In [None]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, random_split
from tqdm import tqdm

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = torch.Generator(device=device)
torch.set_default_device(device)
print(f"default device set to {device}")

In [None]:
with open("/kaggle/input/shakespeare/input.txt", 'r', encoding="utf-8") as f:
    text = f.read()

In [None]:
vocab = sorted(set(text))
vocab_size = len(vocab)

In [None]:
stoi = {c: v for v, c in enumerate(vocab)}
itos = {v: c for c, v in stoi.items()}

print(stoi["h"])
print(itos[46])

In [None]:
encode = lambda e: [stoi[ch] for ch in e]
decode = lambda d: "".join([itos[idx] for idx in d])

encoded = encode("hello how are you?")
decoded = decode(encoded)

print(encoded, decoded)

In [None]:
context_size = 256
random_idx_tensor = torch.randperm(10000//context_size) * context_size
print(random_idx_tensor)

In [None]:
def make_dataset(data):
    random_idx_tensor = torch.randperm((len(data)-context_size)//context_size) * context_size
    inputs = torch.stack([data[idx:idx+context_size] for idx in random_idx_tensor])
    labels = torch.stack([data[idx+1:idx+context_size+1] for idx in random_idx_tensor])
    
    return TensorDataset(inputs.to(torch.long), labels.to(torch.long))

In [None]:
data = torch.tensor(encode(text))
dataset = make_dataset(data=data)
sample_input = dataset[0][0]
sample_label = dataset[0][1]

print(sample_input)
print(sample_label)

print(f"dataset length --> {len(dataset)} ({len(dataset) * context_size} characters), that is, about the length of text {len(text)} - context_size --> {len(text)-context_size}")

In [None]:
train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)-train_split)

train_dataset, test_dataset = random_split(dataset=dataset, lengths=[train_split, test_split], generator=generator)

In [None]:
batch_size = 32
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, generator=generator)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, generator=generator)

In [None]:
class Head(nn.Module):
    def __init__(self, n_embd, head_size, context_size):
        super(Head, self).__init__()
        
        self.Q = nn.Linear(in_features=n_embd, out_features=head_size) # takes in BxTxC and return BxTxHead_size
        self.K = nn.Linear(in_features=n_embd, out_features=head_size)
        self.V = nn.Linear(in_features=n_embd, out_features=head_size)
        
        self.register_buffer("tril", torch.tril(torch.ones(size=(context_size, context_size))))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape # batch_size by context_size by n_embd
        q = self.Q(x) # BxTxHead_size
        k = self.K(x) # BxTxHead_size
        
        wei = q @ k.transpose(-2, -1) * (C ** -0.5) # BxTxHead_size @ BxHead_sizexT --> BxTxT then divided by the square root of n_embd
        
        wei = wei.masked_fill(self.tril[:T, :T]==0, float('-inf')) # :T and :T is needed in case context is smaller than context_size
        wei = torch.softmax(wei, dim=-1)
        v = self.V(x) # BxTxHead_size
        #print(f"wei shape is {wei.shape}")
        output = wei @ v # BxTxT @ BxTxHead_size --> BxTxHead_size
        
        return output


In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, n_embd, context_size, n_heads, head_size):
        super(MultiHeadedAttention, self).__init__()
        
        self.heads = nn.ModuleList([Head(n_embd=n_embd, head_size=head_size, context_size=context_size) for _ in range(n_heads)]) # BxTx (n_heads * head_size)
        self.projection = nn.Linear(in_features=n_heads*head_size, out_features=n_embd) # ensures the output is going to be o shape BxTxn_embd (BxTxC) so that is can go through multiple attention block
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.cat([head(x) for head in self.heads], dim=-1) # cat in the Channels dimension; output shape is BxTx (n_heads * head_size)
        #print(f"multihead output shape is {out.shape}")
        #return out

        x = self.projection(x)
        return  x

In [None]:
class FeedForward(nn.Module):
    def __init__(self, in_features):
        super(FeedForward, self).__init__()

        self.ffwrd_layer = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=in_features * 4), # scale by 4, according to the attention is all you need paper
            nn.ReLU(),
            nn.Linear(in_features=in_features * 4, out_features=in_features) # another projection layer
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.ffwrd_layer(x)

