# Introduction
In this laboratory we will get our hands dirty working with Large Language Models (e.g. GPT and BERT) to do various useful things. I you haven't already, it is highly recommended to:

+ Read the [Attention is All you Need](https://arxiv.org/abs/1706.03762) paper, which is the basis for all transformer-based LLMs.
+ Watch (and potentially *code along*) with this [Andrej Karpathy video](https://www.youtube.com/watch?v=kCc8FmEb1nY) which shows you how to build an autoregressive GPT model from the ground up.

# Exercise 1: Warming Up
In this first exercise you will train a *small* autoregressive GPT model for character generation (the one used by Karpathy in his video) to generate text in the style of Dante Aligheri. Use [this file](https://archive.org/stream/ladivinacommedia00997gut/1ddcd09.txt), which contains the entire text of Dante's Inferno (**note**: you will have to delete some introductory text at the top of the file before training). Train the model for a few epochs, monitor the loss, and generate some text at the end of training. Qualitatively evaluate the results

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm
# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 1000
eval_interval = 100
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
# ------------

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input2.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPTLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = GPTLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in tqdm(range(max_iters)):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))
#open('more.txt', 'w').write(decode(m.generate(context, max_new_tokens=10000)[0].tolist()))

10.783546 M parameters


  0%|                                                  | 0/1000 [00:00<?, ?it/s]

step 0: train loss 4.0476, val loss 4.0427


 10%|████                                    | 100/1000 [01:24<05:26,  2.76it/s]

step 100: train loss 2.3760, val loss 2.4034


 20%|████████                                | 200/1000 [02:48<04:45,  2.80it/s]

step 200: train loss 2.2835, val loss 2.3137


 30%|████████████                            | 300/1000 [04:11<04:13,  2.77it/s]

step 300: train loss 2.1615, val loss 2.1907


 40%|████████████████                        | 400/1000 [05:36<03:37,  2.76it/s]

step 400: train loss 1.9235, val loss 1.9592


 50%|████████████████████                    | 500/1000 [07:00<03:00,  2.77it/s]

step 500: train loss 1.7546, val loss 1.7993


 60%|████████████████████████                | 600/1000 [08:24<02:24,  2.76it/s]

step 600: train loss 1.6406, val loss 1.7044


 70%|████████████████████████████            | 700/1000 [09:49<01:48,  2.77it/s]

step 700: train loss 1.5294, val loss 1.6300


 80%|████████████████████████████████        | 800/1000 [11:13<01:12,  2.76it/s]

step 800: train loss 1.4382, val loss 1.5861


 90%|████████████████████████████████████    | 900/1000 [12:37<00:36,  2.76it/s]

step 900: train loss 1.3449, val loss 1.5481


100%|███████████████████████████████████████▉| 999/1000 [14:01<00:00,  2.76it/s]

step 999: train loss 1.2498, val loss 1.5333


100%|███████████████████████████████████████| 1000/1000 [14:49<00:00,  1.12it/s]




Di  se tutto ferbele bestilme
  seguerta ne l'altemma, utger divesse,
  franne siede, acmicola` giu` de' fuggira

me le prose tu temo de l'embre e cota;
  me m'abbe a da questa ta lietra onde me;
  sotto ci solo acqua fella a pie` quica tutti>>.

Quinfin sente l'erra viede 'l maestro,
  fuol, ch'acqua greffante eratta ciaccia
  di foco Il fatte forcese tenduna.

E 'nvidi la far sconda lrude gratto
  cimo susci strattinia e suscilitsi
  da volta vedesta da cio` non che montre

facealde sotto Sie


In [None]:
torch.save(m.state_dict(), "model.pt")

In [None]:
for iter in tqdm(range(max_iters)):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [None]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))



Inferno: Canto XI


Ne' quel che m'abbaianco a si confesse
  la` reno`, da lui, si` conforte,
  e due 'l seguir qual e` com'uscive.

Io savisi poggiati e con piu` dentro;
  e di rispontar la coda di spiata
  di quell'aere quanto a mul che non modo.

Ma per l'occhio scoglio travien sospiri,
  a mi cantare un peccatori, a cu' il giuso
  con tenea ch'era fatto al cio` ch'i' vinsi.

