# Introduction
In this laboratory we will get our hands dirty working with Large Language Models (e.g. GPT and BERT) to do various useful things. I you haven't already, it is highly recommended to:

+ Read the [Attention is All you Need](https://arxiv.org/abs/1706.03762) paper, which is the basis for all transformer-based LLMs.
+ Watch (and potentially *code along*) with this [Andrej Karpathy video](https://www.youtube.com/watch?v=kCc8FmEb1nY) which shows you how to build an autoregressive GPT model from the ground up.

# Exercise 1: Warming Up
In this first exercise you will train a *small* autoregressive GPT model for character generation (the one used by Karpathy in his video) to generate text in the style of Dante Aligheri. Use [this file](https://archive.org/stream/ladivinacommedia00997gut/1ddcd09.txt), which contains the entire text of Dante's Inferno (**note**: you will have to delete some introductory text at the top of the file before training). Train the model for a few epochs, monitor the loss, and generate some text at the end of training. Qualitatively evaluate the results 

In [3]:
# Your code here.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import wandb
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Data proeprocessing

In [2]:
from os.path import expanduser

class TextDS(Dataset):
    def __init__(self, block_size: int, path = expanduser('~') + "/datasets/commedia.txt", start: float = 0, stop: float= 0.7) -> None:
        super().__init__()
        with open(path, 'r') as f:
            self.text = f.read()
        l = len(self.text)
        self.text = self.text[int(start * l): int(stop * l)]
        self.vocab = sorted(list(set(self.text)))
        self.vocab_size = len(self.vocab)
        self.stoi = { ch: i for i, ch in enumerate(self.vocab) }
        self.itos = { i: ch for i, ch in enumerate(self.vocab) }
        self.encode = lambda s: [self.stoi[c] for c in s]
        self.decode = lambda l: ''.join([self.itos[i] for i in l])
        self.block_size = block_size

        self.data = torch.Tensor(self.encode(self.text)).type(torch.LongTensor)

    def __len__(self): return len(self.text) - (self.block_size + 1)

    def __getitem__(self, index):
        x = self.data[index: index + self.block_size]
        y = self.data[index + 1: index + self.block_size + 1]
        return x, y

In [3]:
block_size = 32
batch_size = 8
train_ds = TextDS(block_size, start=0, stop=0.3)
train_dl = DataLoader(train_ds, batch_size, True)
val_ds = TextDS(block_size, start=0.7, stop=0.9)
val_dl = DataLoader(val_ds, batch_size, True)
test_ds = TextDS(block_size, start=0.9, stop=1)
train_dl = DataLoader(train_ds, batch_size, True)

for x, y in train_dl:
    print(x.shape, y.shape)
    break


torch.Size([8, 32]) torch.Size([8, 32])


In [74]:
class AttentionHead(nn.Module):
    # take in input an embedding

    def __init__(self, input_size: int, head_size: int, block_size: int, masked: bool = True) -> None:
        super(AttentionHead, self).__init__()
        self.Q = nn.Linear(input_size, head_size)
        self.K = nn.Linear(input_size, head_size)
        self.V = nn.Linear(input_size, head_size)
        self.dropout = nn.Dropout()
        self.d = head_size
        self.masked = masked
        if self.masked:
            self.tril = torch.tril(torch.ones((block_size, block_size))).to(device)

    def forward(self, X):
        # X [B T C]
        # B, T, C = X.shape
        q = self.Q(X) # [B T D]
        k = self.K(X)
        v = self.V(X)

        qk: torch.Tensor = (q @ k.transpose(-1, -2)) / (self.d ** 0.5) # [B T D] @ [B D T] = [B T T]      
        if self.masked:
            qk = self.dropout(qk)
            qk = qk.masked_fill(self.tril == 0, float('-inf'))
        # qk = F.softmax(qk, dim=-1)
        # print("a", qk.isnan().any())
        return F.softmax(qk, dim=-1) @ v # [B T T] @ [B T D] = [B T D]
        
class MultiHeadAttention(nn.Module):

    def __init__(self, embedding_size: int, head_size: int, num_heads: int, block_size: int, masked: bool = True) -> None:
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList([AttentionHead(embedding_size, head_size, block_size, masked) for _ in range(num_heads)])
        self.projection = nn.Linear(head_size * num_heads, embedding_size)
        self.dropout = nn.Dropout()

    def forward(self, X):
        # each head has an output of [B T T], stacking at [B T (T*N)]
        concat = torch.cat([head(X) for head in self.heads], dim=-1) # last dimension
        out = self.projection(concat)
        out = self.dropout(out)
        return out

class FeedFoward(nn.Module):

    def __init__(self, n_embd):
        super(FeedFoward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):

    def __init__(self, embedding_size, num_heads, block_size) -> None:
        super(TransformerBlock, self).__init__()
        head_size = embedding_size // num_heads
        self.mha = MultiHeadAttention(embedding_size, head_size, num_heads, block_size)
        self.feed_forward = FeedFoward(embedding_size)
        self.ln1 = nn.LayerNorm(embedding_size)
        self.ln2 = nn.LayerNorm(embedding_size)

    def forward(self, X):
        X = X + self.mha(self.ln1(X))
        X = X + self.feed_forward(self.ln2(X))
        return X
    
class BLM(nn.Module):
    def __init__(self, vocab_size: int, embedding_size: int, num_heads: int, block_size: int, num_layers: int) -> None:
        super(BLM, self).__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, embedding_size)
        self.position_embedding_table = nn.Embedding(block_size, embedding_size)
        self.blocks = nn.Sequential(*[TransformerBlock(embedding_size, num_heads, block_size) for _ in range(num_layers)]) 
        self.ln_f = nn.LayerNorm(embedding_size) # final layer norm
        self.lm_head = nn.Linear(embedding_size, vocab_size)

    def forward(self, X):
        T = X.shape[-1]
        token_embedding = self.token_embedding_table(X)
        position_embedding = self.position_embedding_table(torch.arange(T, device=device))
        X = token_embedding + position_embedding
        X = self.blocks(X)
        X = self.ln_f(X)
        X = self.lm_head(X)
        return X

    def generate(self, X, max_new_token: int):
        # NOTE: Assume the input to be a single element, not a batch!
        # X is [T]
        T = X.shape[-1]
        for _ in range(max_new_token):
            # use last T tokens
            logits = self(X[-T:]) # [T C]
            # get the last timestep
            logits = logits[-1, :]
            # get the distribution of the next element
            probs = F.softmax(logits, dim=-1) # softmax over the last dimensiont
            next_token = torch.multinomial(probs, num_samples=1)
            X = torch.cat((X, next_token), dim=0)
        return X

# Train & Validation

In [75]:
@torch.no_grad()
def validation(model, dataloader, loss_fn):
    model.eval()
    loss = 0
    acc = 0
    for x, y in tqdm(dataloader, "Validation: ", leave=False):
        x, y = x.to(device), y.to(device)
        prediction_logits = model(x)

        B, T, C = prediction_logits.shape
        prediction_logits = prediction_logits.view(B*T, C)
        y = y.view(B*T)

        loss += loss_fn(prediction_logits, y).item()
        acc += (prediction_logits.argmax(1) == y).float().sum().item()
    return loss / len(dataloader), acc / len(dataloader.dataset)

def training(model, train_dataloader, validation_dataloader, loss_fn, optimizer, epochs, validation_freq, log):
    losses, accs = [], []
    for t in range(1, epochs + 1):
        model.train()
        for x, y in tqdm(train_dataloader, f"Epoch #{t}: ", leave=False):
            x, y = x.to(device), y.to(device)
            prediction_logits = model(x)

            B, T, C = prediction_logits.shape
            prediction_logits = prediction_logits.view(B*T, C)
            y = y.view(B*T)
            
            loss = loss_fn(prediction_logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        if t % validation_freq == 0:
            lss, acc = validation(model, validation_dataloader, loss_fn)
            losses.append(lss)
            accs.append(acc)
            log_dict = {"loss": lss, "accuracy": acc}
            if log:
                wandb.log(log_dict)

    return losses, accs

In [80]:
epochs = 10
embedding_size = 512
num_head = 16
num_layers = 6
validation_freq = 5
blm = BLM(train_ds.vocab_size, embedding_size, num_head, block_size, num_layers).to(device)
optimizer = torch.optim.Adam(blm.parameters(), lr=0.001)
loss = F.cross_entropy

# wandb.init(
#     project="DLA Lab 2 1.0",
# )
# training(blm, train_dl, val_dl, loss, optimizer, epochs, validation_freq, False)
# wandb.finish()

In [81]:
blm.eval()
generated_tokens = blm.generate(train_ds[100][0].to(device),max_new_token=50)
generated_tokens = generated_tokens.tolist()
generated_text = train_ds.decode(generated_tokens)
print(generated_text)


  che' la diritta via era smarrNgtVCh` "VAZaCSQtr!Es;-"c.S:VB.vMgpEPhtB,>bRLDBeZZ


In [82]:
training(blm, train_dl, val_dl, loss, optimizer, epochs, validation_freq, False)

                                                             

KeyboardInterrupt: 

In [18]:
sd = blm.state_dict()
torch.save(sd, "weight.pth")

In [61]:
blm.load_state_dict(torch.load("weight.pth"))

RuntimeError: Error(s) in loading state_dict for BLM:
	Unexpected key(s) in state_dict: "blocks.1.mha.heads.0.Q.weight", "blocks.1.mha.heads.0.Q.bias", "blocks.1.mha.heads.0.K.weight", "blocks.1.mha.heads.0.K.bias", "blocks.1.mha.heads.0.V.weight", "blocks.1.mha.heads.0.V.bias", "blocks.1.mha.heads.1.Q.weight", "blocks.1.mha.heads.1.Q.bias", "blocks.1.mha.heads.1.K.weight", "blocks.1.mha.heads.1.K.bias", "blocks.1.mha.heads.1.V.weight", "blocks.1.mha.heads.1.V.bias", "blocks.1.mha.heads.2.Q.weight", "blocks.1.mha.heads.2.Q.bias", "blocks.1.mha.heads.2.K.weight", "blocks.1.mha.heads.2.K.bias", "blocks.1.mha.heads.2.V.weight", "blocks.1.mha.heads.2.V.bias", "blocks.1.mha.heads.3.Q.weight", "blocks.1.mha.heads.3.Q.bias", "blocks.1.mha.heads.3.K.weight", "blocks.1.mha.heads.3.K.bias", "blocks.1.mha.heads.3.V.weight", "blocks.1.mha.heads.3.V.bias", "blocks.1.mha.heads.4.Q.weight", "blocks.1.mha.heads.4.Q.bias", "blocks.1.mha.heads.4.K.weight", "blocks.1.mha.heads.4.K.bias", "blocks.1.mha.heads.4.V.weight", "blocks.1.mha.heads.4.V.bias", "blocks.1.mha.heads.5.Q.weight", "blocks.1.mha.heads.5.Q.bias", "blocks.1.mha.heads.5.K.weight", "blocks.1.mha.heads.5.K.bias", "blocks.1.mha.heads.5.V.weight", "blocks.1.mha.heads.5.V.bias", "blocks.1.mha.heads.6.Q.weight", "blocks.1.mha.heads.6.Q.bias", "blocks.1.mha.heads.6.K.weight", "blocks.1.mha.heads.6.K.bias", "blocks.1.mha.heads.6.V.weight", "blocks.1.mha.heads.6.V.bias", "blocks.1.mha.heads.7.Q.weight", "blocks.1.mha.heads.7.Q.bias", "blocks.1.mha.heads.7.K.weight", "blocks.1.mha.heads.7.K.bias", "blocks.1.mha.heads.7.V.weight", "blocks.1.mha.heads.7.V.bias", "blocks.1.mha.heads.8.Q.weight", "blocks.1.mha.heads.8.Q.bias", "blocks.1.mha.heads.8.K.weight", "blocks.1.mha.heads.8.K.bias", "blocks.1.mha.heads.8.V.weight", "blocks.1.mha.heads.8.V.bias", "blocks.1.mha.heads.9.Q.weight", "blocks.1.mha.heads.9.Q.bias", "blocks.1.mha.heads.9.K.weight", "blocks.1.mha.heads.9.K.bias", "blocks.1.mha.heads.9.V.weight", "blocks.1.mha.heads.9.V.bias", "blocks.1.mha.heads.10.Q.weight", "blocks.1.mha.heads.10.Q.bias", "blocks.1.mha.heads.10.K.weight", "blocks.1.mha.heads.10.K.bias", "blocks.1.mha.heads.10.V.weight", "blocks.1.mha.heads.10.V.bias", "blocks.1.mha.heads.11.Q.weight", "blocks.1.mha.heads.11.Q.bias", "blocks.1.mha.heads.11.K.weight", "blocks.1.mha.heads.11.K.bias", "blocks.1.mha.heads.11.V.weight", "blocks.1.mha.heads.11.V.bias", "blocks.1.mha.heads.12.Q.weight", "blocks.1.mha.heads.12.Q.bias", "blocks.1.mha.heads.12.K.weight", "blocks.1.mha.heads.12.K.bias", "blocks.1.mha.heads.12.V.weight", "blocks.1.mha.heads.12.V.bias", "blocks.1.mha.heads.13.Q.weight", "blocks.1.mha.heads.13.Q.bias", "blocks.1.mha.heads.13.K.weight", "blocks.1.mha.heads.13.K.bias", "blocks.1.mha.heads.13.V.weight", "blocks.1.mha.heads.13.V.bias", "blocks.1.mha.heads.14.Q.weight", "blocks.1.mha.heads.14.Q.bias", "blocks.1.mha.heads.14.K.weight", "blocks.1.mha.heads.14.K.bias", "blocks.1.mha.heads.14.V.weight", "blocks.1.mha.heads.14.V.bias", "blocks.1.mha.heads.15.Q.weight", "blocks.1.mha.heads.15.Q.bias", "blocks.1.mha.heads.15.K.weight", "blocks.1.mha.heads.15.K.bias", "blocks.1.mha.heads.15.V.weight", "blocks.1.mha.heads.15.V.bias", "blocks.1.mha.projection.weight", "blocks.1.mha.projection.bias", "blocks.1.feed_forward.net.0.weight", "blocks.1.feed_forward.net.0.bias", "blocks.1.feed_forward.net.2.weight", "blocks.1.feed_forward.net.2.bias", "blocks.1.ln1.weight", "blocks.1.ln1.bias", "blocks.1.ln2.weight", "blocks.1.ln2.bias", "blocks.2.mha.heads.0.Q.weight", "blocks.2.mha.heads.0.Q.bias", "blocks.2.mha.heads.0.K.weight", "blocks.2.mha.heads.0.K.bias", "blocks.2.mha.heads.0.V.weight", "blocks.2.mha.heads.0.V.bias", "blocks.2.mha.heads.1.Q.weight", "blocks.2.mha.heads.1.Q.bias", "blocks.2.mha.heads.1.K.weight", "blocks.2.mha.heads.1.K.bias", "blocks.2.mha.heads.1.V.weight", "blocks.2.mha.heads.1.V.bias", "blocks.2.mha.heads.2.Q.weight", "blocks.2.mha.heads.2.Q.bias", "blocks.2.mha.heads.2.K.weight", "blocks.2.mha.heads.2.K.bias", "blocks.2.mha.heads.2.V.weight", "blocks.2.mha.heads.2.V.bias", "blocks.2.mha.heads.3.Q.weight", "blocks.2.mha.heads.3.Q.bias", "blocks.2.mha.heads.3.K.weight", "blocks.2.mha.heads.3.K.bias", "blocks.2.mha.heads.3.V.weight", "blocks.2.mha.heads.3.V.bias", "blocks.2.mha.heads.4.Q.weight", "blocks.2.mha.heads.4.Q.bias", "blocks.2.mha.heads.4.K.weight", "blocks.2.mha.heads.4.K.bias", "blocks.2.mha.heads.4.V.weight", "blocks.2.mha.heads.4.V.bias", "blocks.2.mha.heads.5.Q.weight", "blocks.2.mha.heads.5.Q.bias", "blocks.2.mha.heads.5.K.weight", "blocks.2.mha.heads.5.K.bias", "blocks.2.mha.heads.5.V.weight", "blocks.2.mha.heads.5.V.bias", "blocks.2.mha.heads.6.Q.weight", "blocks.2.mha.heads.6.Q.bias", "blocks.2.mha.heads.6.K.weight", "blocks.2.mha.heads.6.K.bias", "blocks.2.mha.heads.6.V.weight", "blocks.2.mha.heads.6.V.bias", "blocks.2.mha.heads.7.Q.weight", "blocks.2.mha.heads.7.Q.bias", "blocks.2.mha.heads.7.K.weight", "blocks.2.mha.heads.7.K.bias", "blocks.2.mha.heads.7.V.weight", "blocks.2.mha.heads.7.V.bias", "blocks.2.mha.heads.8.Q.weight", "blocks.2.mha.heads.8.Q.bias", "blocks.2.mha.heads.8.K.weight", "blocks.2.mha.heads.8.K.bias", "blocks.2.mha.heads.8.V.weight", "blocks.2.mha.heads.8.V.bias", "blocks.2.mha.heads.9.Q.weight", "blocks.2.mha.heads.9.Q.bias", "blocks.2.mha.heads.9.K.weight", "blocks.2.mha.heads.9.K.bias", "blocks.2.mha.heads.9.V.weight", "blocks.2.mha.heads.9.V.bias", "blocks.2.mha.heads.10.Q.weight", "blocks.2.mha.heads.10.Q.bias", "blocks.2.mha.heads.10.K.weight", "blocks.2.mha.heads.10.K.bias", "blocks.2.mha.heads.10.V.weight", "blocks.2.mha.heads.10.V.bias", "blocks.2.mha.heads.11.Q.weight", "blocks.2.mha.heads.11.Q.bias", "blocks.2.mha.heads.11.K.weight", "blocks.2.mha.heads.11.K.bias", "blocks.2.mha.heads.11.V.weight", "blocks.2.mha.heads.11.V.bias", "blocks.2.mha.heads.12.Q.weight", "blocks.2.mha.heads.12.Q.bias", "blocks.2.mha.heads.12.K.weight", "blocks.2.mha.heads.12.K.bias", "blocks.2.mha.heads.12.V.weight", "blocks.2.mha.heads.12.V.bias", "blocks.2.mha.heads.13.Q.weight", "blocks.2.mha.heads.13.Q.bias", "blocks.2.mha.heads.13.K.weight", "blocks.2.mha.heads.13.K.bias", "blocks.2.mha.heads.13.V.weight", "blocks.2.mha.heads.13.V.bias", "blocks.2.mha.heads.14.Q.weight", "blocks.2.mha.heads.14.Q.bias", "blocks.2.mha.heads.14.K.weight", "blocks.2.mha.heads.14.K.bias", "blocks.2.mha.heads.14.V.weight", "blocks.2.mha.heads.14.V.bias", "blocks.2.mha.heads.15.Q.weight", "blocks.2.mha.heads.15.Q.bias", "blocks.2.mha.heads.15.K.weight", "blocks.2.mha.heads.15.K.bias", "blocks.2.mha.heads.15.V.weight", "blocks.2.mha.heads.15.V.bias", "blocks.2.mha.projection.weight", "blocks.2.mha.projection.bias", "blocks.2.feed_forward.net.0.weight", "blocks.2.feed_forward.net.0.bias", "blocks.2.feed_forward.net.2.weight", "blocks.2.feed_forward.net.2.bias", "blocks.2.ln1.weight", "blocks.2.ln1.bias", "blocks.2.ln2.weight", "blocks.2.ln2.bias", "blocks.3.mha.heads.0.Q.weight", "blocks.3.mha.heads.0.Q.bias", "blocks.3.mha.heads.0.K.weight", "blocks.3.mha.heads.0.K.bias", "blocks.3.mha.heads.0.V.weight", "blocks.3.mha.heads.0.V.bias", "blocks.3.mha.heads.1.Q.weight", "blocks.3.mha.heads.1.Q.bias", "blocks.3.mha.heads.1.K.weight", "blocks.3.mha.heads.1.K.bias", "blocks.3.mha.heads.1.V.weight", "blocks.3.mha.heads.1.V.bias", "blocks.3.mha.heads.2.Q.weight", "blocks.3.mha.heads.2.Q.bias", "blocks.3.mha.heads.2.K.weight", "blocks.3.mha.heads.2.K.bias", "blocks.3.mha.heads.2.V.weight", "blocks.3.mha.heads.2.V.bias", "blocks.3.mha.heads.3.Q.weight", "blocks.3.mha.heads.3.Q.bias", "blocks.3.mha.heads.3.K.weight", "blocks.3.mha.heads.3.K.bias", "blocks.3.mha.heads.3.V.weight", "blocks.3.mha.heads.3.V.bias", "blocks.3.mha.heads.4.Q.weight", "blocks.3.mha.heads.4.Q.bias", "blocks.3.mha.heads.4.K.weight", "blocks.3.mha.heads.4.K.bias", "blocks.3.mha.heads.4.V.weight", "blocks.3.mha.heads.4.V.bias", "blocks.3.mha.heads.5.Q.weight", "blocks.3.mha.heads.5.Q.bias", "blocks.3.mha.heads.5.K.weight", "blocks.3.mha.heads.5.K.bias", "blocks.3.mha.heads.5.V.weight", "blocks.3.mha.heads.5.V.bias", "blocks.3.mha.heads.6.Q.weight", "blocks.3.mha.heads.6.Q.bias", "blocks.3.mha.heads.6.K.weight", "blocks.3.mha.heads.6.K.bias", "blocks.3.mha.heads.6.V.weight", "blocks.3.mha.heads.6.V.bias", "blocks.3.mha.heads.7.Q.weight", "blocks.3.mha.heads.7.Q.bias", "blocks.3.mha.heads.7.K.weight", "blocks.3.mha.heads.7.K.bias", "blocks.3.mha.heads.7.V.weight", "blocks.3.mha.heads.7.V.bias", "blocks.3.mha.heads.8.Q.weight", "blocks.3.mha.heads.8.Q.bias", "blocks.3.mha.heads.8.K.weight", "blocks.3.mha.heads.8.K.bias", "blocks.3.mha.heads.8.V.weight", "blocks.3.mha.heads.8.V.bias", "blocks.3.mha.heads.9.Q.weight", "blocks.3.mha.heads.9.Q.bias", "blocks.3.mha.heads.9.K.weight", "blocks.3.mha.heads.9.K.bias", "blocks.3.mha.heads.9.V.weight", "blocks.3.mha.heads.9.V.bias", "blocks.3.mha.heads.10.Q.weight", "blocks.3.mha.heads.10.Q.bias", "blocks.3.mha.heads.10.K.weight", "blocks.3.mha.heads.10.K.bias", "blocks.3.mha.heads.10.V.weight", "blocks.3.mha.heads.10.V.bias", "blocks.3.mha.heads.11.Q.weight", "blocks.3.mha.heads.11.Q.bias", "blocks.3.mha.heads.11.K.weight", "blocks.3.mha.heads.11.K.bias", "blocks.3.mha.heads.11.V.weight", "blocks.3.mha.heads.11.V.bias", "blocks.3.mha.heads.12.Q.weight", "blocks.3.mha.heads.12.Q.bias", "blocks.3.mha.heads.12.K.weight", "blocks.3.mha.heads.12.K.bias", "blocks.3.mha.heads.12.V.weight", "blocks.3.mha.heads.12.V.bias", "blocks.3.mha.heads.13.Q.weight", "blocks.3.mha.heads.13.Q.bias", "blocks.3.mha.heads.13.K.weight", "blocks.3.mha.heads.13.K.bias", "blocks.3.mha.heads.13.V.weight", "blocks.3.mha.heads.13.V.bias", "blocks.3.mha.heads.14.Q.weight", "blocks.3.mha.heads.14.Q.bias", "blocks.3.mha.heads.14.K.weight", "blocks.3.mha.heads.14.K.bias", "blocks.3.mha.heads.14.V.weight", "blocks.3.mha.heads.14.V.bias", "blocks.3.mha.heads.15.Q.weight", "blocks.3.mha.heads.15.Q.bias", "blocks.3.mha.heads.15.K.weight", "blocks.3.mha.heads.15.K.bias", "blocks.3.mha.heads.15.V.weight", "blocks.3.mha.heads.15.V.bias", "blocks.3.mha.projection.weight", "blocks.3.mha.projection.bias", "blocks.3.feed_forward.net.0.weight", "blocks.3.feed_forward.net.0.bias", "blocks.3.feed_forward.net.2.weight", "blocks.3.feed_forward.net.2.bias", "blocks.3.ln1.weight", "blocks.3.ln1.bias", "blocks.3.ln2.weight", "blocks.3.ln2.bias", "blocks.4.mha.heads.0.Q.weight", "blocks.4.mha.heads.0.Q.bias", "blocks.4.mha.heads.0.K.weight", "blocks.4.mha.heads.0.K.bias", "blocks.4.mha.heads.0.V.weight", "blocks.4.mha.heads.0.V.bias", "blocks.4.mha.heads.1.Q.weight", "blocks.4.mha.heads.1.Q.bias", "blocks.4.mha.heads.1.K.weight", "blocks.4.mha.heads.1.K.bias", "blocks.4.mha.heads.1.V.weight", "blocks.4.mha.heads.1.V.bias", "blocks.4.mha.heads.2.Q.weight", "blocks.4.mha.heads.2.Q.bias", "blocks.4.mha.heads.2.K.weight", "blocks.4.mha.heads.2.K.bias", "blocks.4.mha.heads.2.V.weight", "blocks.4.mha.heads.2.V.bias", "blocks.4.mha.heads.3.Q.weight", "blocks.4.mha.heads.3.Q.bias", "blocks.4.mha.heads.3.K.weight", "blocks.4.mha.heads.3.K.bias", "blocks.4.mha.heads.3.V.weight", "blocks.4.mha.heads.3.V.bias", "blocks.4.mha.heads.4.Q.weight", "blocks.4.mha.heads.4.Q.bias", "blocks.4.mha.heads.4.K.weight", "blocks.4.mha.heads.4.K.bias", "blocks.4.mha.heads.4.V.weight", "blocks.4.mha.heads.4.V.bias", "blocks.4.mha.heads.5.Q.weight", "blocks.4.mha.heads.5.Q.bias", "blocks.4.mha.heads.5.K.weight", "blocks.4.mha.heads.5.K.bias", "blocks.4.mha.heads.5.V.weight", "blocks.4.mha.heads.5.V.bias", "blocks.4.mha.heads.6.Q.weight", "blocks.4.mha.heads.6.Q.bias", "blocks.4.mha.heads.6.K.weight", "blocks.4.mha.heads.6.K.bias", "blocks.4.mha.heads.6.V.weight", "blocks.4.mha.heads.6.V.bias", "blocks.4.mha.heads.7.Q.weight", "blocks.4.mha.heads.7.Q.bias", "blocks.4.mha.heads.7.K.weight", "blocks.4.mha.heads.7.K.bias", "blocks.4.mha.heads.7.V.weight", "blocks.4.mha.heads.7.V.bias", "blocks.4.mha.heads.8.Q.weight", "blocks.4.mha.heads.8.Q.bias", "blocks.4.mha.heads.8.K.weight", "blocks.4.mha.heads.8.K.bias", "blocks.4.mha.heads.8.V.weight", "blocks.4.mha.heads.8.V.bias", "blocks.4.mha.heads.9.Q.weight", "blocks.4.mha.heads.9.Q.bias", "blocks.4.mha.heads.9.K.weight", "blocks.4.mha.heads.9.K.bias", "blocks.4.mha.heads.9.V.weight", "blocks.4.mha.heads.9.V.bias", "blocks.4.mha.heads.10.Q.weight", "blocks.4.mha.heads.10.Q.bias", "blocks.4.mha.heads.10.K.weight", "blocks.4.mha.heads.10.K.bias", "blocks.4.mha.heads.10.V.weight", "blocks.4.mha.heads.10.V.bias", "blocks.4.mha.heads.11.Q.weight", "blocks.4.mha.heads.11.Q.bias", "blocks.4.mha.heads.11.K.weight", "blocks.4.mha.heads.11.K.bias", "blocks.4.mha.heads.11.V.weight", "blocks.4.mha.heads.11.V.bias", "blocks.4.mha.heads.12.Q.weight", "blocks.4.mha.heads.12.Q.bias", "blocks.4.mha.heads.12.K.weight", "blocks.4.mha.heads.12.K.bias", "blocks.4.mha.heads.12.V.weight", "blocks.4.mha.heads.12.V.bias", "blocks.4.mha.heads.13.Q.weight", "blocks.4.mha.heads.13.Q.bias", "blocks.4.mha.heads.13.K.weight", "blocks.4.mha.heads.13.K.bias", "blocks.4.mha.heads.13.V.weight", "blocks.4.mha.heads.13.V.bias", "blocks.4.mha.heads.14.Q.weight", "blocks.4.mha.heads.14.Q.bias", "blocks.4.mha.heads.14.K.weight", "blocks.4.mha.heads.14.K.bias", "blocks.4.mha.heads.14.V.weight", "blocks.4.mha.heads.14.V.bias", "blocks.4.mha.heads.15.Q.weight", "blocks.4.mha.heads.15.Q.bias", "blocks.4.mha.heads.15.K.weight", "blocks.4.mha.heads.15.K.bias", "blocks.4.mha.heads.15.V.weight", "blocks.4.mha.heads.15.V.bias", "blocks.4.mha.projection.weight", "blocks.4.mha.projection.bias", "blocks.4.feed_forward.net.0.weight", "blocks.4.feed_forward.net.0.bias", "blocks.4.feed_forward.net.2.weight", "blocks.4.feed_forward.net.2.bias", "blocks.4.ln1.weight", "blocks.4.ln1.bias", "blocks.4.ln2.weight", "blocks.4.ln2.bias", "blocks.5.mha.heads.0.Q.weight", "blocks.5.mha.heads.0.Q.bias", "blocks.5.mha.heads.0.K.weight", "blocks.5.mha.heads.0.K.bias", "blocks.5.mha.heads.0.V.weight", "blocks.5.mha.heads.0.V.bias", "blocks.5.mha.heads.1.Q.weight", "blocks.5.mha.heads.1.Q.bias", "blocks.5.mha.heads.1.K.weight", "blocks.5.mha.heads.1.K.bias", "blocks.5.mha.heads.1.V.weight", "blocks.5.mha.heads.1.V.bias", "blocks.5.mha.heads.2.Q.weight", "blocks.5.mha.heads.2.Q.bias", "blocks.5.mha.heads.2.K.weight", "blocks.5.mha.heads.2.K.bias", "blocks.5.mha.heads.2.V.weight", "blocks.5.mha.heads.2.V.bias", "blocks.5.mha.heads.3.Q.weight", "blocks.5.mha.heads.3.Q.bias", "blocks.5.mha.heads.3.K.weight", "blocks.5.mha.heads.3.K.bias", "blocks.5.mha.heads.3.V.weight", "blocks.5.mha.heads.3.V.bias", "blocks.5.mha.heads.4.Q.weight", "blocks.5.mha.heads.4.Q.bias", "blocks.5.mha.heads.4.K.weight", "blocks.5.mha.heads.4.K.bias", "blocks.5.mha.heads.4.V.weight", "blocks.5.mha.heads.4.V.bias", "blocks.5.mha.heads.5.Q.weight", "blocks.5.mha.heads.5.Q.bias", "blocks.5.mha.heads.5.K.weight", "blocks.5.mha.heads.5.K.bias", "blocks.5.mha.heads.5.V.weight", "blocks.5.mha.heads.5.V.bias", "blocks.5.mha.heads.6.Q.weight", "blocks.5.mha.heads.6.Q.bias", "blocks.5.mha.heads.6.K.weight", "blocks.5.mha.heads.6.K.bias", "blocks.5.mha.heads.6.V.weight", "blocks.5.mha.heads.6.V.bias", "blocks.5.mha.heads.7.Q.weight", "blocks.5.mha.heads.7.Q.bias", "blocks.5.mha.heads.7.K.weight", "blocks.5.mha.heads.7.K.bias", "blocks.5.mha.heads.7.V.weight", "blocks.5.mha.heads.7.V.bias", "blocks.5.mha.heads.8.Q.weight", "blocks.5.mha.heads.8.Q.bias", "blocks.5.mha.heads.8.K.weight", "blocks.5.mha.heads.8.K.bias", "blocks.5.mha.heads.8.V.weight", "blocks.5.mha.heads.8.V.bias", "blocks.5.mha.heads.9.Q.weight", "blocks.5.mha.heads.9.Q.bias", "blocks.5.mha.heads.9.K.weight", "blocks.5.mha.heads.9.K.bias", "blocks.5.mha.heads.9.V.weight", "blocks.5.mha.heads.9.V.bias", "blocks.5.mha.heads.10.Q.weight", "blocks.5.mha.heads.10.Q.bias", "blocks.5.mha.heads.10.K.weight", "blocks.5.mha.heads.10.K.bias", "blocks.5.mha.heads.10.V.weight", "blocks.5.mha.heads.10.V.bias", "blocks.5.mha.heads.11.Q.weight", "blocks.5.mha.heads.11.Q.bias", "blocks.5.mha.heads.11.K.weight", "blocks.5.mha.heads.11.K.bias", "blocks.5.mha.heads.11.V.weight", "blocks.5.mha.heads.11.V.bias", "blocks.5.mha.heads.12.Q.weight", "blocks.5.mha.heads.12.Q.bias", "blocks.5.mha.heads.12.K.weight", "blocks.5.mha.heads.12.K.bias", "blocks.5.mha.heads.12.V.weight", "blocks.5.mha.heads.12.V.bias", "blocks.5.mha.heads.13.Q.weight", "blocks.5.mha.heads.13.Q.bias", "blocks.5.mha.heads.13.K.weight", "blocks.5.mha.heads.13.K.bias", "blocks.5.mha.heads.13.V.weight", "blocks.5.mha.heads.13.V.bias", "blocks.5.mha.heads.14.Q.weight", "blocks.5.mha.heads.14.Q.bias", "blocks.5.mha.heads.14.K.weight", "blocks.5.mha.heads.14.K.bias", "blocks.5.mha.heads.14.V.weight", "blocks.5.mha.heads.14.V.bias", "blocks.5.mha.heads.15.Q.weight", "blocks.5.mha.heads.15.Q.bias", "blocks.5.mha.heads.15.K.weight", "blocks.5.mha.heads.15.K.bias", "blocks.5.mha.heads.15.V.weight", "blocks.5.mha.heads.15.V.bias", "blocks.5.mha.projection.weight", "blocks.5.mha.projection.bias", "blocks.5.feed_forward.net.0.weight", "blocks.5.feed_forward.net.0.bias", "blocks.5.feed_forward.net.2.weight", "blocks.5.feed_forward.net.2.bias", "blocks.5.ln1.weight", "blocks.5.ln1.bias", "blocks.5.ln2.weight", "blocks.5.ln2.bias", "blocks.0.mha.heads.1.Q.weight", "blocks.0.mha.heads.1.Q.bias", "blocks.0.mha.heads.1.K.weight", "blocks.0.mha.heads.1.K.bias", "blocks.0.mha.heads.1.V.weight", "blocks.0.mha.heads.1.V.bias", "blocks.0.mha.heads.2.Q.weight", "blocks.0.mha.heads.2.Q.bias", "blocks.0.mha.heads.2.K.weight", "blocks.0.mha.heads.2.K.bias", "blocks.0.mha.heads.2.V.weight", "blocks.0.mha.heads.2.V.bias", "blocks.0.mha.heads.3.Q.weight", "blocks.0.mha.heads.3.Q.bias", "blocks.0.mha.heads.3.K.weight", "blocks.0.mha.heads.3.K.bias", "blocks.0.mha.heads.3.V.weight", "blocks.0.mha.heads.3.V.bias", "blocks.0.mha.heads.4.Q.weight", "blocks.0.mha.heads.4.Q.bias", "blocks.0.mha.heads.4.K.weight", "blocks.0.mha.heads.4.K.bias", "blocks.0.mha.heads.4.V.weight", "blocks.0.mha.heads.4.V.bias", "blocks.0.mha.heads.5.Q.weight", "blocks.0.mha.heads.5.Q.bias", "blocks.0.mha.heads.5.K.weight", "blocks.0.mha.heads.5.K.bias", "blocks.0.mha.heads.5.V.weight", "blocks.0.mha.heads.5.V.bias", "blocks.0.mha.heads.6.Q.weight", "blocks.0.mha.heads.6.Q.bias", "blocks.0.mha.heads.6.K.weight", "blocks.0.mha.heads.6.K.bias", "blocks.0.mha.heads.6.V.weight", "blocks.0.mha.heads.6.V.bias", "blocks.0.mha.heads.7.Q.weight", "blocks.0.mha.heads.7.Q.bias", "blocks.0.mha.heads.7.K.weight", "blocks.0.mha.heads.7.K.bias", "blocks.0.mha.heads.7.V.weight", "blocks.0.mha.heads.7.V.bias", "blocks.0.mha.heads.8.Q.weight", "blocks.0.mha.heads.8.Q.bias", "blocks.0.mha.heads.8.K.weight", "blocks.0.mha.heads.8.K.bias", "blocks.0.mha.heads.8.V.weight", "blocks.0.mha.heads.8.V.bias", "blocks.0.mha.heads.9.Q.weight", "blocks.0.mha.heads.9.Q.bias", "blocks.0.mha.heads.9.K.weight", "blocks.0.mha.heads.9.K.bias", "blocks.0.mha.heads.9.V.weight", "blocks.0.mha.heads.9.V.bias", "blocks.0.mha.heads.10.Q.weight", "blocks.0.mha.heads.10.Q.bias", "blocks.0.mha.heads.10.K.weight", "blocks.0.mha.heads.10.K.bias", "blocks.0.mha.heads.10.V.weight", "blocks.0.mha.heads.10.V.bias", "blocks.0.mha.heads.11.Q.weight", "blocks.0.mha.heads.11.Q.bias", "blocks.0.mha.heads.11.K.weight", "blocks.0.mha.heads.11.K.bias", "blocks.0.mha.heads.11.V.weight", "blocks.0.mha.heads.11.V.bias", "blocks.0.mha.heads.12.Q.weight", "blocks.0.mha.heads.12.Q.bias", "blocks.0.mha.heads.12.K.weight", "blocks.0.mha.heads.12.K.bias", "blocks.0.mha.heads.12.V.weight", "blocks.0.mha.heads.12.V.bias", "blocks.0.mha.heads.13.Q.weight", "blocks.0.mha.heads.13.Q.bias", "blocks.0.mha.heads.13.K.weight", "blocks.0.mha.heads.13.K.bias", "blocks.0.mha.heads.13.V.weight", "blocks.0.mha.heads.13.V.bias", "blocks.0.mha.heads.14.Q.weight", "blocks.0.mha.heads.14.Q.bias", "blocks.0.mha.heads.14.K.weight", "blocks.0.mha.heads.14.K.bias", "blocks.0.mha.heads.14.V.weight", "blocks.0.mha.heads.14.V.bias", "blocks.0.mha.heads.15.Q.weight", "blocks.0.mha.heads.15.Q.bias", "blocks.0.mha.heads.15.K.weight", "blocks.0.mha.heads.15.K.bias", "blocks.0.mha.heads.15.V.weight", "blocks.0.mha.heads.15.V.bias". 
	size mismatch for blocks.0.mha.heads.0.Q.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for blocks.0.mha.heads.0.Q.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for blocks.0.mha.heads.0.K.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for blocks.0.mha.heads.0.K.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for blocks.0.mha.heads.0.V.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for blocks.0.mha.heads.0.V.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([512]).

In [79]:
blm.eval()
print(blm(train_ds[100][0].to(device)))
generated_tokens = blm.generate(train_ds[100][0].to(device),max_new_token=50)
generated_tokens = generated_tokens.tolist()
generated_text = train_ds.decode(generated_tokens)
print(generated_text)

tensor([[ 4.9235e+00,  5.2436e+00, -5.5760e+00,  ..., -1.9069e+00,
         -6.8754e-01, -4.3279e+00],
        [ 2.6395e-01,  1.0988e+01, -9.1082e+00,  ..., -8.1945e-01,
          1.8988e-02, -6.0349e+00],
        [-5.1276e+00, -2.7756e-01, -7.7493e+00,  ..., -7.8942e-01,
          1.8550e+00, -5.7235e+00],
        ...,
        [ 1.1628e+00,  4.2442e+00, -3.4044e+00,  ..., -1.9836e+00,
          6.4581e-01, -2.5607e+00],
        [-3.5277e+00,  3.6280e+00, -7.2505e+00,  ..., -8.2779e-03,
         -1.4042e-01, -1.7086e+00],
        [-5.1594e+00,  1.3357e+00, -7.0264e+00,  ...,  2.9148e+00,
         -2.9007e+00, -2.8000e+00]], device='cuda:0', grad_fn=<AddmmBackward0>)

  che' la diritta via era smarragne non denti,
  questannonto,
  no fura no, elle


# Exercise 2: Working with Real LLMs

Our toy GPT can only take us so far. In this exercise we will see how to use the [Hugging Face](https://huggingface.co/) model and dataset ecosystem to access a *huge* variety of pre-trained transformer models.

## Exercise 2.1: Installation and text tokenization

First things first, we need to install the [Hugging Face transformer library](https://huggingface.co/docs/transformers/index):

    conda install -c huggingface -c conda-forge transformers
    
The key classes that you will work with are `GPT2Tokenizer` to encode text into sub-word tokens, and the `GPT2LMHeadModel`. **Note** the `LMHead` part of the class name -- this is the version of the GPT2 architecture that has the text prediction heads attached to the final hidden layer representations (i.e. what we need to **generate** text). 

Instantiate the `GPT2Tokenizer` and experiment with encoding text into integer tokens. Compare the length of input with the encoded sequence length.

**Tip**: Pass the `return_tensors='pt'` argument to the togenizer to get Pytorch tensors as output (instead of lists).

In [4]:
# Your code here.
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [5]:
text = "My first GPT2 encoded text"
encoded_text = tokenizer.encode(text, return_tensors='pt')
print(encoded_text.shape) # expected shape [1 X]

torch.Size([1, 7])


In [6]:
print(tokenizer.decode(encoded_text[0]))
for i in range(encoded_text.shape[1]):
    print(tokenizer.decode(encoded_text[0, i]))

2023-07-20 18:45:44.413075: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


My first GPT2 encoded text
My
 first
 G
PT
2
 encoded
 text


In [7]:
coded = tokenizer.encode(["<SOS>", "<EOS>", "<PAD>", "<UNK>", "<RANDOM>"])
print(coded)
print(tokenizer.decode(coded))

[50256, 50256, 50256, 50256, 50256]
<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>


## Exercise 2.2: Generating Text

There are a lot of ways we can, given a *prompt* in input, sample text from a GPT2 model. Instantiate a pre-trained `GPT2LMHeadModel` and use the [`generate()`](https://huggingface.co/docs/transformers/v4.27.2/en/main_classes/text_generation#transformers.GenerationMixin.generate) method to generate text from a prompt.

**Note**: The default inference mode for GPT2 is *greedy* which might not results in satisfying generated text. Look at the `do_sample` and `temperature` parameters.

In [8]:
# Your code here.
from transformers import GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained("gpt2")

In [9]:
generated = model.generate()
print(generated)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


tensor([[50256,   198,   464,   717,   640,   314,  2497,   262,   649,  2196,
           286,   262,   983,    11,   314,   373,   523,  6568,    13,   314]])


In [10]:
print(tokenizer.decode(generated[0]))

<|endoftext|>
The first time I saw the new version of the game, I was so excited. I


# Exercise 3: Reusing Pre-trained LLMs (choose one)

Choose **one** of the following exercises (well, *at least* one). In each of these you are asked to adapt a pre-trained LLM (`GPT2Model` or `DistillBERT` are two good choices) to a new Natural Language Understanding task. A few comments:

+ Since GPT2 is a *autoregressive* model, there is no latent space aggregation at the last transformer layer (you get the same number of tokens out that you give in input). To use a pre-trained model for a classification or retrieval task, you should aggregate these tokens somehow (or opportunistically select *one* to use).

+ BERT models (including DistillBERT) have a special [CLS] token prepended to each latent representation in output from a self-attention block. You can directly use this as a representation for classification (or retrieval).

+ The first *two* exercises below can probably be done *without* any fine-tuning -- that is, just training a shallow MLP to classify or represent with the appropriate loss function.

# Exercise 3.1: Training a Text Classifier (easy)

Peruse the [text classification datasets on Hugging Face](https://huggingface.co/datasets?task_categories=task_categories:text-classification&sort=downloads). Choose a *moderately* sized dataset and use a LLM to train a classifier to solve the problem.

**Note**: A good first baseline for this problem is certainly to use an LLM *exclusively* as a feature extractor and then train a shallow model.

# Exercise 3.2: Training a Question Answering Model (harder)

Peruse the [multiple choice question answering datasets on Hugging Face](https://huggingface.co/datasets?task_categories=task_categories:multiple-choice&sort=downloads). Chose a *moderately* sized one and train a model to answer contextualized multiple-choice questions. You *might* be able to avoid fine-tuning by training a simple model to *rank* the multiple choices (see margin ranking loss in Pytorch).

# Exercise 3.3: Training a Retrieval Model (hardest)

The Hugging Face dataset repository contains a large number of ["text retrieval" problems](https://huggingface.co/datasets?task_categories=task_categories:text-retrieval&p=1&sort=downloads). These tasks generally require that the model measure *similarity* between text in some metric space -- naively, just a cosine similarity between [CLS] tokens can get you pretty far. Find an interesting retrieval problem and train a model (starting from a pre-trained LLM of course) to solve it.

**Tip**: Sometimes identifying the *retrieval* problems in these datasets can be half the challenge. [This dataset](https://huggingface.co/datasets/BeIR/scifact) might be a good starting point.

# Exercise 1:
## Dataset exploration

In [11]:
from datasets import load_dataset
dataset = load_dataset("rotten_tomatoes", data_dir="~/datasets")

Found cached dataset rotten_tomatoes (/home/manu/.cache/huggingface/datasets/rotten_tomatoes/default-data_dir=~%2Fdatasets/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46)


  0%|          | 0/3 [00:00<?, ?it/s]

In [16]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 8530
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1066
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1066
    })
})


In [17]:
dataset['train'][0]


{'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .',
 'label': 1}

In [12]:
from torch.utils.data import Dataset, DataLoader

class RottenTomatoes(Dataset):

    def __init__(self, tokenizer, split: str = "train" or "validation" or "test") -> None:
        super().__init__()
        rt = load_dataset("rotten_tomatoes", data_dir="~/datasets")
        self.data = rt[split]

    def __len__(self) -> int: return len(self.data)

    def __getitem__(self, index) -> tuple[str, int]: return self.data[index]['text'], torch.Tensor(self.data[index]['label']).int()

In [13]:
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
train_ds = RottenTomatoes(tokenizer)
val_ds = RottenTomatoes(tokenizer, split="validation")
test_ds = RottenTomatoes(tokenizer, split="test")

bsz = 64

train_dl = DataLoader(train_ds, batch_size=bsz, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=bsz, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=bsz, shuffle=True)

Found cached dataset rotten_tomatoes (/home/manu/.cache/huggingface/datasets/rotten_tomatoes/default-data_dir=~%2Fdatasets/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset rotten_tomatoes (/home/manu/.cache/huggingface/datasets/rotten_tomatoes/default-data_dir=~%2Fdatasets/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset rotten_tomatoes (/home/manu/.cache/huggingface/datasets/rotten_tomatoes/default-data_dir=~%2Fdatasets/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46)


  0%|          | 0/3 [00:00<?, ?it/s]

In [20]:
len(train_ds)

8530

In [21]:
train_ds[1][0].shape

NameError: name 'torch' is not defined

In [15]:
from transformers import GPT2Model

gpt2 = GPT2Model.from_pretrained("gpt2")

In [16]:
import torch
e = gpt2(torch.ones((1, 10)).int())
e.last_hidden_state

tensor([[[-0.0518,  0.1207, -0.4284,  ..., -0.0340, -0.0397, -0.1079],
         [-0.4345,  0.2382, -0.6182,  ...,  0.1135,  0.3325,  0.0097],
         [-0.5839,  0.0889, -0.8555,  ..., -0.1169,  0.4172,  0.5132],
         ...,
         [-0.1124, -0.2327, -0.8494,  ..., -0.2287,  0.4269,  0.3823],
         [-0.0653, -0.2508, -0.8165,  ..., -0.2782,  0.4119,  0.3119],
         [-0.0328, -0.2671, -0.7836,  ..., -0.3210,  0.4001,  0.2521]]],
       grad_fn=<ViewBackward0>)

# Compound model

In [17]:
import torch.nn as nn
class TextClassifier(nn.Module):

    def __init__(self, n_classes: int = 2) -> None:
        super(TextClassifier, self).__init__()
        self.tokenizer = tokenizer
        self.encoder = GPT2Model.from_pretrained('gpt2')
        self.head = nn.Sequential(
            nn.Linear(self.encoder.embed_dim, self.encoder.embed_dim * 4),
            nn.ReLU(),
            nn.Linear(self.encoder.embed_dim * 4, self.encoder.embed_dim),
            nn.ReLU(),
            nn.Linear(self.encoder.embed_dim, n_classes)
        )

    def forward(self, X: str):
        with torch.no_grad():
            X = self.tokenizer.encode(X, return_tensors='pt')
            X = self.encoder(X).last_hidden_state
        return self.head(X)
    
tc = TextClassifier().to(device)
tc(train_ds[0][0].to(device)).shape

AttributeError: 'str' object has no attribute 'to'

In [None]:
@torch.no_grad()
def validation(model, dl, loss):
    val_loss = 0
    accs = 0
    for x, y in tqdm(train_dl, "Validation", leave=False):
        x, y = x.to(device), y.to(device)
        logits = model(x)
        val_loss += loss(logits, y).item()
        accs += (logits.argmax(dim=1) == y).sum().float().item()
    return val_loss / len(dl), accs / len(dl.dataset)



def train(model, train_dl, val_dl, loss, optim, epochs, validation_freq):
    for t in range(1, epochs + 1):
        for x, y in tqdm(train_dl, f"Epoch #{t}", leave=False):
            x, y = x.to(device), y.to(device)
            logits = model(x)
            lss = loss(logits, y)

            optim.zero_grad()
            lss.backward()
            optim.step()
        if t % validation_freq == 0:
            lss, acc = validation(model, val_dl, loss)
            print(f"Epoch #{t}: ", lss, acc)

In [None]:
loss_fn = F.binary_cross_entropy
optim = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.05)
train(tc, train_dl, val_dl, loss_fn, optim, 10, 1)


                                                 

RuntimeError: stack expects each tensor to be equal size, but got [1, 35] at entry 0 and [1, 13] at entry 1