In [1]:
import torch

In [2]:
text = open('dataset.txt').read()
characters = sorted(list(set(text)))

stoi = { c:i for i, c in enumerate(characters)}

itos = { i:c for i, c in enumerate(characters)}

len(text), len(characters)

(1115393, 65)

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

device

'cuda'

In [4]:
batch_size = 4
block_size = 8

In [5]:
train_ratio = 0.9
train_size = int(len(text)*train_ratio)
train_text, val_text = text[:train_size], text[train_size:]

len(train_text), len(val_text)

(1003853, 111540)

In [6]:
encode = lambda s : [stoi[c] for c in s]
decode = lambda li : "".join([itos[i] for i in li])

encode("abc"), decode([39,40,41])

([39, 40, 41], 'abc')

In [7]:
train = torch.tensor(encode(train_text), dtype=torch.long)
val = torch.tensor(encode(val_text), dtype=torch.long)

len(train), len(val)

(1003853, 111540)

In [8]:
torch.manual_seed(1337)

def get_batches(type):
    if type == 'train':
        data = train
    else:
        data = val

    start_points = torch.randint(len(data) - block_size, (batch_size,))
    
    x = [data[i: i + block_size] for i in start_points]        
    y = [data[i + 1: i + block_size + 1] for i in start_points]    

    return torch.stack(x), torch.stack(y)

xb, yb = get_batches('train')

xb

tensor([[53, 59,  6,  1, 58, 56, 47, 40],
        [49, 43, 43, 54,  1, 47, 58,  1],
        [13, 52, 45, 43, 50, 53,  8,  0],
        [ 1, 39,  1, 46, 53, 59, 57, 43]])

In [9]:
xb.to('cuda')

tensor([[53, 59,  6,  1, 58, 56, 47, 40],
        [49, 43, 43, 54,  1, 47, 58,  1],
        [13, 52, 45, 43, 50, 53,  8,  0],
        [ 1, 39,  1, 46, 53, 59, 57, 43]], device='cuda:0')

In [10]:
for i, b in enumerate(xb):
    for j in range(block_size):
        current_context = b[: j + 1]
        prediction = yb[i][j]

        ints = [ix.item() for ix in current_context]
        print(f"Context: {(decode(ints))} -> Prediction: {itos[prediction.item()]}")

Context: o -> Prediction: u
Context: ou -> Prediction: ,
Context: ou, -> Prediction:  
Context: ou,  -> Prediction: t
Context: ou, t -> Prediction: r
Context: ou, tr -> Prediction: i
Context: ou, tri -> Prediction: b
Context: ou, trib -> Prediction: u
Context: k -> Prediction: e
Context: ke -> Prediction: e
Context: kee -> Prediction: p
Context: keep -> Prediction:  
Context: keep  -> Prediction: i
Context: keep i -> Prediction: t
Context: keep it -> Prediction:  
Context: keep it  -> Prediction: t
Context: A -> Prediction: n
Context: An -> Prediction: g
Context: Ang -> Prediction: e
Context: Ange -> Prediction: l
Context: Angel -> Prediction: o
Context: Angelo -> Prediction: .
Context: Angelo. -> Prediction: 

Context: Angelo.
 -> Prediction: N
Context:   -> Prediction: a
Context:  a -> Prediction:  
Context:  a  -> Prediction: h
Context:  a h -> Prediction: o
Context:  a ho -> Prediction: u
Context:  a hou -> Prediction: s
Context:  a hous -> Prediction: e
Context:  a house -> Predic

In [11]:
embedding_dim = (len(characters), len(characters))
embedding_dim

(65, 65)

In [12]:
temp = None

In [27]:
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, emb_dims):
        super(Model, self).__init__()
        self.temp = None
        self.embedding_table = nn.Embedding(emb_dims[0], emb_dims[1])

    def forward(self, x, y=None):
        logits = self.embedding_table(x)
        if y is None:
            return logits, None
        B, T, C = logits.shape
        print(B, T, C)
        # cross entropy expects 
        logits = logits.view(B*T, C)
        y = y.view(B*T)
        loss = F.cross_entropy(input=logits, target=y)
        return logits, loss
    
    def generate(self, idx, max_tokens):
        for i in range(0, max_tokens):
            logits, loss = self(idx, None) # logits -> 4, 8, 65
            logits = logits[:, -1, :] # because we have to predict the next token in the sequence. it logits are present at the last location. shape = 4, 65. Weight of each character in vocabulary for last char in input for each batch
            probs = F.softmax(logits, dim=1) # probability per batch
            idx_next = torch.multinomial(probs, num_samples=1)
            # print(idx_next)
            idx = torch.cat([idx, idx_next], dim=1)
            self.temp = idx
        return idx


