In [None]:
import torch
from torch import nn

In [None]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [None]:
class DecoderTransformer(nn.Module):
    def __init__(self, n_layers, embedding_dim, vocab_size, context_length, head_size, fc_inner_size):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.position_encoding = PositionEncoding(embedding_dim, context_length)
        self.blocks = nn.Sequential(*[Block(embedding_dim, head_size, context_length, fc_inner_size) for i in range(n_layers)])
        self.last_fc = nn.Linear(embedding_dim, vocab_size)
        
    def forward(self, x):
        # x is of shape (Batch size x Context length)
        x = self.embedding(x)
        x = self.position_encoding(x)
        # x is now of shape (Batch size x Context length x Embedding dim)

        x = self.blocks(x)
        logits = self.last_fc(x) # Generating next token prediction for each token in context length so at the at we get (batch size x context length x vocab size)

        return logits

class PositionEncoding(nn.Module):
    def __init__(self, embedding_dim, context_len):
        super().__init__()

        pe = torch.zeros((context_len, embedding_dim)) # Tensor that will store position embeddings so that we dont have to compute them each time

        positions = torch.arange(0, context_len, step=1).float().unsqueeze(1)

        embedding_indices = torch.arange(0, embedding_dim, step=2)
        div_term = 1/1000**(embedding_indices/embedding_dim)

        pe[:, 0::2] = torch.sin(positions*div_term)
        pe[:, 1::2] = torch.cos(positions*div_term)

        self.register_buffer('pe', pe)

    def forward(self, word_embeddings):
        batch_size, context_length, embedding_dim = word_embeddings.shape
        return word_embeddings + self.pe[:context_length, :] 
    
class SelfAttention(nn.Module):
    def __init__(self, embedding_dim, head_size, context_length):
        super().__init__()

        self.query = nn.Linear(embedding_dim, head_size)
        self.key = nn.Linear(embedding_dim, head_size)
        self.value = nn.Linear(embedding_dim, embedding_dim)

        self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))

    def forward(self, x):
        # x is of shape (Batch size x Context length x Embedding dim)
        batch_size, context_length, embedding_dim = x.shape

        q = self.query(x) # q is of shape (batch size x context_length x head_size)
        k = self.key(x)
        v = self.value(x) # v is of shape (batch size x context length x embedding dim)

        _, _, head_size = q.shape

        # Dot product is of shape (Batch size x context length x context length)
        dot_product = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(head_size)) 

        dot_product_masked = dot_product.masked_fill(self.tril == 0, float("-inf"))

        weigths = torch.softmax(dot_product_masked, dim=-1)

        new_embedding = torch.matmul(weigths, v) # new embedding of shape (batch size x context length x embedding dim)

        return new_embedding
        
    
class Block(nn.Module):
    def __init__(self, embedding_dim, head_size, context_length, fc_inner_size):
        super().__init__()
        
        self.attention = SelfAttention(embedding_dim, head_size, context_length)
        self.fc = nn.Sequential(nn.Linear(embedding_dim, fc_inner_size), nn.ReLU(), nn.Linear(fc_inner_size, embedding_dim))

    def forward(self, x):
        # Attention and fc with residual connections
        x = self.attention(x) + x
        x = self.fc(x) + x
        return x # Output is of shape (Batch size x context length x embedding dim)

In [None]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [None]:
len(text)

In [None]:
# let's look at the first 1000 characters
print(text[:1000])

In [None]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

In [None]:
char_to_id = {ch: i for i, ch in enumerate(chars)}
id_to_char = {i: ch for i, ch in enumerate(chars)}

encode = lambda text: [char_to_id[character] for character in text]
decode = lambda ids: ''.join([id_to_char[index] for index in ids])

In [None]:
print(encode("Hello mama"))
print(decode(encode("Hello mama")))

In [None]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data[:1000])

In [None]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [None]:
context_length = 8 # Also called block_size
embedding_dim = 2048
head_size = 128
fc_inner_layer = 4096
train_data[:context_length+1]

In [None]:
x = data[:context_length]
y = data[1:context_length+1]
for t in range(context_length):
    print("Context:", x[:t+1], "Target:", y[t])

In [None]:
torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - context_length, (batch_size,)) # random offsets
    x = torch.stack([data[i:i+context_length] for i in ix]) # inputs 
    y = torch.stack([data[i+1:i+context_length+1] for i in ix]) # targets
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(context_length): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

In [None]:
input = get_batch('train')[0]

In [None]:
input.shape

In [None]:
model = DecoderTransformer(10, embedding_dim, vocab_size, context_length, head_size, fc_inner_layer)
model(input).shape

In [None]:
for t in range(context_length):
    print("Model output:", model(x[:t+1]).shape)
    print("Context:", x[:t+1].shape, "Target:", y[t])

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
batch_size = 32
n_epochs = 10000


for steps in range(n_epochs):
    xb, yb = get_batch('train')

    logits = model(xb)
    
    B, T, C = logits.shape
    logits = logits.view(B*T, C) # Reshape because cross_entropy expects inputs as (B, C, T)
    yb = yb.view(B*T)
    loss = criterion(logits, yb)
    
    # loss = criterion(logits, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    print(loss.item())