In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-05-25 16:41:51--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-05-25 16:41:51 (17.7 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
import torch
import numpy as np
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
with open('input.txt', 'r') as f:
    text = f.read()

1115394


In [6]:
vocabulary = sorted(list(set(text)))
vocab_size = len(vocabulary)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
65


In [7]:
vocab_to_int = {v:i for i, v in enumerate(vocabulary)}
int_to_vocab = {i:v for i, v in enumerate(vocabulary)}

encode = lambda x: [vocab_to_int[i] for i in x]
decode = lambda x: "".join([int_to_vocab[i] for i in x])

[20, 47, 1, 58, 46, 43, 56, 43, 2]
Hi there!


In [9]:
dataset = torch.tensor(encode(text), dtype = torch.long)
train_data = dataset[:int(0.9 * len(dataset))]
val_data = dataset[int(0.9 * len(dataset)): int(0.97 * len(dataset))]
test_data = dataset[int(0.97 * len(dataset)): ]

In [10]:
class ShaksphereDataLoader(torch.utils.data.DataLoader):
    def __init__(self, dataset, seq_len, batch_size):
        self.dataset = dataset
        self.seq_len = seq_len
        self.batch_size = batch_size
        self.sub_seq = (len(dataset) // self.seq_len)
        self.num_batches = self.sub_seq // self.batch_size

    def __len__(self):
        return self.num_batches

    def __iter__(self):
        offsets = torch.randperm(len(self.dataset) - self.seq_len)[:self.sub_seq]
        inputs = torch.stack([self.dataset[i: i + self.seq_len] for i in offsets])
        outputs = torch.stack([self.dataset[i + 1: i + 1 + self.seq_len] for i in offsets])
        rem = inputs.shape[0] % self.batch_size
        if rem != 0:
            inputs = inputs[:-rem, :].reshape(self.num_batches, self.batch_size, self.seq_len)
            outputs = outputs[:-rem, :].reshape(self.num_batches, self.batch_size, self.seq_len)
        else:
            inputs = inputs.reshape(self.num_batches, self.batch_size, self.seq_len)
            outputs = outputs.reshape(self.num_batches, self.batch_size, self.seq_len)

        batch_idx = 0
        while batch_idx < self.num_batches:
            yield inputs[batch_idx, : , :].to(device), outputs[batch_idx, :, :].to(device)
            batch_idx = batch_idx + 1

In [11]:
ModelConfig = {
    'seq_len': 256,
    "batch_size": 128,
    'embed_dim': 512,
    'qk_dim': 512
}

seq_len = ModelConfig['seq_len']
batch_size = ModelConfig['batch_size']
embed_dim = ModelConfig['embed_dim']
qk_dim = ModelConfig['qk_dim']
value_dim = ModelConfig['embed_dim']
model_size = ModelConfig['embed_dim']

In [12]:
train_loader = ShaksphereDataLoader(train_data, seq_len = ModelConfig['seq_len'], batch_size = ModelConfig['batch_size'])
val_loader = ShaksphereDataLoader(val_data, seq_len = ModelConfig['seq_len'], batch_size = ModelConfig['batch_size'])

30
2


torch.Size([128, 256])
torch.Size([128, 256])


In [14]:
class MaskedAttention(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.query_layer = torch.nn.Linear(embed_dim, embed_dim, bias = False)
        self.key_layer = torch.nn.Linear(embed_dim, qk_dim, bias = False)
        self.value_layer = torch.nn.Linear(embed_dim, value_dim, bias = False)
        self.mask = torch.tril(torch.ones(seq_len, seq_len)).to(device)

    def forward(self, tok_embeddings, return_attention_weights = False):
        B, T, C = tok_embeddings.shape
        Q = self.query_layer(tok_embeddings) # (B, T, qk)
        K = self.key_layer(tok_embeddings)   # (B, T, qk)
        V = self.value_layer(tok_embeddings) # (B, T, v)

        affinities = (Q @ K.transpose(-1, -2)) * K.shape[-1] ** -0.5 # (B, T, T)
        affinities = affinities.masked_fill(self.mask[:T, :T] == 0, float('-inf'))

        attention_weights = F.softmax(affinities, dim = -1) # (B, T, T)

        if return_attention_weights:
            return attention_weights @ V, attention_weights
        return attention_weights @ V  # (B, T, v)

In [15]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.mini_head_size = int(self.head_size / self.num_heads)

        self.query = torch.nn.Linear(embed_dim, self.head_size)
        self.key = torch.nn.Linear(embed_dim, self.head_size)
        self.value = torch.nn.Linear(embed_dim, self.head_size)
        self.up_project = torch.nn.Linear(self.num_heads * self.mini_head_size, head_size)
        self.mask = torch.tril(torch.ones(seq_len, seq_len)).to(device)

    def forward(self, tok_embeddings, return_attention_weights = False):
        B, T, C = tok_embeddings.shape
        Q = self.query(tok_embeddings)
        K = self.key(tok_embeddings)
        V = self.value(tok_embeddings)

        # Reshape into N sub heads for parallel processing
        mini_Q = Q.view(B, T, self.num_heads, self.mini_head_size).permute(0, 2, 1, 3) # (B, nh, T, mini_head)
        mini_K = K.view(B, T, self.num_heads, self.mini_head_size).permute(0, 2, 1, 3) # (B, nh, T, mini_head)
        mini_V = V.view(B, T, self.num_heads, self.mini_head_size).permute(0, 2, 1, 3) # (B, nh, T, mini_head)

        # Scaled dot Product
        affinities = (mini_Q @ mini_K.transpose(-1, -2)) * (self.mini_head_size ** -0.5) #  (B, nh, T, T)
        # Makes it causal decoder
        affinities = affinities.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
        attention_weights = F.softmax(affinities, dim = -1) # (B, nh, T, T)

        # weighted vlaues for each head
        mini_head_weighted_values = attention_weights @ mini_V # (B, nh, T, mini_head)
        # Concatenating each mini_head outputs
        mini_head_weighted_values = mini_head_weighted_values.permute(0, 2, 1, 3)
        head_weighted_values = mini_head_weighted_values.reshape(B, T, self.mini_head_size * self.num_heads) # (B, T, head_size)

        if return_attention_weights:
            return self.up_project(head_weighted_values), attention_weights
        return self.up_project(head_weighted_values) # (B, T, head_size)

In [16]:
class DecoderBlock(torch.nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        self.attention = MultiHeadAttention(num_heads=num_heads, head_size=model_size)
        self.layer_norm_1 = torch.nn.LayerNorm(model_size)
        self.linear_1 = torch.nn.Linear(model_size, model_size)
        self.linear_2 = torch.nn.Linear(model_size, model_size)
        self.relu = torch.nn.ReLU()
        self.layer_norm_2 = torch.nn.LayerNorm(model_size)
        self.dropout_1 = torch.nn.Dropout(0.1)
        self.dropout_2 = torch.nn.Dropout(0.1)

    def forward(self, embeddings, return_attention_weights = False):
        # tokens ---> (B, T, embed_size)
        B, T, C = embeddings.shape

        # Attention
        if return_attention_weights:
            weighted_values, attention_weights = self.attention.forward(embeddings, return_attention_weights)
        else:
            weighted_values = self.attention.forward(embeddings) # (B, T, embed_dim)

        # Attention Norm + dropout
        weighted_values_drp = self.dropout_1(weighted_values)
        norm_values = self.layer_norm_1(weighted_values_drp + embeddings) # (B, T, embed_dim)

        # FFN + dropout
        linear_1 = self.linear_1(norm_values)
        act_vals = self.relu(linear_1)
        linear_2 = self.linear_2(act_vals)
        linear_2_drp = self.dropout_2(linear_2)

        # LayerNorm
        ffn_norm = self.layer_norm_2(norm_values + linear_2_drp) # (B, T, embed_dim)


        if return_attention_weights:
            return ffn_norm, attention_weights
        return ffn_norm # (B, T, embed_dim)

In [17]:
class LanguageModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_layer = torch.nn.Embedding(vocab_size, embed_dim)
        self.positional_embedding = torch.nn.Embedding(seq_len, embed_dim)
        self.dropout_0 = torch.nn.Dropout(0.1)
        self.decoder_block_1 = DecoderBlock(num_heads=8)
        self.decoder_block_2 = DecoderBlock(num_heads=8)
        self.decoder_block_3 = DecoderBlock(num_heads=8)

        self.projection = torch.nn.Linear(model_size, vocab_size)
        self.projection.weight = self.embedding_layer.weight

    def forward(self, tokens, return_attention_weights = False):
        B, T = tokens.shape
        # Token and Positional Embeddings + dropout
        tok_embs = self.embedding_layer(tokens) # (B, T, embed_dim)
        pos_embs = self.positional_embedding(torch.arange(T).to(device))
        pos_tok_embs = tok_embs + pos_embs # (B, T, embed_dim)
        pos_tok_embs_drp = self.dropout_0(pos_tok_embs)

        # Decoder Blocks
        decoder_1 = self.decoder_block_1.forward(pos_tok_embs)
        decoder_2 = self.decoder_block_2.forward(decoder_1)
        if return_attention_weights:
            decoder_3, last_layer_attention_weights = self.decoder_block_3.forward(decoder_2, return_attention_weights=True)
        else:
            decoder_3 = self.decoder_block_3.forward(decoder_2)

        # projection layer
        logits = self.projection(decoder_3) # (B, T, vocab_size)

        if return_attention_weights:
            return logits, last_layer_attention_weights
        return logits # (B, T, vocab_size)

In [18]:
model = LanguageModel()
model.to(device)
model.load_state_dict(torch.load('/content/drive/MyDrive/ShakGPT_3_512.pt', map_location = device))
#summary(model, x)

<All keys matched successfully>

In [20]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
scheduler =  torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=4) #TODO

In [21]:
optimizer.param_groups[0]['lr'] = 5e-5

In [24]:
class Trainer():
    def __init__(self, model, train_loader, val_loader, optimizer, criterion):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.epoch = 0
        self.optimizer = optimizer
        self.criterion = criterion
        train_loss = 0


    def train(self):
        self.model.train()
        total_loss = 0
        for x, y in self.train_loader:
            logits, attention_weights = self.model.forward(x, return_attention_weights=True)
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            y = y.view(B*T)
            loss = self.criterion(logits, y)
            total_loss = total_loss + loss.item()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()


        epoch_loss = total_loss / len(train_loader)
        self.epoch = self.epoch + 1
        print(f"Epoch: {self.epoch}\nTrain Loss: {epoch_loss}, Train Perplexity: {np.exp(epoch_loss)}, LR: {self.optimizer.param_groups[0]['lr']}")
        return epoch_loss

    @torch.no_grad()
    def validate(self):
        self.model.eval()
        total_loss = 0
        for x, y in self.val_loader:
            logits = self.model.forward(x)
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            y = y.view(B*T)
            val_loss = self.criterion(logits, y)
            total_loss = total_loss + val_loss.item()

        epoch_loss = total_loss / len(self.val_loader)
        print(f'Validation Loss: {epoch_loss}')
        return epoch_loss

    @torch.no_grad()
    def generate(self, max_tokens, prompt, temperature = 1, top_p = 1):
        self.model.eval()
        for i in range(max_tokens):
            logits = self.model.forward(prompt)
            logit = logits[:, -1, :] # (B, C)
            logit = logit / temperature
            probs = F.softmax(logit, dim = -1)
            weighted_probs = self.topPTransform(probs, top_p)
            #token = torch.argmax(probs, dim = -1).view(-1, 1)
            token = torch.multinomial(weighted_probs, num_samples = 1) # (B, 1)
            prompt = torch.cat((prompt, token), dim = -1) # (B, T + 1)
        return prompt

    def topPTransform(self, probs, top_p):
        probs_sorted_vals, probs_sort_idx = torch.sort(probs, descending=True)
        prob_cumsum = torch.cumsum(probs_sorted_vals, dim = -1)

        absolute_diff = torch.abs(prob_cumsum - top_p)
        closest_index = torch.argmin(absolute_diff).item()
        idx_to_remove = probs_sort_idx[:, closest_index + 1:]

        mask = torch.ones_like(probs)
        mask[:, idx_to_remove] = 0

        probs = probs * mask
        weighted_probs = probs / torch.sum(probs, dim = -1)
        return weighted_probs

    def save(self, path):
        torch.save(self.model.state_dict(), path)
        print("Model saved at " + path)

In [25]:
epochs = 10
trainer = Trainer(model = model, train_loader=train_loader, val_loader=val_loader, optimizer=optimizer, criterion = criterion)
best_val_loss = float('inf')
for epoch in range(epochs):
    train_loss = trainer.train()
    val_loss = trainer.validate()
    #scheduler.step(val_loss)
    if val_loss < best_val_loss:
        trainer.save('/content/drive/MyDrive/ShakGPT_3_512.pt')
        best_val_loss = val_loss
    print("Generation: " + decode(trainer.generate(prompt =
                            torch.tensor(encode("BARNARDO: Who's there?\nFRANCISCO: Nay, answer me. Stand and unfold yourself.\nBARNARDO:")).view(1, -1).to(device),
                            max_tokens=170,
                            top_p = 0.7)[0].tolist()))
    print('\n')
    if epoch == 700:
        optimizer.param_groups[0]['lr'] = 1e-5

Epoch: 1
Train Loss: 8.601622295379638, Train Perplexity: 5440.478499124308, LR: 5e-05
Validation Loss: 4.496948003768921
Model saved at /content/drive/MyDrive/ShakGPT_3_512.pt
Generation: BARNARDO: Who's there?
FRANCISCO: Nay, answer me. Stand and unfold yourself.
BARNARDO:      e              e                                  e                      ee                e e                                                     e  ee            




KeyboardInterrupt: 

In [None]:
from random import uniform
import time
from IPython.display import display, clear_output

def topPTransform(probs, top_p):
        probs_sorted_vals, probs_sort_idx = torch.sort(probs, descending=True)
        prob_cumsum = torch.cumsum(probs_sorted_vals, dim = -1)

        absolute_diff = torch.abs(prob_cumsum - top_p)
        closest_index = torch.argmin(absolute_diff).item()
        idx_to_remove = probs_sort_idx[:, closest_index + 1:]

        mask = torch.ones_like(probs)
        mask[:, idx_to_remove] = 0

        probs = probs * mask
        weighted_probs = probs / torch.sum(probs, dim = -1)
        return weighted_probs

def persistent_generation(prompt, max_tokens, temperature = 1, top_p = 1):
    generated_text = prompt
    model.eval()
    tokens = torch.tensor(encode(prompt)).reshape(1, -1).to(device)
    buffer = 10
    token_count = 0
    while token_count < max_tokens:
        if tokens.shape[-1] >= seq_len:
            tokens = tokens[:, tokens.shape[-1] - seq_len + 10: ]
        logits = model.forward(tokens)
        logit = logits[:, -1, :] # (B, C)

        # temperature
        logit = logit / temperature

        # top_p
        probs = F.softmax(logit, dim = -1) # (B, C)
        weighted_probs = topPTransform(probs, top_p)
        predicted_token = torch.multinomial(weighted_probs, num_samples = 1) # (B, 1)
        generated_text = generated_text + decode(predicted_token[0].cpu().detach().tolist())
        tokens = torch.cat((tokens, predicted_token), dim = -1) # (B, T + 1) # (1, 1)
        clear_output(wait=True)

        print(generated_text)
        token_count = token_count + 1

print(persistent_generation("ACT I\n\nSCENE I. Elsinore. A platform before the Castle.\n\n\nEnter Francisco and Barnardo, two sentinels.", max_tokens = 700, top_p = 0.7))