In [1]:
import torch
from torch import nn


In [9]:
with open("input.txt", 'r', encoding="UTF-8") as f:
    text = f.read()
    #words = text.splitlines()
    vocab = sorted(set("".join(text)))

vocab_size = len(vocab)
print(vocab_size)

65


In [13]:
stoi = {char:integer for integer, char in enumerate(vocab)}
itos = {integer:char for integer, char in enumerate(vocab)}

encode = lambda enc: [stoi[c] for c in enc]
decode = lambda dec: "".join([itos[i] for i in dec])


In [15]:
train_percent = 0.8

train_split = int(len(text)*train_percent)

data = torch.tensor(encode(text))

train_data = data[:train_split]
test_data = data[train_split:]

In [27]:
def get_batch(data, context_size, batch_size):
    batch_idx = torch.randint(0, len(data)-context_size, (batch_size,))
    x = torch.stack([data[idx:idx+context_size] for idx in batch_idx], dim=0)
    y = torch.stack([data[idx+1:idx+context_size+1] for idx in batch_idx], dim=0)

    return x, y

context_size = 8
batch_size = 1
#get_batch(train_data, context_size, batch_size)

In [371]:
class NGramModel(nn.Module):
    def __init__(self, context_size, n_dim, vocab_size):
        super(NGramModel, self).__init__()

        self.token_embedding = nn.Embedding(vocab_size, n_dim)

    def forward(self, x, batched):
        logits = self.token_embedding(x)
        #print(logits)
        if batched == True:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
        #print(logits)

        return logits

    def generate(self, idx, max_new_tokens):
        #logits = self.token_embedding(idx)
        for _ in range(max_new_tokens):
            logits = self(idx, False)
            logits = logits[:, -1, :]
            probs = torch.softmax(logits, dim=1)
            #print(probs)
            #print(probs.shape)
            
            #need a new idx to stack onto the previous idx and do the logits again but now with some context
            #print(f"PREV IDX --> {idx}")
            new_idx = torch.multinomial(probs, num_samples=1)
            #print()
            #print()
            #print(f"NEW IDX --> {new_idx}")
            idx = torch.cat((idx, new_idx), dim=1)
            #print(f"IDX --> {idx}")
        return idx





In [372]:
n_gram_model = NGramModel(context_size=context_size, n_dim=vocab_size, vocab_size=vocab_size)

batch_size = 1
context_size = 2

x, y = get_batch(train_data, context_size, batch_size)

logits = n_gram_model(x, True)

print()
print(y.view(-1))

#print(x, y)




tensor([50, 53])


In [683]:
from tqdm import tqdm
#context_size = 2 #bigram
context_size = 2
n_gram_model = NGramModel(context_size=context_size, n_dim=vocab_size, vocab_size=vocab_size) #ndim is of vocab size because this is an n gram model and the labels are from 0 to vocab size
optimizer = torch.optim.AdamW(params=n_gram_model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()


In [685]:
epochs = 10000

batch_size = 32

context_size = 8


for epoch in tqdm(range(epochs)):
    x, y = get_batch(batch_size=batch_size, context_size=context_size, data=train_data)
    #print(x)
    logits = n_gram_model(x, True)
    #print(logits)

    #pytorch cross entropy loss expects the channels to be the second dimension while the minibatch as the first
    loss = loss_fn(logits, y.view(-1))
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss)



100%|██████████| 10000/10000 [00:12<00:00, 798.54it/s]

tensor(2.4617, grad_fn=<NllLossBackward0>)





In [636]:
#print(torch.stack((torch.tensor(1), torch.tensor(2))))

#tensor at [[0, 0]] will pluck out two times the 0th row in the embedding
print(get_batch(batch_size=1, context_size=2, data=train_data))

#this sample starts with a context of a single break line
#sample = n_gram_model.generate(idx=torch.tensor([[0]]), max_new_tokens=10)
#sample = n_gram_model.generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=10)
#print(torch.zeros((1, 1)))

sample = n_gram_model.generate(idx=torch.tensor([encode("I hate my life ")]), max_new_tokens=100)
print(sample)
print(sample[0])
print(sample[0].tolist())
print(decode(sample[0].tolist()))

(tensor([[57,  1]]), tensor([[ 1, 46]]))
tensor([[21,  1, 46, 39, 58, 43,  1, 51, 63,  1, 50, 47, 44, 43,  1, 46, 53, 61,
          8,  0, 41, 39, 52, 53,  1, 54, 56, 43,  1, 39, 45, 57,  8,  0, 30, 10,
          0, 59, 57, 58, 43, 39, 58, 46,  7, 50, 53, 59, 54, 57, 58, 46, 39, 58,
         46, 47, 58, 46, 39, 56, 50, 50, 53, 59, 56, 43,  1, 57,  1, 47, 50,  1,
         58, 46, 43, 57, 58,  1, 46, 43,  6,  1, 40, 43, 16,  1, 46,  1, 51, 39,
         52, 43, 43, 50,  1, 46,  1, 40, 39, 51,  6,  0, 23, 21, 15, 21,  1, 53,
          1, 39, 52, 53,  1, 51, 43]])
tensor([21,  1, 46, 39, 58, 43,  1, 51, 63,  1, 50, 47, 44, 43,  1, 46, 53, 61,
         8,  0, 41, 39, 52, 53,  1, 54, 56, 43,  1, 39, 45, 57,  8,  0, 30, 10,
         0, 59, 57, 58, 43, 39, 58, 46,  7, 50, 53, 59, 54, 57, 58, 46, 39, 58,
        46, 47, 58, 46, 39, 56, 50, 50, 53, 59, 56, 43,  1, 57,  1, 47, 50,  1,
        58, 46, 43, 57, 58,  1, 46, 43,  6,  1, 40, 43, 16,  1, 46,  1, 51, 39,
        52, 43, 43, 50,  1, 46,  1

In [691]:
#Attention

test = torch.randn((4, 4))
tril = torch.tril(test)
print(test)
print(tril)

tensor([[ 1.1383,  0.0087,  0.8069,  1.7460],
        [-0.8166, -0.5755, -1.4708, -0.6814],
        [ 0.4011,  1.7600,  0.8195, -0.0321],
        [ 0.4064, -0.0048, -0.5751,  0.1558]])
tensor([[ 1.1383,  0.0000,  0.0000,  0.0000],
        [-0.8166, -0.5755,  0.0000,  0.0000],
        [ 0.4011,  1.7600,  0.8195,  0.0000],
        [ 0.4064, -0.0048, -0.5751,  0.1558]])
