In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.nn import functional as F

In [None]:
batch_size = 64 # B
block_size = 8 # T
# C = 65 -> vocab length
head_size = 16
steps = 5000
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate = 1e-3
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout_percent = 0.0

In [3]:
from datasets import load_dataset
ds = load_dataset("Trelis/tiny-shakespeare")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
ds_ = ds['train']['Text']

In [5]:
chars = sorted(set(''.join(ds_))) # vocabulory
len(chars)

65

In [6]:
encoder_dict = {k:v for k,v in zip(chars, range(len(chars)))}
decoder_dict = {v:k for k,v in zip(chars, range(len(chars)))}

# Encoder, Decoder
encode = lambda x: [encoder_dict[letter] for letter in x]
decode = lambda x: ''.join([decoder_dict[letter] for letter in x])
encode('hello'), decode([46, 43, 50, 50, 53])

([46, 43, 50, 50, 53], 'hello')

In [7]:
ds_all = '\n'.join(ds_)
ds_encoded = encode(ds_all)

In [8]:
n = int(.9*len(ds_encoded))
train_data = ds_encoded[:n]
val_data = ds_encoded[n:]

In [9]:
def get_batch(data, block_size=block_size, batch_size=batch_size):
    # print(len(data))
    ix = torch.randint(len(data)-block_size, (batch_size,)) # these numbers are the start of each batch
    xy = torch.tensor([[data[i] for i in range(ix[i],ix[i]+block_size+1)] for i in range(batch_size)])
    xb = xy[:, :block_size]
    yb = xy[:,1:]
    # print(xb.shape, yb.shape) # (B x T)
    return xb, yb

get_batch(train_data)

(tensor([[39, 45, 59, 43,  1, 47, 52, 48],
         [53, 59, 40, 58, 44, 59, 50,  1],
         [43, 56,  1, 58, 53, 52, 45, 59],
         [ 1, 21,  1, 40, 43, 46, 43, 50],
         [53, 57, 57,  1, 50, 53, 59, 58],
         [ 1, 21,  1, 61, 47, 50, 50,  1],
         [53, 44,  1, 58, 46, 43, 57, 43],
         [43, 42, 57,  6,  0, 35, 46, 47],
         [ 1, 58, 46, 47, 52, 45, 57,  1],
         [ 1, 61, 47, 58, 46,  1, 58, 46],
         [ 0, 18, 53, 56,  1, 47, 52,  1],
         [58, 56, 47, 39, 50,  8,  0, 25],
         [ 0, 29, 33, 17, 17, 26,  1, 17],
         [26, 21, 33, 31, 10,  0, 20, 43],
         [58,  1, 58, 46, 43, 47, 56,  1],
         [59, 57,  1, 46, 53, 61,  1, 52],
         [53, 59, 50, 42,  1, 57, 53,  1],
         [41, 46,  1, 53, 44,  1, 58, 46],
         [ 1, 58, 53,  1, 40, 43,  1, 57],
         [59, 45, 46,  1, 51, 63,  1, 60],
         [58, 46, 43, 47, 56,  1, 61, 56],
         [52, 53, 58, 11,  1, 63, 53, 59],
         [57,  6,  0, 32, 46, 39, 58,  1],
         [4

In [23]:
# write code to estimate loss with train and test dataset

In [None]:
class Head(nn.Module):
    '''one head in self attention'''
    def __init__(self):
        super().__init__()
        self.key = nn.Linear(n_embd, int(n_embd/n_head), bias=False)
        self.query = nn.Linear(n_embd, int(n_embd/n_head), bias=False)
        self.value = nn.Linear(n_embd, int(n_embd/n_head), bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # registers are saved along with the model, but are never updated during training. They dont count as model parameters
        self.dropout = nn.Dropout(dropout_percent)
    
    def forward(self, x): pass

In [10]:
class LLM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            # print(B, T, C)
            logits = logits.view(B*T, C)
            # print(targets.shape, B*T)
            # print(targets.dtype)
            targets = targets.float().view(B*T).long()
            # print(targets.shape)
            # print(targets)
            loss = F.cross_entropy(logits, targets)
            # print(logits.shape, loss.shape)
        return logits, loss
    
    def generate(self, idx, max_tokens=100):
        for _ in range(max_tokens):
            logits, loss = self(idx)
            # logits is (B x T x C)
            logits = logits[:, -1, :] # take only the last(latest) one in T component
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

llm = LLM(len(chars))
# llm.forward(get_batch(train_data)[0])
# decode(llm.generate(get_batch(train_data)[0]))
print(decode(llm.generate(idx = torch.zeros((1, 1), dtype=torch.long))[0].tolist()))



jXn?wKPcMkibnuuaiSnfRxTegQFus&CGaccTtO!ZutVo.bnggYpJzOQEVutny
Ocwdz!-$?y
;xtQOCU-N,W3Q!W,SugVI!ayIa$


In [12]:
optimizer = torch.optim.AdamW(llm.parameters(), lr=1e-3)
for steps in range(1000): # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch(train_data)

    # evaluate the loss
    logits, loss = llm(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    # if steps % 100 == 0: print(loss.item())
print(loss.item())

64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
64 8 65