Vedi che disse: "Qui son piglio convien che e trova
  a quel papeccator, e pugno, presso regio
  e volte a la` bramos


In [None]:
torch.save(m.state_dict(), "model2.pt")

# Exercise 2: Working with Real LLMs

Our toy GPT can only take us so far. In this exercise we will see how to use the [Hugging Face](https://huggingface.co/) model and dataset ecosystem to access a *huge* variety of pre-trained transformer models.

## Exercise 2.1: Installation and text tokenization

First things first, we need to install the [Hugging Face transformer library](https://huggingface.co/docs/transformers/index):

    conda install -c huggingface -c conda-forge transformers
    
The key classes that you will work with are `GPT2Tokenizer` to encode text into sub-word tokens, and the `GPT2LMHeadModel`. **Note** the `LMHead` part of the class name -- this is the version of the GPT2 architecture that has the text prediction heads attached to the final hidden layer representations (i.e. what we need to **generate** text).

Instantiate the `GPT2Tokenizer` and experiment with encoding text into integer tokens. Compare the length of input with the encoded sequence length.

**Tip**: Pass the `return_tensors='pt'` argument to the togenizer to get Pytorch tensors as output (instead of lists).

In [None]:
# Your code here.
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [None]:
tokenizer("Well, here we are", return_tensors='pt')['input_ids']

tensor([[5779,   11,  994,  356,  389]])

## Exercise 2.2: Generating Text

There are a lot of ways we can, given a *prompt* in input, sample text from a GPT2 model. Instantiate a pre-trained `GPT2LMHeadModel` and use the [`generate()`](https://huggingface.co/docs/transformers/v4.27.2/en/main_classes/text_generation#transformers.GenerationMixin.generate) method to generate text from a prompt.

**Note**: The default inference mode for GPT2 is *greedy* which might not results in satisfying generated text. Look at the `do_sample` and `temperature` parameters.

In [None]:
from transformers import GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained("gpt2")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

outputs = model(**inputs, labels=inputs["input_ids"])



Downloading:   0%|          | 0.00/548M [00:00<?, ?B/s]

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


tensor([[50256,   198,   464,   717,   640,   314,  2497,   262,   649,  2196,
           286,   262,   983,    11,   314,   373,   523,  6568,    13,   314]])

In [None]:
tokenizer.pad_token_id = tokenizer.eos_token_id

inputs = tokenizer(["Today is"], return_tensors="pt")

# Example 1: Print the scores for each token generated with Greedy Search

outputs = model.generate(**inputs, max_new_tokens=80, return_dict_in_generate=True, output_scores=True, do_sample=True, temperature=2.3)

input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]

generated_tokens = outputs.sequences[:, input_length:]
for tok in generated_tokens:
    print(tokenizer.decode(tok))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


 now an easy day. And it looks amazing today because we came at 10 PM, we started that 5min time on Thursday 9 to end we had 30mins but by Monday it looks we still ran into 50k-60mins per mile for 24+ times, which in terms people did have fun to work after 10 hours long with no training needed and very interesting challenges after it starts! So there


# Exercise 3: Reusing Pre-trained LLMs (choose one)

Choose **one** of the following exercises (well, *at least* one). In each of these you are asked to adapt a pre-trained LLM (`GPT2Model` or `DistillBERT` are two good choices) to a new Natural Language Understanding task. A few comments:

+ Since GPT2 is a *autoregressive* model, there is no latent space aggregation at the last transformer layer (you get the same number of tokens out that you give in input). To use a pre-trained model for a classification or retrieval task, you should aggregate these tokens somehow (or opportunistically select *one* to use).

+ BERT models (including DistillBERT) have a special [CLS] token prepended to each latent representation in output from a self-attention block. You can directly use this as a representation for classification (or retrieval).

+ The first *two* exercises below can probably be done *without* any fine-tuning -- that is, just training a shallow MLP to classify or represent with the appropriate loss function.

# Exercise 3.1: Training a Text Classifier (easy)

Peruse the [text classification datasets on Hugging Face](https://huggingface.co/datasets?task_categories=task_categories:text-classification&sort=downloads). Choose a *moderately* sized dataset and use a LLM to train a classifier to solve the problem.

**Note**: A good first baseline for this problem is certainly to use an LLM *exclusively* as a feature extractor and then train a shallow model.

# Exercise 3.2: Training a Question Answering Model (harder)

Peruse the [multiple choice question answering datasets on Hugging Face](https://huggingface.co/datasets?task_categories=task_categories:multiple-choice&sort=downloads). Chose a *moderately* sized one and train a model to answer contextualized multiple-choice questions. You *might* be able to avoid fine-tuning by training a simple model to *rank* the multiple choices (see margin ranking loss in Pytorch).

# Exercise 3.3: Training a Retrieval Model (hardest)

The Hugging Face dataset repository contains a large number of ["text retrieval" problems](https://huggingface.co/datasets?task_categories=task_categories:text-retrieval&p=1&sort=downloads). These tasks generally require that the model measure *similarity* between text in some metric space -- naively, just a cosine similarity between [CLS] tokens can get you pretty far. Find an interesting retrieval problem and train a model (starting from a pre-trained LLM of course) to solve it.

**Tip**: Sometimes identifying the *retrieval* problems in these datasets can be half the challenge. [This dataset](https://huggingface.co/datasets/BeIR/scifact) might be a good starting point.

# Exercise 3.1: Training a Text Classifier
I decided to do classic sentiment classification with **IMDB**, a dataset of 25000 highly polar movie reviews where we need to classify if the review is "Negative": 0 or "Positive": 1 <br>


In [None]:
from datasets import load_dataset

In [None]:
train_data = load_dataset("imdb", "plain_text",split='train').shuffle(seed=42)
test_data = load_dataset("imdb", "plain_text",split='test')

Found cached dataset imdb (/home/muduard/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
Loading cached shuffled indices for dataset at /home/muduard/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-9c48ce5d173413c7.arrow
Found cached dataset imdb (/home/muduard/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


We use BERT uncased for classification.
We define a simple preprocessing function that returns a tokenized input for the forward function of the Trainer.

In [None]:
from transformers import AutoTokenizer
import numpy as np

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def preprocess_data(examples):
  # take a batch of texts
  text = examples["text"]
  # encode them
  encoding = tokenizer(text, padding="max_length", truncation=True, max_length=128)
  # add labels
  encoding["labels"] = examples["label"]

  return encoding


Encode the dataset and test that it works correctly, then format it for pytorch use.

In [None]:

encoded_train = train_data.map(preprocess_data, batched=True)
encoded_test = test_data.map(preprocess_data, batched=True)

Loading cached processed dataset at /home/muduard/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-c58134a836ca495e.arrow


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

In [None]:
example = encoded_train[0]
print(example.keys())

dict_keys(['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'])


In [None]:
tokenizer.decode(example['input_ids'])

'[CLS] i rented i am curious - yellow from my video store because of all the controversy that surrounded it when it was first released in 1967. i also heard that at first it was seized by u. s. customs if it ever tried to enter this country, therefore being a fan of films considered " controversial " i really had to see this for myself. < br / > < br / > the plot is centered around a young swedish drama student named lena who wants to learn everything she can about life. in particular she wants to focus her attentions to making some sort of documentary on what the average swede thought about certain political issues [SEP]'

In [None]:
encoded_train.set_format("torch")
encoded_test.set_format("torch")

Huggingface provides a way to load pretrained bert-uncased for the specific task of Sequence CLassification

In [None]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")

loading configuration file config.json from cache at /home/muduard/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8a40076/config.json
Model config BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.23.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

loading weights file model.safetensors from cache at /home/muduard/.cache/huggingface/hub/models--bert-base-uncased/snapshots/1dbc166cf8765166998eff31ade2eb64c8

Initialize arguments to pass the trainer, we use the f1 metric as score and finetune bert for 5 epochs

In [None]:
batch_size = 8
metric_name = "f1"
from tqdm import tqdm
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    f"bert-imdb-classification",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


Define metrics and function callback for the trainer

In [None]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch

def b_tp(preds, labels):
  '''Returns True Positives (TP): count of correct predictions of actual class 1'''
  return sum([preds == labels and preds == 1 for preds, labels in zip(preds, labels)])

def b_fp(preds, labels):
  '''Returns False Positives (FP): count of wrong predictions of actual class 1'''
  return sum([preds != labels and preds == 1 for preds, labels in zip(preds, labels)])

def b_tn(preds, labels):
  '''Returns True Negatives (TN): count of correct predictions of actual class 0'''
  return sum([preds == labels and preds == 0 for preds, labels in zip(preds, labels)])

def b_fn(preds, labels):
  '''Returns False Negatives (FN): count of wrong predictions of actual class 0'''
  return sum([preds != labels and preds == 0 for preds, labels in zip(preds, labels)])

def b_metrics(preds, labels):
  '''
  Returns the following metrics:
    - accuracy    = (TP + TN) / N
    - precision   = TP / (TP + FP)
    - recall      = TP / (TP + FN)
    - specificity = TN / (TN + FP)
  '''
  preds = np.argmax(preds, axis = 1).flatten()
  labels = labels.flatten()
  tp = b_tp(preds, labels)

  tn = b_tn(preds, labels)
  fp = b_fp(preds, labels)
  fn = b_fn(preds, labels)
  b_accuracy = (tp + tn) / len(labels)
  b_precision = tp / (tp + fp) if (tp + fp) > 0 else 'nan'
  b_recall = tp / (tp + fn) if (tp + fn) > 0 else 'nan'
  f1 = 2 * tp / (2 * tp + fp + fn)
  metrics =  {'f1': f1,
               'accuracy': b_accuracy,
               'precision': b_precision,
               'recall': b_recall}
  return metrics

# Function to be called by the trainer
def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions,
            tuple) else p.predictions
    result = b_metrics(
        preds=preds,
        labels=p.label_ids)
    return result

Use the all in one Trainer of huggingface for finetuning

In [None]:

trainer = Trainer(
    model,
    args,
    train_dataset=encoded_train,
    eval_dataset=encoded_test,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)
#trainer.train("imdb") # for pretrained model
trainer.train()

The following columns in the training set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 25000
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 15625
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


  0%|          | 0/15625 [00:00<?, ?it/s]

{'loss': 0.4275, 'learning_rate': 1.936e-05, 'epoch': 0.16}
{'loss': 0.3626, 'learning_rate': 1.8720000000000004e-05, 'epoch': 0.32}
{'loss': 0.3528, 'learning_rate': 1.8080000000000003e-05, 'epoch': 0.48}
{'loss': 0.3536, 'learning_rate': 1.7440000000000002e-05, 'epoch': 0.64}
{'loss': 0.358, 'learning_rate': 1.6800000000000002e-05, 'epoch': 0.8}
{'loss': 0.348, 'learning_rate': 1.616e-05, 'epoch': 0.96}


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 25000
  Batch size = 8


  0%|          | 0/3125 [00:00<?, ?it/s]

Saving model checkpoint to bert-imdb-classification/checkpoint-3125
Configuration saved in bert-imdb-classification/checkpoint-3125/config.json


{'eval_loss': 0.28213560581207275, 'eval_f1': 0.8837697179973373, 'eval_accuracy': 0.88476, 'eval_precision': 0.8914299666313991, 'eval_recall': 0.87624, 'eval_runtime': 138.1343, 'eval_samples_per_second': 180.983, 'eval_steps_per_second': 22.623, 'epoch': 1.0}


Model weights saved in bert-imdb-classification/checkpoint-3125/pytorch_model.bin
tokenizer config file saved in bert-imdb-classification/checkpoint-3125/tokenizer_config.json
Special tokens file saved in bert-imdb-classification/checkpoint-3125/special_tokens_map.json


{'loss': 0.2584, 'learning_rate': 1.552e-05, 'epoch': 1.12}
{'loss': 0.2522, 'learning_rate': 1.4880000000000002e-05, 'epoch': 1.28}
{'loss': 0.2445, 'learning_rate': 1.4240000000000001e-05, 'epoch': 1.44}
{'loss': 0.2464, 'learning_rate': 1.3600000000000002e-05, 'epoch': 1.6}
{'loss': 0.2327, 'learning_rate': 1.2960000000000001e-05, 'epoch': 1.76}
{'loss': 0.2555, 'learning_rate': 1.232e-05, 'epoch': 1.92}


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 25000
  Batch size = 8


  0%|          | 0/3125 [00:00<?, ?it/s]

Saving model checkpoint to bert-imdb-classification/checkpoint-6250
Configuration saved in bert-imdb-classification/checkpoint-6250/config.json


{'eval_loss': 0.4703359305858612, 'eval_f1': 0.889341584348695, 'eval_accuracy': 0.88484, 'eval_precision': 0.8558851816231412, 'eval_recall': 0.92552, 'eval_runtime': 145.0214, 'eval_samples_per_second': 172.388, 'eval_steps_per_second': 21.549, 'epoch': 2.0}


Model weights saved in bert-imdb-classification/checkpoint-6250/pytorch_model.bin
tokenizer config file saved in bert-imdb-classification/checkpoint-6250/tokenizer_config.json
Special tokens file saved in bert-imdb-classification/checkpoint-6250/special_tokens_map.json


{'loss': 0.1775, 'learning_rate': 1.168e-05, 'epoch': 2.08}
{'loss': 0.1128, 'learning_rate': 1.1040000000000001e-05, 'epoch': 2.24}
{'loss': 0.1427, 'learning_rate': 1.04e-05, 'epoch': 2.4}
{'loss': 0.1298, 'learning_rate': 9.760000000000001e-06, 'epoch': 2.56}
{'loss': 0.149, 'learning_rate': 9.12e-06, 'epoch': 2.72}
{'loss': 0.1278, 'learning_rate': 8.48e-06, 'epoch': 2.88}


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 25000
  Batch size = 8


  0%|          | 0/3125 [00:00<?, ?it/s]

Saving model checkpoint to bert-imdb-classification/checkpoint-9375
Configuration saved in bert-imdb-classification/checkpoint-9375/config.json


{'eval_loss': 0.5590696334838867, 'eval_f1': 0.8859392758564703, 'eval_accuracy': 0.8876, 'eval_precision': 0.8992254449571523, 'eval_recall': 0.87304, 'eval_runtime': 138.8808, 'eval_samples_per_second': 180.01, 'eval_steps_per_second': 22.501, 'epoch': 3.0}


Model weights saved in bert-imdb-classification/checkpoint-9375/pytorch_model.bin
tokenizer config file saved in bert-imdb-classification/checkpoint-9375/tokenizer_config.json
Special tokens file saved in bert-imdb-classification/checkpoint-9375/special_tokens_map.json


{'loss': 0.1079, 'learning_rate': 7.840000000000001e-06, 'epoch': 3.04}
{'loss': 0.0606, 'learning_rate': 7.2000000000000005e-06, 'epoch': 3.2}
{'loss': 0.081, 'learning_rate': 6.560000000000001e-06, 'epoch': 3.36}
{'loss': 0.0662, 'learning_rate': 5.92e-06, 'epoch': 3.52}
{'loss': 0.0594, 'learning_rate': 5.28e-06, 'epoch': 3.68}
{'loss': 0.0678, 'learning_rate': 4.6400000000000005e-06, 'epoch': 3.84}


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 25000
  Batch size = 8


{'loss': 0.0513, 'learning_rate': 4.000000000000001e-06, 'epoch': 4.0}


  0%|          | 0/3125 [00:00<?, ?it/s]

Saving model checkpoint to bert-imdb-classification/checkpoint-12500
Configuration saved in bert-imdb-classification/checkpoint-12500/config.json


{'eval_loss': 0.6741758584976196, 'eval_f1': 0.8881728129848431, 'eval_accuracy': 0.88756, 'eval_precision': 0.8833583920234233, 'eval_recall': 0.89304, 'eval_runtime': 140.1278, 'eval_samples_per_second': 178.409, 'eval_steps_per_second': 22.301, 'epoch': 4.0}


Model weights saved in bert-imdb-classification/checkpoint-12500/pytorch_model.bin
tokenizer config file saved in bert-imdb-classification/checkpoint-12500/tokenizer_config.json
Special tokens file saved in bert-imdb-classification/checkpoint-12500/special_tokens_map.json


{'loss': 0.0273, 'learning_rate': 3.3600000000000004e-06, 'epoch': 4.16}
{'loss': 0.0401, 'learning_rate': 2.7200000000000002e-06, 'epoch': 4.32}
{'loss': 0.022, 'learning_rate': 2.08e-06, 'epoch': 4.48}
{'loss': 0.0256, 'learning_rate': 1.44e-06, 'epoch': 4.64}
{'loss': 0.0258, 'learning_rate': 8.000000000000001e-07, 'epoch': 4.8}
{'loss': 0.0299, 'learning_rate': 1.6e-07, 'epoch': 4.96}


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 25000
  Batch size = 8


  0%|          | 0/3125 [00:00<?, ?it/s]

Saving model checkpoint to bert-imdb-classification/checkpoint-15625
Configuration saved in bert-imdb-classification/checkpoint-15625/config.json


{'eval_loss': 0.7640025019645691, 'eval_f1': 0.8867014865407794, 'eval_accuracy': 0.8872, 'eval_precision': 0.890637610976594, 'eval_recall': 0.8828, 'eval_runtime': 140.3425, 'eval_samples_per_second': 178.136, 'eval_steps_per_second': 22.267, 'epoch': 5.0}


Model weights saved in bert-imdb-classification/checkpoint-15625/pytorch_model.bin
tokenizer config file saved in bert-imdb-classification/checkpoint-15625/tokenizer_config.json
Special tokens file saved in bert-imdb-classification/checkpoint-15625/special_tokens_map.json


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from bert-imdb-classification/checkpoint-6250 (score: 0.889341584348695).


{'train_runtime': 3251.3895, 'train_samples_per_second': 38.445, 'train_steps_per_second': 4.806, 'train_loss': 0.1664598349761963, 'epoch': 5.0}


TrainOutput(global_step=15625, training_loss=0.1664598349761963, metrics={'train_runtime': 3251.3895, 'train_samples_per_second': 38.445, 'train_steps_per_second': 4.806, 'train_loss': 0.1664598349761963, 'epoch': 5.0})

We have overfitting after the second epoch (checkpoint-6250) and with just 2 epochs of finetuning we achieve an f1-score of 0.89 and an accuracy of 0.885. This shows how powerful is Bert for simple classification tasks.

If we wanted to do inference on the fine-tuned model we could do it like this:

In [None]:
text = "This movie was very beatiful and sad"

encoding = tokenizer(text, return_tensors="pt")
encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}

outputs = trainer.model(**encoding)
logits = outputs.logits
preds = torch.argmax(logits, axis = 1).flatten()
print(preds.item())

tensor([1], device='cuda:0')


# Exercise 3.2: Training a Question Answering model


I train a question answering model on the [medmcqa](https://github.com/medmcqa/medmcqa) dataset, a really difficult multiple choice answer dataset. The current best accuracy is 0.72 by Med-Palm 2 and the expected accuracy of a bert-base model is 0.33 (Just slightly above a random choice)
I still try fine-tuning distill-bert-cased for this task to check if I can reproduce the results

In [None]:
from datasets import load_dataset
import numpy as np
dataset = load_dataset("medmcqa")

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

We define a preprocessing function similar to the first exercise but this time we encode a pair of (question,answer) for every choice, then we set the correct label as classification objective.

In [None]:
endings = ["opa","opb","opc","opd"]
def preprocess_function(examples):

    first_sentences = [[context] * 4 for context in examples["question"]]
    second_sentences = [
        [examples[end][i] for end in endings] for i, _ in enumerate(examples['id'])
    ]
    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])
    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
    encoding = {k: [v[i: i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}
    encoding['label'] = examples['cop']
    return encoding

Encode dataset

In [None]:
tokenized_dataset = dataset.map(preprocess_function, batched=True)

We define a data collator to pad the multiple choices

In [None]:
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch

#Adapted from https://huggingface.co/docs/transformers/tasks/multiple_choice
@dataclass
class DataCollatorForMultipleChoice:

    tokenizer: PreTrainedTokenizerBase

    padding: Union[bool, str, PaddingStrategy] = True

    max_length: Optional[int] = None

    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label"
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        return batch

Define metrics for evaluation

In [None]:
import evaluate

accuracy = evaluate.load("accuracy")
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)


from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer

model = AutoModelForMultipleChoice.from_pretrained("distilbert-base-cased")

Setup training arguments

In [None]:
training_args = TrainingArguments(
    output_dir="medqa_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=6,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    compute_metrics=compute_metrics,
)

In [None]:
#trainer.train('medqa_model') #for pretrained model
trainer.train()
trainer.save_state()


After fine-tuning for 5 epochs we obtain an accuracy on the validation set of 0.321, this is in line with the bert-uncased training cited in the paper.

# Exercise 3.3: Training a Retrieval Model
I approached text retrieval from an interesting standpoint: Using a Sequence-to-Sequence model.<br>
The model used is the recently proposed Flan-T5 [[Scaling Instruction-Finetuned Language Models](https://arxiv.org/abs/2210.11416)], an improvement over the commonly used T5 model [[Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683)] from Google.<br>
Searching for ways to implement text retrieval, I found the one described in the paper [[Document Ranking with a Pretrained Sequence-to-Sequence Model](https://aclanthology.org/2020.findings-emnlp.63/)] particurarly ingenious. <br>
The entire premise is training the model to complete a sentence of the type:<br>
`Query [Q] Document [D] Relevant:` <br>
The model is trained to predict "true" or "false" and the score associated is the softmax of the logits "true" and "false":
$$ Pr(relevant = 1 \vert q,d)$$
The training dataset is a subset of MS Marco from the BEIR benchmark on information retrieval tasks.

In [None]:
# Import libraries
from transformers import T5Tokenizer, T5ForConditionalGeneration, PreTrainedModel, PreTrainedTokenizer
from datasets import load_dataset
from copy import deepcopy
import heapq
import os
import pickle
import numpy as np
from collections.abc import Iterable
from tqdm.auto import tqdm
import torch
from torch.utils.data import Dataset
import argparse
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
from typing import List, Optional, Union, Mapping, Any
from torch.utils.data import Dataset
from dataclasses import dataclass

I use some helper classes from PyGaggle to simplify the reranking procedure

In [None]:
# Define helper classes adapted from https://github.com/castorini/pygaggle/
class MonoT5Dataset(Dataset):
    def __init__(self, data):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        text = f'Query: {sample[0]} Document: {sample[1]} Relevant:'
        return {
            'labels': sample[2],
            'input_ids': text,
        }


class Text:
    def __init__(self,
                 text: str,
                 metadata: Mapping[str, Any] = None,
                 score: Optional[float] = 0,
                 title: Optional[str] = None):
        self.text = text
        if metadata is None:
            metadata = dict()
        self.metadata = metadata
        self.score = score
        self.title = title

class Query:
    def __init__(self, text: str, id: Optional[str] = None):
        self.text = text
        self.id = id

TokenizerReturnType = Mapping[str, Union[torch.Tensor, List[int],
                                         List[List[int]],
                                         List[List[str]]]]

@dataclass
class QueryDocumentBatch:
    query: Query
    documents: List[Text]
    output: Optional[TokenizerReturnType] = None

    def __len__(self):
        return len(self.documents)


To train a reranking model we need to generate some negative samples because the model can't learn from only positive examples.

In [None]:
# Training
def train(args):
    device = torch.device('cuda')
    torch.manual_seed(123)
    # Use flan-t5-base instead of t5
    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
    model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", device_map="auto")

    # Load datasets
    # Contains queries
    queries = load_dataset("BeIR/msmarco", 'queries')['queries']
    # Contains documents
    corpus = load_dataset("BeIR/msmarco", 'corpus')['corpus']
    # Contains pre-computed relevance between queries and documents
    qrels = load_dataset("BeIR/msmarco-qrels", split="train")

    # Count number of occurance of feature in dataset
    def count_feature(dataset, feature_name):
        count_dict = {}
        for data in dataset:
            if data[feature_name] in count_dict:
                count_dict[data[feature_name]] += 1
            else:
                count_dict[data[feature_name]] = 0
        return count_dict

    # Get topk elements of dict
    def get_topk(count_dict, k):
        topk = heapq.nlargest(k, count_dict, key=count_dict.get)
        return topk


    # Generates samples
    def generate_samples(out_path):

        train_samples = []
        # If pickle already exists, don't recompute
        if os.path.exists(out_path):
            with open(out_path, "rb") as fp:
                train_samples = pickle.load(fp)
        else:
            # Count number of times documents appears in the relation db
            count_docs_rel = count_feature(qrels, 'corpus-id')
            # Get top 10000 cited documents
            doc_ids = get_topk(count_docs_rel, 10000)

            for rel in tqdm(qrels, desc="Generating samples:"):
                query = rel['query-id']
                positive = rel['corpus-id']
                train_samples.append((query, positive, 'true'))
                # Generate random negative sample
                negatives = np.random.choice(doc_ids)
                train_samples.append((query, negatives, 'false'))
            with open(out_path, "wb") as fp:
                pickle.dump(train_samples, fp)
        return train_samples

    def get_texts_from_ids(dataset, column_name):
        q_pd = dataset.to_pandas()
        q_pd['_id'] = q_pd['_id'].astype(int)
        return q_pd.loc[q_pd['_id'].isin(qrels[column_name])]['text'].to_list()

    id_samples = generate_samples('./samples.pkl')
    q_pd = queries.to_pandas()
    q_pd['_id'] = q_pd['_id'].astype(int)
    c_pd = corpus.to_pandas()
    c_pd['_id'] = c_pd['_id'].astype(int)
    def get_q(id):
        return q_pd.loc[q_pd['_id'] == id]['text'].item()
    def get_c(id):
        return c_pd.loc[c_pd['_id'] == id]['text'].item()

    train_samples = []
    # Get already generated train samples
    if os.path.exists('train_samples.pkl'):
        with open('train_samples.pkl', "rb") as fp:
            train_samples = pickle.load(fp)

    # Custom function for batch collate
    def smart_batching_collate_text_only(batch):
        texts = [example['input_ids'] for example in batch]
        tokenized = tokenizer(texts, padding=True, truncation='longest_first', return_tensors='pt', max_length=512)
        tokenized['labels'] = tokenizer([example['labels'] for example in batch], return_tensors='pt')['input_ids']

        for name in tokenized:
            tokenized[name] = tokenized[name].to(device)

        return tokenized

    dataset_train = MonoT5Dataset(train_samples)

    if args.save_every_n_steps:
        steps = args.save_every_n_steps
        strategy = 'steps'
    else:
        steps = 1
        strategy = 'epoch'

    train_args = Seq2SeqTrainingArguments(
        output_dir=args.output_model_path,
        do_train=True,
        save_strategy=strategy,
        save_steps = steps,
        logging_steps=args.logging_steps,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.learning_rate,
        weight_decay=5e-5,
        num_train_epochs=1,
        warmup_steps=1000,
        seed=1,
        disable_tqdm=False,
        load_best_model_at_end=False,
        predict_with_generate=True,
        dataloader_pin_memory=False,
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=train_args,
        train_dataset=dataset_train,
        tokenizer=tokenizer,
        data_collator=smart_batching_collate_text_only,

    )

    trainer.train()

    trainer.save_model(args.output_model_path)
    trainer.save_state()


In [None]:
# Helper classes for inference
class TokenizerEncodeMixin:
    tokenizer: PreTrainedTokenizer = None
    tokenizer_kwargs = None

    def encode(self, strings: List[str]):
        assert self.tokenizer and self.tokenizer_kwargs is not None, \
                'mixin used improperly'
        ret = self.tokenizer.batch_encode_plus(strings,
                                               **self.tokenizer_kwargs)
        ret['tokens'] = list(map(self.tokenizer.tokenize, strings))
        return ret

class QueryDocumentBatchTokenizer(TokenizerEncodeMixin):
    def __init__(self,
                 tokenizer: PreTrainedTokenizer,
                 batch_size: int,
                 pattern: str = '{query} {document}',
                 **tokenizer_kwargs):
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.tokenizer_kwargs = tokenizer_kwargs
        self.pattern = pattern

    def traverse_query_document(
            self,
            batch_input: QueryDocumentBatch) -> Iterable[QueryDocumentBatch]:
        query = batch_input.query
        for batch_idx in range(0, len(batch_input), self.batch_size):
            docs = batch_input.documents[batch_idx:batch_idx + self.batch_size]
            outputs = self.encode([self.pattern.format(
                                        query=query.text,
                                        document=doc.text) for doc in docs])
            yield QueryDocumentBatch(query, docs, outputs)

class T5BatchTokenizer(QueryDocumentBatchTokenizer):
    def __init__(self, *args, **kwargs):
        kwargs['pattern'] = 'Query: {query} Document: {document} Relevant:'
        if 'return_attention_mask' not in kwargs:
            kwargs['return_attention_mask'] = True
        if 'padding' not in kwargs:
            kwargs['padding'] = 'longest'
        if 'truncation' not in kwargs:
            kwargs['truncation'] = True
        if 'return_tensors' not in kwargs:
            kwargs['return_tensors'] = 'pt'
        if 'max_length' not in kwargs:
            kwargs['max_length'] = 512
        super().__init__(*args, **kwargs)


class MonoT5():
    def __init__(self,
                 pretrained_model_name_or_path: str  = 'castorini/monot5-base-msmarco-10k',
                 model: T5ForConditionalGeneration = None,
                 tokenizer: QueryDocumentBatchTokenizer = None,
                 token_false = None,
                 token_true  = None):
        self.model = model or self.get_model(pretrained_model_name_or_path)
        self.tokenizer = tokenizer or self.get_tokenizer(pretrained_model_name_or_path)
        self.token_false_id, self.token_true_id = self.get_prediction_tokens(
                pretrained_model_name_or_path, self.tokenizer, token_false, token_true)
        self.pretrained_model_name_or_path = pretrained_model_name_or_path
        self.device = next(self.model.parameters(), None).device

    @staticmethod
    def get_model(pretrained_model_name_or_path: str,
                  *args, device: str = None, **kwargs) -> T5ForConditionalGeneration:
        device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        device = torch.device(device)
        return AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path,
                                                          *args, **kwargs).to(device).eval()

    @staticmethod
    def get_tokenizer(pretrained_model_name_or_path: str,
                      *args, batch_size: int = 8, **kwargs) -> T5BatchTokenizer:
        return T5BatchTokenizer(
            AutoTokenizer.from_pretrained(pretrained_model_name_or_path, use_fast=False, *args, **kwargs),
            batch_size=batch_size
        )
    @staticmethod
    def get_prediction_tokens(pretrained_model_name_or_path: str,
            tokenizer, token_false, token_true):
        if not (token_false and token_true):
            if pretrained_model_name_or_path in prediction_tokens:
                token_false, token_true = prediction_tokens[pretrained_model_name_or_path]
                token_false_id = tokenizer.tokenizer.get_vocab()[token_false]
                token_true_id  = tokenizer.tokenizer.get_vocab()[token_true]
                return token_false_id, token_true_id
            else:
                raise Exception(f"We don't know the indexes for the non-relevant/relevant tokens for\
                        the checkpoint {pretrained_model_name_or_path} and you did not provide any.")
        else:
            token_false_id = tokenizer.tokenizer.get_vocab()[token_false]
            token_true_id  = tokenizer.tokenizer.get_vocab()[token_true]
            return token_false_id, token_true_id


    def rescore(self, query: Query, texts: List[Text]) -> List[Text]:
        texts = deepcopy(texts)
        batch_input = QueryDocumentBatch(query=query, documents=texts)
        for batch in self.tokenizer.traverse_query_document(batch_input):

            input_ids = batch.output['input_ids'].to(self.device)
            attn_mask = batch.output['attention_mask'].to(self.device)
            batch_scores = greedy_decode(self.model,
                                            input_ids,
                                            length=1,
                                            attention_mask=attn_mask)

            batch_scores = batch_scores[:, [self.token_false_id, self.token_true_id]]
            # Added temperature of 1.5 to increase entropy and separate better the results
            # Flan-t5 model has more uniform logits than t5
            batch_scores = torch.nn.functional.log_softmax(batch_scores/1.5, dim=1)
            batch_log_probs = batch_scores[:, 1].tolist()
            for doc, score in zip(batch.documents, batch_log_probs):
                doc.score = score

        return texts

    def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
        return sorted(self.rescore(query, texts), key=lambda x: x.score, reverse=True)

@torch.no_grad()
def greedy_decode(model: PreTrainedModel,
                  input_ids: torch.Tensor,
                  length: int,
                  attention_mask: torch.Tensor = None):
    decode_ids = torch.full((input_ids.size(0), 1),
                            model.config.decoder_start_token_id,
                            dtype=torch.long).to(input_ids.device)
    encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask)
    next_token_logits = None
    for _ in range(length):
        model_inputs = model.prepare_inputs_for_generation(
            decode_ids,
            encoder_outputs=encoder_outputs,
            past=None,
            attention_mask=attention_mask,
            use_cache=True)
        outputs = model(**model_inputs)  # (batch_size, cur_len, vocab_size)
        next_token_logits = outputs[0][:, -1, :]  # (batch_size, vocab_size)
    return next_token_logits


In [None]:
# Reranking
# Define a query
query = Query('who proposed the geocentric theory')
# Load subset of corpus
corpus = load_dataset("BeIR/msmarco", 'corpus', split='corpus[0:1000]')
# Define a document where the answer is
passages = [['7744105',
             'For Earth-centered it was  Geocentric Theory proposed by greeks under the guidance of Ptolemy and Sun-centered was Heliocentric theory proposed by Nicolas Copernicus in 16th century A.D. In short, Your Answers are: 1st blank - Geo-Centric Theory. 2nd blank - Heliocentric Theory.']]

# Dataset of Text from corpus
texts = [Text(p['text'], {'docid': p['_id']}, 0) for p in tqdm(corpus)]
# Add corpus with high relevance
texts.extend([Text(p[1], {'docid': p[0]}, 0) for p in passages])
# Load trained model
model = T5ForConditionalGeneration.from_pretrained("./rerank", device_map="auto")

tokenizer = T5BatchTokenizer(
        AutoTokenizer.from_pretrained("google/flan-t5-base", use_fast=False),
        batch_size=2)
# Objective tokens for the msmarco-10k dataset
token_false = '▁false'
token_true = '▁true'
# Define reranker
reranker = MonoT5('castorini/monot5-base-msmarco-10k',model=model, tokenizer=tokenizer, token_false = token_false, token_true = token_true)
# Inference
reranked = reranker.rerank(query, texts)

# Print out reranked results:
for i in range(0, 10):
    print(f'{i + 1:2} {reranked[i].metadata["docid"]:15} {reranked[i].score:.5f} {reranked[i].text}')

Found cached dataset msmarco (/home/muduard/.cache/huggingface/datasets/BeIR___msmarco/corpus/0.0.0/093f1fe2ffa7a9c72fa48239c8f279b51d6b171abd77737c7fd1406125307599)


  0%|          | 0/1000 [00:00<?, ?it/s]

loading configuration file ./checkpoint/config.json
Model config T5Config {
  "_name_or_path": "google/flan-t5-base",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length": 30,
      "no_repeat_ngram_size": 3,
      "num_beams": 4,
      "prefix": "summarize: "
    },
    "translation_en_to_de": {
 

 1 7744105         -0.00019 For Earth-centered it was  Geocentric Theory proposed by greeks under the guidance of Ptolemy and Sun-centered was Heliocentric theory proposed by Nicolas Copernicus in 16th century A.D. In short, Your Answers are: 1st blank - Geo-Centric Theory. 2nd blank - Heliocentric Theory.
 2 247             -0.04919 Keynesian economics gets its name, theories, and principles from British economist John Maynard Keynes (1883â1946), who is regarded as the founder of modern macroeconomics. His most famous work, The General Theory of Employment, Interest and Money, was published in 1936.
 3 248             -0.31595 DEFINITION of 'Keynesian Economics'. An economic theory of total spending in the economy and its effects on output and inflation. Keynesian economics was developed by the British economist John Maynard Keynes during the 1930s in an attempt to understand the Great Depression. Keynes advocated increased government expenditures and lower taxes to stimulate demand

We can see that the correct document is first in the ranking but the others aren't separated enough, I ran a more formal benchmark according to https://github.com/castorini/pygaggle/blob/master/docs/experiments-msmarco-document.md.


| metric | base | ours |
| --- | --- | --- |
| precision@1 | 0.2 | 0.08 |
| recall@3 | 0.56 | 0.36 |
| recall@50 | 0.84 | 0.76 |
| recall@1000 | 0.88 | 0.88 |
| mrr | 0.38882 | 0.24596 |
| mrr@10 | 0.38271 | 0.23683 |

From this we can see that on the entire corpus the new trained model has difficulty to find the correct best document but because the Mean Reciprocal Rank (MRR) is not too different it means that it's still in the top ranked documents.
These results are actually very good, our model is trained only for 1 epoch and batch size 2 meanwhile the paper model is trained for 10 epochs and batch size 128.
If we check the ablation study on the MRR with respect to the number of epochs we see that our model achieves perfomance comparable to T5-3B while having 10x times lesser parameters!
This shows how much better is the new flan-t5 model compared to the previous one and that the implementation is correct.