In [28]:
m = Model(embedding_dim)
logits, loss = m(xb, yb)

32 8 65


In [15]:
v = m.generate(torch.zeros((1, 1), dtype=torch.long), max_tokens=8)

v[0]

tensor([ 0, 50,  7, 29, 37, 48, 58,  5, 15])

In [16]:
v

tensor([[ 0, 50,  7, 29, 37, 48, 58,  5, 15]])

In [17]:
torch.zeros((1, 1), dtype=torch.long).tolist()

[[0]]

In [18]:
print(decode(v[0].tolist()))


l-QYjt'C


In [19]:
batch_size = 32
epochs = 10000

optimizer = torch.optim.Adam(m.parameters(), lr=1e-3)

In [20]:
for i in range(epochs):
    xb, yb = get_batches('train')
    _, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if i % 1000 == 0:
        print(f"Epoch {i} Loss: {loss.item()}")

print(f"Final loss = {loss.item()}")

Epoch 0 Loss: 4.656340599060059
Epoch 1000 Loss: 3.68070912361145
Epoch 2000 Loss: 2.9998435974121094
Epoch 3000 Loss: 2.8883938789367676
Epoch 4000 Loss: 2.6013710498809814
Epoch 5000 Loss: 2.461345911026001
Epoch 6000 Loss: 2.572709560394287
Epoch 7000 Loss: 2.422856330871582
Epoch 8000 Loss: 2.3459699153900146
Epoch 9000 Loss: 2.4689576625823975
Final loss = 2.475316047668457


In [21]:
loss

tensor(2.4753, grad_fn=<NllLossBackward0>)

In [22]:
v = m.generate(torch.zeros((1, 1), dtype=torch.long), max_tokens=1000)
print(decode(v[0].tolist()))


By arermet hn y, denjohece w illd CHAL, mer thoun s's:
Thicuntilalllevise sthat dy hangilyoteng h hasbe pave pirance
RDe hicomyonthar's
PES:
AKEd ith henourzincenonthioneir thondy, y heltieiengerofo'dsssit ey
KINld pe wither vouprrouthercckehathe; d!
My hind tt hinig t ouchos tes; st yo hind wotte grotonear 'so itJas
Waketancothanan hay.JUCle n prids, r loncave w hollular s O:
HIs; ht anjx?

DUThineent.

LaZEESTEORDY:
h l.
KEONGBUCHand po be y,-JZNEEYowddy scat t tridesar, wne'shenous s ls, theresseys
PlorseelapinghienHen yof GLANCHI me. strsithisgothers je are!
ABer wotouciullle's fldrwertho s?
NDan'spererds cist ripl chyreer orlese;
Yo jowhan, hecere ek? wf HEThot mowo soaf lou f;
Ane his, t, f at. fal thetrimy tepof tor atha s y d utho fplimimave.
NENTIt cir selle p wie wede
Ro n apenor f'Y toverawitys an sh d w t e w!
CELINoretoaveE IINGAwe n ck. cung.
ORDUSURes hacin benqurd bll, d a r w wistatsowor ath
Fivet bloll ail aror;
ARKIO:
My f tsce larry t I Ane szy t
A hy thit,
n.
Faur

# temp trails

In [23]:
arr = [
    [[2,3,4], [2,3,4], [2,3,4],[2,3,4],[2,3,4],[2,3,4],[2,3,4],[2,3,4],[2543,3,4],[2,332,4]],
    [[2,3,4], [2,3,4], [2,3,4],[2,3,4],[2,3,4],[2,3,4],[2,3,4],[2,3,4],[2543,3,4],[2,332,4]],
    [[12,3,4], [12,3,4], [12,3,4],[12,3,4],[12,3,4],[12,3,4],[12,3,4],[12,3,4],[12,3,4],[12,332,4]],
    [[21,3,4], [2,13,4], [2,31,4],[2,3,14],[2,3,41],[2,3,4],[21,3,4],[2543,3,4],[12,3,4],[21,332,4]]
]

temp = torch.stack([torch.tensor(ix) for ix in [i for i in arr]])


In [24]:
temp.shape, temp[-1].shape, temp[: -1].shape, temp[:, -1, :].shape

(torch.Size([4, 10, 3]),
 torch.Size([10, 3]),
 torch.Size([3, 10, 3]),
 torch.Size([4, 3]))

In [25]:
temp[:, 1, :], temp[:, 1, :].shape

(tensor([[ 2,  3,  4],
         [ 2,  3,  4],
         [12,  3,  4],
         [ 2, 13,  4]]),
 torch.Size([4, 3]))

In [26]:
temp[: -1].shape

torch.Size([3, 10, 3])