In [None]:
class Block(nn.Module):
    def __init__(self, n_heads, head_size, n_embd, context_size):
        super(Block, self).__init__()
        self.multiheaded_self_attetion = MultiHeadedAttention(n_embd=n_embd, context_size=context_size, n_heads=n_heads, head_size=head_size) # create a multiheaded attention block; returns shape BxTx (num_heads*head_size)
        #self.ffwrd = FeedForward(in_features=n_heads*head_size) # returns shape BxTx (num_heads*head_size) --> this is only in case there is no projection layer inside of multiheaded attention
        self.ffwrd = FeedForward(in_features=n_embd)

        self.layer_norm1 = nn.LayerNorm(n_embd)
        self.layer_norm2 = nn.LayerNorm(n_embd)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # /// no residual connections ///
        #x = self.multiheaded_self_attetion(x)
        #x = self.ffwrd(x)

        # /// with residual conections for better optimization ///
        x = x + self.multiheaded_self_attetion(self.layer_norm1(x))
        x = x + self.ffwrd(self.layer_norm2(x))

        return x

In [None]:
class MLP(nn.Module):
    def __init__(self, n_embd, context_size, vocab_size):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.context_size = context_size
        self.n_embd = n_embd
        
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd) # each character from the vocab has n_embd values associated to it
        self.positional_embedding_table = nn.Embedding(context_size, n_embd) # each character position in the context has n_embd values associated to it
        
        #self.linear1 = nn.Linear(in_features=n_embd, out_features=8*8)
        #self.linear2 = nn.Linear(in_features=8*8, out_features=vocab_size)
        #self.act_fn = nn.Tanh()
        self.num_sa_heads = 16
        #self.sa_head_size = 64
        #self.sa_head_size = n_embd // self.num_sa_heads # this proportion is needed in case you are using multiple attention blocks so to keep proper dimensions, otherwhise you can set head_size to anything you want
        self.sa_head_size = 64

        #self.multiheadattention = MultiHeadedAttention(n_embd=n_embd, context_size=context_size, n_heads=self.num_sa_heads, head_size=self.sa_head_size)
        #self.ffwrd = FeedForward(in_features=self.num_sa_heads*self.sa_head_size) # going to take in BxTx (sa_head_size * num_sa_heads) --> going to output the same shape
        #self.sa_head = Head(n_embd=64, head_size=64, context_size=self.context_size)


        self.attention_blocks = nn.Sequential(
            Block(n_heads=self.num_sa_heads, head_size=self.sa_head_size, context_size=self.context_size, n_embd=self.n_embd), # takes in BxTxC, calculate logits of BxTx (num_heads * head_size), then project it as BxTxC
            Block(n_heads=self.num_sa_heads, head_size=self.sa_head_size, context_size=self.context_size, n_embd=self.n_embd),
            #Block(n_heads=self.num_sa_heads, head_size=self.sa_head_size, context_size=self.context_size, n_embd=self.n_embd)
            nn.LayerNorm(n_embd)
        )

        #print("fs")

        #self.attention_blocks = Block(n_heads=self.num_sa_heads, head_size=self.sa_head_size, context_size=self.context_size, n_embd=self.n_embd)

        #self.lm_head = nn.Linear(in_features=self.sa_head_size*self.num_sa_heads, out_features=vocab_size) # in case of attention_blocks with no projection layer
        self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        #print(f"context_size is {self.context_size}")
        B, T = x.shape # batch_size and context_size
        positions = torch.arange(start=0, end=T, step=1)
        
        pos_emb = self.positional_embedding_table(positions) # T x C --> in broadcasting, pytorch adds a batch dim=1
        token_emb = self.token_embedding_table(x) # B x T x C
        
        x = token_emb + pos_emb # BxTxC
        #x = self.lm_head(x) # BxTxVocab_size
        #x = self.act_fn(self.linear1(x)) # BxTxVocab_size
        #x = self.linear2(x) # BxTxVocab_size
        
        #x = self.multiheadattention(x) # BxTx (sa_head_size*num_sa_heads)
        #x = self.ffwrd(x)
        #self_attention = self.sa_head(x) # BxTxHead_size (BxTxC in this case, since head_size=n_embd)

        x = self.attention_blocks(x) # returns logits of shape BxTx (self.sa_head_size * self.num_sa_heads) projected to BxTxC

        x = self.lm_head(x) # BxTxvocab_size --> BxTxHead_size @ BxTxVocab_size return BxTxVocab_size

        return x.view(B*T, self.vocab_size) # easier shape to work with the labels
    
    def generate(self, starting_idx: torch.Tensor, max_length: int, debug: bool) -> torch.Tensor:
        full_text = decode([starting_idx.item()])

        
        context = starting_idx
        
        for _ in range(max_length):
            context = context[:, -self.context_size:] # make sure the context is of size context_size
            
            if debug:
                print(f"predicting on context: {decode(context[0].tolist())}")
            
            logits = self(context) # B*T x vocab_size --> 1*2 x vocab_size
            logits = logits[-1, :].view(1, self.vocab_size) # only take the prediction for the last character
            percents = torch.softmax(logits, dim=1) # 1*2xvocab_size
            pred = torch.multinomial(percents, num_samples=1) 
            full_text += decode(pred.tolist()[0])
            
            #print(len(padded[0]))
            context = torch.cat([context, pred], dim=1) # add to the context dimension instead of the batch dim
            
        return full_text

