# **Install and Imports**

In [None]:
!pip install transformers datasets scikit-learn

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW, get_scheduler
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score
import numpy as np
from datasets import load_dataset

# **Load dataset and Create splits**
### Sample 10% of training data for validation set as done in the original paper

In [None]:
ds = load_dataset("Salesforce/cos_e", "v1.11", split={"train": "train", "test": "validation"})

def modify_example(example):
    choice_labels = ['(a)', '(b)', '(c)', '(d)', '(e)']
    formatted_choices = "\n".join([f"{choice_labels[i]} {choice}" for i, choice in enumerate(example["choices"])])
    input_text = f"{example['question']}\nAnswer Choices:\n{formatted_choices}"
    return {
        "input": input_text,
        "label": example["answer"]
    }

dataset = {split: data.map(modify_example, remove_columns=['id', 'question', 'choices', 'answer', 'abstractive_explanation', 'extractive_explanation']) for split, data in ds.items()}

# Create validation split from original training data
val_dataset = dataset["train"]
val_size = int(0.1 * len(val_dataset))  # 10% for validation
train_size = len(val_dataset) - val_size

# Create the split
train_set, val_set = random_split(
    val_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

In [None]:
data_splits = {'train': train_set, 'val': val_set, 'test': dataset['test']}

for split_name, split_data in data_splits.items():
    print(f"Length of {split_name}_set:", len(split_data))
    print(f"First element of {split_name}_set:", split_data[0])

Length of train_set: 8767
First element of train_set: {'input': 'Where might someone keep personal soap?\nAnswer Choices:\n(a) birthday party\n(b) supermarket\n(c) own home\n(d) jail\n(e) cabinet', 'label': 'own home'}
Length of val_set: 974
First element of val_set: {'input': 'What do you have to do to learn to play violin?\nAnswer Choices:\n(a) tune\n(b) practise\n(c) relaxing\n(d) ask questions\n(e) take lessons', 'label': 'take lessons'}
Length of test_set: 1221
First element of test_set: {'input': 'A beaver is know for building prowess, their supplies come from where?\nAnswer Choices:\n(a) british columbia\n(b) body of water\n(c) wooded area\n(d) pay debts\n(e) zoo', 'label': 'wooded area'}


# **Custom Dataset Class**

In [None]:
class T5FineTuningDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=128):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]

        # Tokenize inputs
        input_encoding = self.tokenizer(
            example['input'],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Tokenize labels
        label_encoding = self.tokenizer(
            example['label'],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Replace padding token ID with -100
        labels = label_encoding.input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_encoding.input_ids.squeeze(),
            'attention_mask': input_encoding.attention_mask.squeeze(),
            'labels': labels.squeeze()
        }

# **Validation Function**

In [None]:
def validate_model(model, val_loader, device):
    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_val_loss += loss.item()

    return total_val_loss / len(val_loader)

# **Train the model**

In [None]:
def train_model(train_set, val_set, model_name="google-t5/t5-small", max_length=1024,
                batch_size=16, num_epochs=25, learning_rate=3e-5, weight_decay=0.01):

    # Load model and tokenizer
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    device = "cuda"
    model.to(device)

    # Prepare training and validation datasets
    train_dataset = T5FineTuningDataset(train_set, tokenizer, max_length)
    val_dataset = T5FineTuningDataset(val_set, tokenizer, max_length)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    total_steps = len(train_loader) * num_epochs

    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )

    # Initialize variables to track best model
    best_val_loss = float('inf')
    best_model_state = None

    # Training loop
    print("Starting training...")
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Training
        model.train()
        total_train_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")

        for batch in progress_bar:
            # Move batch to device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            total_train_loss += loss.item()

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            progress_bar.set_postfix({"loss": loss.item()})

        avg_train_loss = total_train_loss / len(train_loader)
        print(f"Epoch {epoch+1} - Average training loss: {avg_train_loss:.4f}")

        # Validation
        print("Running validation...")
        val_loss = validate_model(model, val_loader, device)
        print(f"Epoch {epoch+1} - Validation loss: {val_loss:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            print(f"New best model saved with validation loss: {val_loss:.4f}")

    # Load best model for final evaluation
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model with validation loss: {best_val_loss:.4f}")

    print("Training completed!")
    return model, tokenizer

In [None]:
model, tokenizer = train_model(train_set, val_set)

# **Evaluate the model**

In [None]:
def evaluate_model(model, tokenizer, test_dataset, max_length=1024, batch_size=16):
    device = "cuda"
    model.to(device)

    # Prepare test dataset
    test_dataset_processed = T5FineTuningDataset(test_dataset, tokenizer, max_length)
    test_loader = DataLoader(test_dataset_processed, batch_size=batch_size)

    model.eval()
    predictions = []
    references = []

    print("Starting evaluation...")
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            # Generate predictions
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=max_length
            )

            # Decode predictions and references
            preds = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

            # Replace -100 with pad token id first
            label_ids = labels.clone()
            label_ids[label_ids == -100] = tokenizer.pad_token_id
            refs = [tokenizer.decode(label, skip_special_tokens=True) for label in label_ids]

            predictions.extend(preds)
            references.extend(refs)

    # Calculate accuracy
    accuracy = accuracy_score(references, predictions) * 100
    return accuracy, predictions, references

In [None]:
accuracy, predictions, references = evaluate_model(model, tokenizer, dataset["test"])

# **Results**

In [None]:
def display_evaluation_results(accuracy, predictions, references, num_samples=5):

    print(f"Final Evaluation Accuracy: {accuracy:.2f}%")

    # Display some examples
    print("\nSample predictions:")
    for i in range(min(num_samples, len(predictions))):
        print(f"Reference: {references[i]}")
        print(f"Prediction: {predictions[i]}")
        print("-" * 50)

In [None]:
display_evaluation_results(accuracy, predictions, references, num_samples=5)

Final Evaluation Accuracy: 41.20%

Sample predictions:
Reference: wooded area
Prediction: wooded area
--------------------------------------------------
Reference: go downtown
Prediction: east
--------------------------------------------------
Reference: play tag
Prediction: play tag
--------------------------------------------------
Reference: great outdoors
Prediction: fairytale
--------------------------------------------------
Reference: appliance store
Prediction: appliance store
--------------------------------------------------
