# Introduction
In this laboratory we will get our hands dirty working with Large Language Models (e.g. GPT and BERT) to do various useful things. I you haven't already, it is highly recommended to:

+ Read the [Attention is All you Need](https://arxiv.org/abs/1706.03762) paper, which is the basis for all transformer-based LLMs.
+ Watch (and potentially *code along*) with this [Andrej Karpathy video](https://www.youtube.com/watch?v=kCc8FmEb1nY) which shows you how to build an autoregressive GPT model from the ground up.

# Exercise 1: Warming Up
In this first exercise you will train a *small* autoregressive GPT model for character generation (the one used by Karpathy in his video) to generate text in the style of Dante Aligheri. Use [this file](https://archive.org/stream/ladivinacommedia00997gut/1ddcd09.txt), which contains the entire text of Dante's Inferno (**note**: you will have to delete some introductory text at the top of the file before training). Train the model for a few epochs, monitor the loss, and generate some text at the end of training. Qualitatively evaluate the results

In [3]:
#if using colab
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 10000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('/content/drive/MyDrive/LaboratoryDLA/LAB2/dante.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPTLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
      # idx is (B, T) array of indices in the current context
      for _ in range(max_new_tokens):
          # crop idx to the last block_size tokens
          idx_cond = idx[:, -block_size:]
          # get the predictions
          logits, loss = self(idx_cond)
          # focus only on the last time step
          logits = logits[:, -1, :] # becomes (B, C)
          # apply softmax to get probabilities
          probs = F.softmax(logits, dim=-1) # (B, C)
          # sample from the distribution
          idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
          # append sampled index to the running sequence
          idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
      return idx

In [None]:
model = GPTLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))
#open('more.txt', 'w').write(decode(m.generate(context, max_new_tokens=10000)[0].tolist()))

10.783546 M parameters
step 0: train loss 4.0481, val loss 4.0425
step 500: train loss 1.7516, val loss 1.8010
step 1000: train loss 1.2666, val loss 1.5518
step 1500: train loss 0.7492, val loss 1.6696
step 2000: train loss 0.2980, val loss 2.1068
step 2500: train loss 0.1388, val loss 2.5982
step 3000: train loss 0.0991, val loss 2.9071
step 3500: train loss 0.0872, val loss 3.1320
step 4000: train loss 0.0803, val loss 3.3237
step 4500: train loss 0.0779, val loss 3.3884
step 5000: train loss 0.0731, val loss 3.4739
step 5500: train loss 0.0697, val loss 3.5807
step 6000: train loss 0.0673, val loss 3.6510
step 6500: train loss 0.0654, val loss 3.7594
step 7000: train loss 0.0642, val loss 3.8150
step 7500: train loss 0.0628, val loss 3.8619
step 8000: train loss 0.0615, val loss 3.9680
step 8500: train loss 0.0606, val loss 3.9595
step 9000: train loss 0.0594, val loss 3.9947
step 9500: train loss 0.0595, val loss 4.0293
step 9999: train loss 0.0582, val loss 4.0792

udir: <<Perche

The model seems to be overfitted since the validation loss increases

# Exercise 2: Working with Real LLMs

Our toy GPT can only take us so far. In this exercise we will see how to use the [Hugging Face](https://huggingface.co/) model and dataset ecosystem to access a *huge* variety of pre-trained transformer models.

## Exercise 2.1: Installation and text tokenization

First things first, we need to install the [Hugging Face transformer library](https://huggingface.co/docs/transformers/index):

    conda install -c huggingface -c conda-forge transformers
    
The key classes that you will work with are `GPT2Tokenizer` to encode text into sub-word tokens, and the `GPT2LMHeadModel`. **Note** the `LMHead` part of the class name -- this is the version of the GPT2 architecture that has the text prediction heads attached to the final hidden layer representations (i.e. what we need to **generate** text).

Instantiate the `GPT2Tokenizer` and experiment with encoding text into integer tokens. Compare the length of input with the encoded sequence length.

**Tip**: Pass the `return_tensors='pt'` argument to the togenizer to get Pytorch tensors as output (instead of lists).

In [5]:
!pip install transformers datasets wandb torchmetrics



In [6]:
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
text = "Nel mezzo del cammin di nostra vita mi ritrovai per una selva oscura che' la diritta via era smarrita."

tokenized_text = tokenizer.encode(text, return_tensors="pt")
print(f"Tokenized text: {tokenized_text}")
print(f"Text length {len(text)}, tokenized length {len(tokenized_text[0])}")

# Text length 143, tokenized length 78

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Tokenized text: tensor([[   45,   417,   502, 47802,  1619, 12172,  1084,  2566, 18216,   430,
           410,  5350, 21504,   374,   270, 18657,  1872,   583,   555,    64,
           384,  6780,    64,   267,  1416,  5330,  1125,     6,  8591, 26672,
         48519,  2884,  6980,   895,   283,   799,    64,    13]])
Text length 102, tokenized length 38


## Exercise 2.2: Generating Text

There are a lot of ways we can, given a *prompt* in input, sample text from a GPT2 model. Instantiate a pre-trained `GPT2LMHeadModel` and use the [`generate()`](https://huggingface.co/docs/transformers/v4.27.2/en/main_classes/text_generation#transformers.GenerationMixin.generate) method to generate text from a prompt.

**Note**: The default inference mode for GPT2 is *greedy* which might not results in satisfying generated text. Look at the `do_sample` and `temperature` parameters.

In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, set_seed

class Prompt_Generator:
    def __init__(self, config):
        self.tokenizer = GPT2Tokenizer.from_pretrained(config["Tokenizer"])
        self.model = GPT2LMHeadModel.from_pretrained(config["Model"])
        self.do_sample = config["do_sample"]
        self.temperature = config["temperature"]
        self.max_lenght = config["max_lenght"]
        self.no_repeat_ngram_size = config["no_repeat_ngram_size"]
        self.prompt = config.get("prompt")

    def generate(self, prompt = None, do_sample = None, temperature = None, max_lenght = None, no_repeat_ngram_size = None):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        if prompt is not None:
          self.prompt = prompt
        if temperature is not None:
          self.temperature = temperature
        if do_sample is not None:
          self.do_sample = do_sample
        if max_lenght is not None:
          self.max_lenght = max_lenght
        if no_repeat_ngram_size is not None:
          self.no_repeat_ngram_size = no_repeat_ngram_size
        tokenized_output = self.model.generate(**inputs,
                                  max_new_tokens = self.max_lenght,
                                  no_repeat_ngram_size = self.no_repeat_ngram_size,
                                  temperature = self.temperature,
                                  do_sample = self.do_sample,
                                  pad_token_id = self.tokenizer.eos_token_id)
        output = self.tokenizer.decode(tokenized_output[0], skip_special_tokens=True)
        print(output)


In [None]:
config = {"Tokenizer": "openai-community/gpt2",
        "Model": "openai-community/gpt2",
        "do_sample": True,
        "temperature": 2.0,
        "max_lenght": 30,
        "no_repeat_ngram_size": 2,
        }

prompt_generator = Prompt_Generator(config)
prompt_generator.generate("What are")
prompt_generator.generate("I am")
prompt_generator.generate("You")

What are you supposed to say if it doesn't happen… that it happened? Well that means to give your team the winning hand that got them down against an
I am sorry - what were your worries before? I have learned not you - but what have yours - like I was? That, ahah? Oh...
You like them. The one who came down under from this is really very attractive," Bitt said, referring to the Trump who won with fewer supporters than


In [None]:
prompt_generator.generate("What are", do_sample=False)
prompt_generator.generate("I am")
prompt_generator.generate("You")



What are the best ways to get your hands on a new game?

We've got a lot of great games coming out this year, and we're
I am not a fan of the idea of a "big-budget" movie. I think it's a waste of money.

I think the movie
You, the man who has been the most important person in my life, I am sorry for what I have done. I will never forget that day.


In [None]:
prompt_generator.generate("What are", do_sample=True, temperature= 9.0)
prompt_generator.generate("I am")
prompt_generator.generate("You")

What are YOU watching next (not on CBS right where We've never heard on that) CBS SundayNight's Jon Belgrade chats politics (@SandyStrick
I am looking across both systems again... "A: Okay for any further question." ―Thon Tse'nin and Lor Rond with Hao J
You "must always carry something. And be as free not averse..." - Michael Carusoglu In order unto men of virtue to fulfill those qualities to


Raising the temperature makes the generation pretty bad.

In [None]:
prompt_generator.generate("What are", do_sample=True, temperature= 2.0, no_repeat_ngram_size=5)
prompt_generator.generate("I am")
prompt_generator.generate("You")

What are Your thoughts about this case?

Your favorite news outlets make each of their headlines based entirely on what stories they want (see newsrooms from your
I am afraid we all learn very soon from childhood how quickly we have come into contact between man the thing to worry about the little one to whom to relate.
You to try?

Have something specific to offer? Email: askalev.bazakh < AskAriV@google.com> or


In [None]:
prompt_generator.generate("What are", do_sample=True, temperature= 2.0, no_repeat_ngram_size=9)
prompt_generator.generate("I am")
prompt_generator.generate("You")

What are the latest tech and games release trends across every game category to know about and try to forecast next year? Who did the best debut last month under "
I am truly humbled to be speaking up and making sure every single mother's son receives a clean, honest medical treatment that respects a fetus from her or her
You a liar?" my mother asked with wide eyes and no remorse? We laughed like little boys! And who among us looked shocked when I got back to


Also raising the no_repeat_ngram_size too much makes the generation a little less natural.

# Exercise 3: Reusing Pre-trained LLMs (choose one)

Choose **one** of the following exercises (well, *at least* one). In each of these you are asked to adapt a pre-trained LLM (`GPT2Model` or `DistillBERT` are two good choices) to a new Natural Language Understanding task. A few comments:

+ Since GPT2 is a *autoregressive* model, there is no latent space aggregation at the last transformer layer (you get the same number of tokens out that you give in input). To use a pre-trained model for a classification or retrieval task, you should aggregate these tokens somehow (or opportunistically select *one* to use).

+ BERT models (including DistillBERT) have a special [CLS] token prepended to each latent representation in output from a self-attention block. You can directly use this as a representation for classification (or retrieval).

+ The first *two* exercises below can probably be done *without* any fine-tuning -- that is, just training a shallow MLP to classify or represent with the appropriate loss function.

# Exercise 3.1: Training a Text Classifier (easy)

Peruse the [text classification datasets on Hugging Face](https://huggingface.co/datasets?task_categories=task_categories:text-classification&sort=downloads). Choose a *moderately* sized dataset and use a LLM to train a classifier to solve the problem.

**Note**: A good first baseline for this problem is certainly to use an LLM *exclusively* as a feature extractor and then train a shallow model.



In [1]:
from transformers import DistilBertTokenizer, DistilBertModel, set_seed
from datasets import load_dataset
import torch
import torch.nn as nn
from torch.nn import functional as F

class TextClassifier(nn.Module):
    def __init__(self, num_classes = 2):
        super(TextClassifier, self).__init__()
        self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
        self.backbone = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.head = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, text,attention_mask, device):
        #inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        inputs = text.to(device)
        output = self.backbone(inputs, attention_mask = attention_mask)
        cls_token = output['last_hidden_state'][:, 0, :]
        return self.head(cls_token)

# Exercise 3.2: Training a Question Answering Model (harder)

Peruse the [multiple choice question answering datasets on Hugging Face](https://huggingface.co/datasets?task_categories=task_categories:multiple-choice&sort=downloads). Chose a *moderately* sized one and train a model to answer contextualized multiple-choice questions. You *might* be able to avoid fine-tuning by training a simple model to *rank* the multiple choices (see margin ranking loss in Pytorch).


In [None]:
#TODO understand if I can use https://huggingface.co/docs/transformers/tasks/multiple_choice which
#basically uses pre made classes for multiple options question answering by tranformers hugging face library

In [None]:
from transformers import DistilBertModel, DistilBertTokenizer
import torch
import torch.nn as nn

# Load model and tokenizer
model_name = "distilbert-base-uncased"
model = DistilBertModel.from_pretrained(model_name)
tokenizer = DistilBertTokenizer.from_pretrained(model_name)

# Define custom classifier
class MultipleChoiceModel(nn.Module):
    def __init__(self, distilbert_model, tokenizer):
        super(MultipleChoiceModel, self).__init__()
        self.distilbert = distilbert_model
        self.tokenizer = tokenizer
        self.classifier = nn.Linear(distilbert_model.config.hidden_size, 1)

    def forward(self, context, options):
        inputs = [self.tokenizer.encode_plus(context, option, return_tensors='pt', padding='max_length', truncation=True, max_length=512) for option in options]
        input_ids = torch.stack([input['input_ids'] for input in inputs]).squeeze(1)
        attention_mask = torch.stack([input['attention_mask'] for input in inputs]).squeeze(1)

        outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        cls_outputs = outputs.last_hidden_state[:, 0, :]
        logits = self.classifier(cls_outputs).squeeze(-1)
        return logits

# Initialize custom model
custom_model = MultipleChoiceModel(model, tokenizer)

# Example question and options
question = "What is the capital of France?"
options = ["London", "Berlin", "Paris", "Madrid"]

# Model inference
with torch.no_grad():
    logits = custom_model(question, options)

# Choose the option with the highest score
best_choice = torch.argmax(logits, dim=0).item()

# Print the answer
print(f"The answer is: {options[best_choice]}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

The answer is: Madrid


# LLM Trainer

Let's create a **LLM trainer** inspired by the one in LAB1

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchmetrics
import wandb
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset, load_metric
from tqdm import tqdm
import os
import torch

class LLMTrainer:
    def __init__(self, config):

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.num_epochs = config["num_epochs"]
        self.batch_size = config.get("batch_size", 16)
        self.learning_rate = config.get("learning_rate", 5e-5)
        self.model_name = config["model_name"]

        if config["use_wandb"]:
            self.use_wandb = True
            wandb.init(project="TextClassification", config=config)
        else:
            self.use_wandb = False

        #self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained)
        #self.model = AutoModelForSequenceClassification.from_pretrained(self.pretrained, num_labels=int(config["num_classes"]))
        match self.model_name:
            case "text_classifier":
                self.model = TextClassifier(num_classes=int(config["num_classes"]))
                self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
                self.backbone = DistilBertModel.from_pretrained("distilbert-base-uncased")

        self.model.to(self.device)


        # Dataset loading
        dataset_name = config["dataset"]
        self.dataset = load_dataset(dataset_name)

        # Tokenize datasets
        self.train_dataset = self.dataset['train'].map(self.tokenize_function, batched=True)
        if 'validation' in self.dataset:
            self.val_dataset = self.dataset['validation'].map(self.tokenize_function, batched=True)
        else:
            # Create validation split from training data
            val_size = int(0.2 * len(self.train_dataset))
            train_size = len(self.train_dataset) - val_size
            self.train_dataset, self.val_dataset = random_split(self.train_dataset, [train_size, val_size])

        self.test_dataset = self.dataset['test'].map(self.tokenize_function, batched=True)

        self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
        self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
        self.test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
        self.val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)
        self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)

        self.training_mode = config["training_mode"]
        print(self.training_mode)
        if self.training_mode == "finetune":
            backbone_params = list(self.model.backbone.parameters())
            head_params = list(self.model.head.parameters())
            lrs = [
                  {'params': backbone_params, 'lr': config["lr"]/1e2},
                  {'params': head_params, 'lr': config["lr"]}
            ]
        elif self.training_mode == "head_only":
            lrs = [
                  {'params': self.model.head.parameters(), 'lr': config["lr"]}
            ]
            for param in self.model.backbone.parameters():
                  param.requires_grad = False
        else:
            lrs = [
                  {'params': self.model.parameters(), 'lr': config["lr"]}
            ]
        match config["optimizer"]:
            case "adam":
                self.optimizer = optim.Adam(lrs)
            case "sgd":
                self.optimizer = optim.SGD(lrs)
            case "adamw":
                self.optimizer = optim.AdamW(lrs)
            case _:
                raise ValueError(f'Optimizer not found. Received {config["optimizer"]}.')

        match config["loss"]:
            case "cross_entropy":
                self.loss = nn.CrossEntropyLoss()
            case _:
                raise ValueError(f'Loss not found. Received {config["loss"]}.')

        self.metrics = []
        for metric in config["metrics"]:
            match metric:
                case "accuracy":
                    self.metrics.append(torchmetrics.Accuracy(task="multiclass", num_classes=config["num_classes"]).to(self.device))
                case "precision":
                    self.metrics.append(torchmetrics.Precision(task="multiclass", average='macro', num_classes=config["num_classes"]).to(self.device))
                case "recall":
                    self.metrics.append(torchmetrics.Recall(task="multiclass", average='macro', num_classes=config["num_classes"]).to(self.device))
                case _:
                    raise ValueError(f'Metric not found. Received {metric}.')

    def tokenize_function(self, examples):
        return self.tokenizer(examples['text'], padding="max_length", truncation=True)

    def train_head(self, epochs=10, lr=3e-4):
        self.model.train()
        optimizer = torch.optim.AdamW(self.model.head.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        for epoch in range(epochs):
            acc_loss = 0
            for i, batch in tqdm(enumerate(self.train_loader)):
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels = batch["label"].to(self.device)
                optimizer.zero_grad()
                outputs = self.model(input_ids, attention_mask=attention_mask)
                loss = self.loss(outputs.logits, labels)
                loss.backward()
                optimizer.step()
                acc_loss += loss.item()
                if (i + 1) % 500 == 0:
                    if self.use_wandb:
                        wandb.log({"loss": acc_loss / 500})
                    print(f"Epoch {epoch}, batch {i}, loss {acc_loss / 500}")
                    acc_loss = 0

    def train_one_epoch(self):
        self.model.train()
        running_loss = 0.0

        for batch in tqdm(self.train_loader):
            #print(batch)  # Print batch for debugging

            input_ids = batch["input_ids"].to(self.device)
            attention_mask = batch["attention_mask"].to(self.device)
            labels = batch["label"].to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(input_ids, attention_mask=attention_mask, device = self.device)
            loss = self.loss(outputs, labels)

            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()

        train_loss = running_loss / len(self.train_loader)
        print(f"Training Loss: {train_loss:.4f}")

    def evaluate(self):
        self.model.eval()
        val_loss = 0.0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in self.val_loader:
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels = batch["label"].to(self.device)
                outputs = self.model(input_ids, attention_mask=attention_mask, device=self.device)
                loss = self.loss(outputs, labels)

                val_loss += loss.item()
                all_preds.append(outputs.argmax(dim=-1))
                all_labels.append(labels)

        val_loss /= len(self.val_loader)
        val_preds = torch.cat(all_preds)
        val_labels = torch.cat(all_labels)

        metrics = {metric.__class__.__name__: metric(val_preds, val_labels) for metric in self.metrics}

        if self.use_wandb:
            wandb.log({"val_loss": val_loss, **metrics})

        print(f'Validation Loss: {val_loss:.4f}')
        for name, value in metrics.items():
            print(f'Validation {name}: {value:.4f}')


    def test(self):
        self.model.eval()
        test_loss = 0.0

        with torch.no_grad():
            for batch in self.test_loader:
                input_ids = batch["input_ids"].to(self.device)
                attention_mask = batch["attention_mask"].to(self.device)
                labels = batch["label"].to(self.device)
                outputs = self.model(input_ids, attention_mask=attention_mask, device=self.device)
                loss = self.loss(outputs, labels)
                test_loss += loss.item()

        test_loss /= len(self.test_loader)
        test_accuracy = self.metric.compute().item()
        self.metric.reset()

        if self.use_wandb:
            wandb.log({"test_loss": test_loss, "test_accuracy": test_accuracy})

        print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}')

    def run(self):
        for epoch in range(self.num_epochs):
            print(f'Epoch {epoch + 1}/{self.num_epochs}')
            self.train_one_epoch()
            self.evaluate()
        self.test()
        if self.use_wandb:
            wandb.finish()
# To train head only
# trainer.train_head(epochs=10, lr=3e-4)
# To finetune the whole model
# trainer.finetune(epochs=10, lr=3e-4)


In [None]:
# Example usage
config = {
    "use_wandb": True,
    "num_epochs": 15,
    "model_name": "text_classifier",
    "dataset": "dair-ai/emotion",
    "batch_size": 4,
    "lr": 5e-5,
    "num_classes": 6,
    "metrics": ["accuracy", "precision", "recall"],
    "optimizer": "adamw",
    "loss": "cross_entropy",
    "training_mode": "finetune"
}

trainer = LLMTrainer(config)
trainer.run()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Map:   0%|          | 0/16000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

finetune
Epoch 1/15


100%|██████████| 4000/4000 [13:25<00:00,  4.96it/s]


Training Loss: 1.0044
Validation Loss: 0.6154
Validation MulticlassAccuracy: 0.7685
Validation MulticlassPrecision: 0.7685
Validation MulticlassRecall: 0.6599
Epoch 2/15


100%|██████████| 4000/4000 [13:29<00:00,  4.94it/s]


Training Loss: 0.5230
Validation Loss: 0.4198
Validation MulticlassAccuracy: 0.8565
Validation MulticlassPrecision: 0.8320
Validation MulticlassRecall: 0.8213
Epoch 3/15


100%|██████████| 4000/4000 [13:29<00:00,  4.94it/s]


Training Loss: 0.3764
Validation Loss: 0.3244
Validation MulticlassAccuracy: 0.8875
Validation MulticlassPrecision: 0.8616
Validation MulticlassRecall: 0.8533
Epoch 4/15


100%|██████████| 4000/4000 [13:28<00:00,  4.95it/s]


Training Loss: 0.2964
Validation Loss: 0.2761
Validation MulticlassAccuracy: 0.9070
Validation MulticlassPrecision: 0.8854
Validation MulticlassRecall: 0.8677
Epoch 5/15


100%|██████████| 4000/4000 [13:26<00:00,  4.96it/s]


Training Loss: 0.2446
Validation Loss: 0.2717
Validation MulticlassAccuracy: 0.9035
Validation MulticlassPrecision: 0.8972
Validation MulticlassRecall: 0.8618
Epoch 6/15


100%|██████████| 4000/4000 [13:29<00:00,  4.94it/s]


Training Loss: 0.2113
Validation Loss: 0.2255
Validation MulticlassAccuracy: 0.9180
Validation MulticlassPrecision: 0.8959
Validation MulticlassRecall: 0.8873
Epoch 7/15


100%|██████████| 4000/4000 [13:28<00:00,  4.95it/s]


Training Loss: 0.1805
Validation Loss: 0.2050
Validation MulticlassAccuracy: 0.9190
Validation MulticlassPrecision: 0.8943
Validation MulticlassRecall: 0.8929
Epoch 8/15


100%|██████████| 4000/4000 [13:27<00:00,  4.95it/s]


Training Loss: 0.1604
Validation Loss: 0.1960
Validation MulticlassAccuracy: 0.9230
Validation MulticlassPrecision: 0.9138
Validation MulticlassRecall: 0.8869
Epoch 9/15


100%|██████████| 4000/4000 [13:25<00:00,  4.96it/s]


Training Loss: 0.1453
Validation Loss: 0.1779
Validation MulticlassAccuracy: 0.9240
Validation MulticlassPrecision: 0.8931
Validation MulticlassRecall: 0.9080
Epoch 10/15


100%|██████████| 4000/4000 [13:27<00:00,  4.96it/s]


Training Loss: 0.1325
Validation Loss: 0.1776
Validation MulticlassAccuracy: 0.9280
Validation MulticlassPrecision: 0.9015
Validation MulticlassRecall: 0.9094
Epoch 11/15


100%|██████████| 4000/4000 [13:26<00:00,  4.96it/s]


Training Loss: 0.1190
Validation Loss: 0.1777
Validation MulticlassAccuracy: 0.9330
Validation MulticlassPrecision: 0.9059
Validation MulticlassRecall: 0.9211
Epoch 12/15


100%|██████████| 4000/4000 [13:27<00:00,  4.96it/s]


Training Loss: 0.1103
Validation Loss: 0.1780
Validation MulticlassAccuracy: 0.9280
Validation MulticlassPrecision: 0.8994
Validation MulticlassRecall: 0.9152
Epoch 13/15


 25%|██▌       | 1011/4000 [03:24<10:05,  4.93it/s]