In [None]:
class model_generator:
    def __init__(self, model: object, max_length: int, num_samples: int, vocab_size: int):
        self.model = model
        self.max_length = max_length
        self.num_samples = num_samples
        self.vocab_size = vocab_size
        
        self.last_output = ""
        
        self.params_dict = {
            "model": model,
            "max_length": max_length,
            "num_samples": num_samples,
            "previous_outputs": []
        }
    
    @torch.no_grad
    def generate(self, starting_char: str = None, clear_outputs: bool = True, debug: bool = False):
        self.model.eval()
        
        if clear_outputs:
            self.clear_ouptuts()
            
        if starting_char is None:
            starting_char = decode([torch.randint(0, vocab_size, (1,)).item()])
            
        for _ in range(self.num_samples):
            starting_idx = torch.tensor(encode(starting_char), dtype=torch.long).view(1, 1)
            output = model.generate(starting_idx=starting_idx, max_length=self.max_length, debug=debug)
            self.params_dict["previous_outputs"].append(output)
            self.last_output = output
    
    def update_params(self, model: object = None, max_length: int = None, num_samples: int = None, clear_outputs: bool = None):
        if clear_outputs:
            self.clear_outputs()
            
        updated_dict = {
            "model": model,
            "max_length": max_length,
            "num_samples": num_samples
        }
        
        for attribute, value in updated_dict.items():
            if value is not None:
                self.params_dict[attribute] = value
                setattr(self, attribute, value)
    
    def clear_ouptuts(self):
        self.params_dict["previous_outputs"] = []
        self.last_output = ""
        
    def print_outputs(self, last: bool = None):
        if last:
            print(self.last_output)
        else:
            for output in self.params_dict["previous_outputs"]:
                print(f"{output}\n\n")
            

In [None]:
n_embd = 1024
vocab_size = len(vocab)
model = MLP(n_embd=n_embd, context_size=context_size, vocab_size=vocab_size)
mlp_generator = model_generator(model=model, max_length=32, num_samples=1, vocab_size=vocab_size)
#print(model)
mlp_generator.generate(starting_char=".", debug=False)
mlp_generator.print_outputs()

In [None]:
#mlp_generator.clear_ouptuts()
mlp_generator.update_params(max_length=100, num_samples=1)
mlp_generator.generate()
mlp_generator.print_outputs()

In [None]:
model.eval()
test_loss_fn = nn.CrossEntropyLoss()
batch_sample_inputs, batch_sample_labels = next(iter(train_dataloader))
with torch.inference_mode():
    #print(batch_sample_inputs)
    #print(batch_sample_labels)
    logits = model(batch_sample_inputs)
    labels = batch_sample_labels.view(-1) # turns the label shape into a B*T
    #print(logits) # 4 batches of 8 characters each, the model is trying to predict the next sequence
    #print(labels)
    loss = test_loss_fn(logits, batch_sample_labels.view(-1))
    print(loss)

In [None]:
def train_model(model, dataloader, loss_fn, optimizer, epochs):
    model.train()
    
    for epoch in range(epochs):
        for batch, (X, y) in tqdm(enumerate(dataloader)):
            logits = model(X) # shape of B*T x vocab_size
            labels = y.view(-1) # shape of B*T --> each character has it's own prediction
            loss = loss_fn(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch % 100 == 0:
                print(f"loss for batch {batch} --> {loss} at epoch {epoch}")

In [None]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [None]:
train_model(model=model, dataloader=train_dataloader, loss_fn=loss_fn, optimizer=optimizer, epochs=15)

In [None]:
mlp_generator.update_params(max_length=1000, num_samples=5)
mlp_generator.generate()
mlp_generator.print_outputs() # from those three outputs, print the last one