# SWAG MCQ Reasoning with BERT and In-Context Learning

This notebook implements **Task 3** from the *Introduction to Data Science - Assignment 5&6*. The goal is to explore how **Large Language Models (LLMs)**, specifically BERT, can solve **multiple-choice commonsense reasoning questions** from the **SWAG dataset**.

Traditional machine learning models struggle with nuanced semantic inference tasks. Here, we utilize modern LLM techniques—including **zero-shot prediction**, **in-context learning (ICL)**, and **fine-tuning with LoRA**—to handle the task effectively.

### 🔍 Task Overview
We work with the **SWAG dataset** (113k+ MCQs) to:
- Preprocess and tokenize inputs in multiple-choice format.
- Evaluate BERT in a zero-shot setting.
- Apply in-context learning using prompt engineering.
- Fine-tune the BERT model with LoRA for performance gains.
- Compare model performance across zero-shot, ICL, and fine-tuned configurations.

### 🧠 Learning Objective
- Understand the limitations of classical models for reasoning tasks.
- Apply prompt-based and fine-tuning strategies for LLMs.
- Use HuggingFace Transformers to load, preprocess, evaluate, and fine-tune models.
- Analyze results using accuracy, confusion matrix, and perplexity.

**Dataset:** [SWAG - Situations With Adversarial Generations](https://huggingface.co/datasets/allenai/swag)
            
### 📌 Note
Make sure to
- Use your Hugging Face token for access.
- Run on Kaggle for reliable GPU access and smoother performance. 


In [None]:
# 1) Uninstall any leftovers
!pip uninstall -y transformers tokenizers

# 2) Install the last version before the helpers were removed
!pip install --no-cache-dir transformers==4.47.0

# 3) Now install your other libraries (they won’t bump Transformers because 4.47.0 satisfies PEFT)
!pip install --no-cache-dir peft evaluate packaging


In [None]:
print("hello world")

## 1. Hugging Face Access Token

In [None]:
# Hugging Face Login
from huggingface_hub import notebook_login, login
notebook_login()

## 2. Loading Dataset

In [None]:
from datasets import load_dataset
import pandas as pd

# Load the SWAG dataset
dataset = load_dataset("swag", "regular")

df_train = dataset['train'].to_pandas()
df_val = dataset['validation'].to_pandas()
df_test = dataset['test'].to_pandas()

# adding a column to specify the split
df_train['split'] = 'train'
df_val['split'] = 'validation'
df_test['split'] = 'test'
# Concatenate all into one DataFrame
df = pd.concat([df_train, df_val, df_test], ignore_index=True)

print(dataset)
dataset["train"][0]
  

## 3. Analyza the Dataset

### schema

In [None]:
df.head(4)

In [None]:
df.info()

In [None]:
df.describe(include='all')

### Class distributions

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.countplot(x='label', data=df)
plt.title("Distribution of Correct Answer Labels")
plt.xlabel("Correct Ending (0-3)")
plt.ylabel("Count")
plt.show()
 

### Avg length of sent1, sent2, ending

In [None]:
df['sent1_len'] = df['sent1'].apply(lambda x: len(x.split()))
df['sent2_len'] = df['sent2'].apply(lambda x: len(x.split()))
for i in range(4):
    df[f'ending{i}_len'] = df[f'ending{i}'].apply(lambda x: len(x.split()))
    
# Display averages
print("Average lengths:")
print(f"sent1: {df['sent1_len'].mean():.2f} words")
print(f"sent2: {df['sent2_len'].mean():.2f} words")
for i in range(4):
    print(f"ending{i}: {df[f'ending{i}_len'].mean():.2f} words")
    
# Display max
print("Max lengths:")
print(f"sent1: {df['sent1_len'].max():.2f} words")
print(f"sent2: {df['sent2_len'].max():.2f} words")
for i in range(4):    
    print(f"ending{i}: {df[f'ending{i}_len'].max():.2f} words")
  

### One full mcq example

In [None]:
i = 0
print("sent1:", df.loc[i, 'sent1'])
print("sent2:", df.loc[i, 'sent2'])
for j in range(4):
    print(f"ending{j}:", df.loc[i, f'ending{j}'])
    print("Correct label:", df.loc[i, 'label'])
  

## 4. Load Tokenizer

In [None]:
from transformers import AutoTokenizer
# loading tokenizer
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")  

## Pre-Process the Dataset

In [None]:
def preprocess_swag(batch):
    first_sentences = [s1 for s1 in batch['sent1'] for _ in range(4)]
    second_sentences = [
        batch['sent2'][i] + " " + batch[f'ending{j}'][i]
        for i in range(len(batch['sent1']))
        for j in range(4)
    ]
    
    tokenized = tokenizer(
        first_sentences,
        second_sentences,
        truncation=True,
        padding="longest"
        )
    # Unflatten: group every 4 items as a list
    def regroup(values):
        return [values[i:i + 4] for i in range(0, len(values), 4)]
    
    return {key: regroup(val) for key, val in tokenized.items()}

## 6. Apply the preprocessing function

In [None]:
# Apply preprocessing function with batch
tokenized_dataset = {
    split: dataset[split].map(preprocess_swag, batched=True)
    for split in ["train", "validation", "test"]
}
  

In [None]:
# clean and finalize the dataset
for split in ["train", "validation", "test"]:
    if "label" in tokenized_dataset[split].column_names:
        tokenized_dataset[split] = tokenized_dataset[split].rename_column("label", "labels")
        
        keep_cols = ['input_ids', 'attention_mask', 'token_type_ids', 'labels']
        tokenized_dataset[split] = tokenized_dataset[split].remove_columns(
            [col for col in tokenized_dataset[split].column_names if col not in keep_cols]
        )    

## 7. Padding

In [None]:
from transformers import DataCollatorForMultipleChoice

data_collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)

## 8. Load Model

In [None]:
from transformers import AutoModelForMultipleChoice

model = AutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-uncased") 

## 9. Test Model on Dataset

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# This will store (method, accuracy)
icl_results = []

### Test case analysis

In [None]:
import torch

# Device setup (should match your model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Pick a single example to inspect
i = 0  # You can change this index to test other samples
example = tokenized_dataset["validation"][i]
raw_example = dataset["validation"][i]  # un-tokenized version for readable display

# Convert inputs to tensor format (batch size = 1) and move to deviceinput = {
k: torch.tensor([example[k]]).to(device)
for k in ["input_ids", "attention_mask", "token_type_ids"]
}

# Run the model and get prediction
model.eval()
with torch.no_grad():
    outputs = model(**input)
    logits = outputs.logits
    predicted_class = torch.argmax(logits, dim=1).item()
    
# Display the context
context = raw_example["sent1"] + " " + raw_example["sent2"]print(f"📌 Context: {context}")
# Display all choices with prediction and ground truth
print("🔘 Options:")
for j in range(4):
option_text = raw_example[f"ending{j}"]
prefix = "✅" if j == raw_example["label"] else "❌"
marker = "👉" if j == predicted_class else "  "
print(f"{marker} {prefix} Choice {j}: {option_text}")

# Final Summary
print(f"🧠 Model Prediction: Choice {predicted_class}")
print(f"✔️ Ground Truth: Choice {raw_example['label']}")
   

### Full valildation set evaluation

In [None]:
val_loader = DataLoader(
    tokenized_dataset["validation"],
    batch_size=16,
    collate_fn=data_collator)

model.eval()
correct = 0
total = 0

for batch in tqdm(val_loader, desc="Evaluating Baseline"):
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
        predictions = torch.argmax(outputs.logits, dim=1)
        correct += (predictions == batch["labels"]).sum().item()
        total += batch["labels"].size(0)
        
baseline_acc = 100*(correct / total)
print(f"📊 Baseline Accuracy: {baseline_acc:.2f}%")
icl_results.append(("Baseline (No ICL)", baseline_acc))
   

## 10. In-Context Learning

### Few-shot

In [None]:
from transformers import AutoModelForMultipleChoice
import torch
from torch.utils.data import DataLoader
from transformers import DataCollatorForMultipleChoice
from datasets import Dataset
from tqdm import tqdm


# Select 5-shot ICL examples from validation set
fewshot_indices = [0, 1, 2, 3, 4]
fewshot_examples = [dataset["validation"][i] for i in fewshot_indices]

# Build the static few-shot prompt
fewshot_prompt = ""
for ex in fewshot_examples:
    fewshot_prompt += f"Context: {ex['sent1']} {ex['sent2']}\n"
    for j in range(4):
        fewshot_prompt += f"{chr(65+j)}. {ex[f'ending{j}']}\n"
    fewshot_prompt += f"Answer: {chr(65 + ex['label'])}\n\n"


In [None]:
# Define the ICL preprocessing function
def preprocess_fewshot_icl_batch(batch):
    prompts = []
    first_sentences = []
    second_sentences = []

    for i in range(len(batch['sent1'])):
        prompt = fewshot_prompt + f"\nContext: {batch['sent1'][i]} {batch['sent2'][i]}"
        first_sentences.extend([prompt] * 4)
        second_sentences.extend([
            f"{chr(65 + j)}. {batch[f'ending{j}'][i]}" for j in range(4)
        ])

    # print(first_sentences)
    # print(second_sentences)
    tokenized = tokenizer(first_sentences, second_sentences, truncation=True, padding="longest")

    def regroup(values): return [values[i:i+4] for i in range(0, len(values), 4)]
    result = {k: regroup(v) for k, v in tokenized.items()}
    result["labels"] = batch["label"]
    return result


In [None]:
fewshot_icl_dataset = dataset["validation"].map(preprocess_fewshot_icl_batch, batched=True)

In [None]:
# 1) Clean & finalize the few-shot ICL dataset

# If the old 'label' column still exists, drop it\nif "label" in fewshot_icl_dataset.column_names:
fewshot_icl_dataset = fewshot_icl_dataset.remove_columns(["label"])

# Now keep only the required MCQ fields plus 'labels'
keep_cols = ["input_ids", "attention_mask", "token_type_ids", "labels"]
fewshot_icl_dataset = fewshot_icl_dataset.remove_columns(
    [c for c in fewshot_icl_dataset.column_names if c not in keep_cols]
)     

In [None]:
# 2) Create DataLoader with your MCQ collator
fewshot_icl_loader = DataLoader(
    fewshot_icl_dataset,
    batch_size=8,
    collate_fn=data_collator
)

# 3) Evaluate Few-Shot ICL
correct = 0
total = 0
model.eval()

with torch.no_grad():
    for batch in tqdm(fewshot_icl_loader, desc="Evaluating Few-Shot ICL"):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        preds = torch.argmax(outputs.logits, dim=1)
        correct += (preds == batch["labels"]).sum().item()
        total += batch["labels"].size(0)
    
few_shot_accuracy = 100*(correct / total)
print(f"📊 Few-Shot ICL Accuracy: {few_shot_accuracy:.2f}%")
icl_results.append(("Few shot learning (ICL)", few_shot_accuracy))
    

### Zero-shot with chain of thought (COT)

In [None]:
# 1) Build the CoT prompt prefix
cot_prefix = "Let's think step by step.\n\n"

In [None]:
# 2) Define the CoT preprocessing fn
def preprocess_cot_batch(batch):
    first_sentences = []
    second_sentences = []

    for i in range(len(batch["sent1"])):
        # prepend the chain-of-thought instruction before each example
        prompt = cot_prefix + f"Context: {batch['sent1'][i]} {batch['sent2'][i]}"
        # repeat for the 4 choices
        first_sentences.extend([prompt] * 4)
        second_sentences.extend([
            f"{chr(65 + j)}. {batch[f'ending{j}'][i]}" for j in range(4)
        ])

    # tokenize exactly as before
    tokenized = tokenizer(first_sentences, second_sentences,
                          truncation=True, padding="longest")

    # regroup back into (batch_size, 4, seq_len)
    def regroup(vals): return [vals[i : i + 4] for i in range(0, len(vals), 4)]
    result = {k: regroup(v) for k, v in tokenized.items()}
    # carry label forward
    result["labels"] = batch["label"]
    return result


In [None]:
# 3) Map over validation set
cot_dataset = dataset["validation"].map(
    preprocess_cot_batch,
    batched=True,
    remove_columns=dataset["validation"].column_names  # drop raw cols
)


In [None]:
# 4) Clean up columns (drop old 'label' if still there)
if "label" in cot_dataset.column_names:
    cot_dataset = cot_dataset.remove_columns(["label"])
keep = ["input_ids", "attention_mask", "token_type_ids", "labels"]
cot_dataset = cot_dataset.remove_columns(
    [c for c in cot_dataset.column_names if c not in keep]
)

In [None]:
# 5) Create a DataLoader
cot_loader = DataLoader(
    cot_dataset,
    batch_size=8,
    collate_fn=data_collator
)

# 6) Evaluate Zero-Shot CoT
correct = 0
total = 0
model.eval()

with torch.no_grad():
    for batch in tqdm(cot_loader, desc="Evaluating Zero-Shot CoT"):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        preds = torch.argmax(outputs.logits, dim=1)
        correct += (preds == batch["labels"]).sum().item()
        total += batch["labels"].size(0)

cot_accuracy = 100*(correct / total)
print(f"\n📊 Zero-Shot CoT Accuracy: {cot_accuracy:.2f}%")
icl_results.append(("Chain of Thought (ICL)", cot_accuracy))


### Comparison table

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

comparison_df = pd.DataFrame(icl_results, columns=["Method", "Accuracy"])
comparison_df = comparison_df.sort_values("Accuracy", ascending=False).reset_index(drop=True)

plt.figure(figsize=(8, 4))
sns.barplot(data=comparison_df, x="Accuracy", y="Method", palette="crest")
plt.title("Comparison of ICL Methods on SWAG Validation Set")
plt.xlim(0, 100)
plt.xlabel("Validation Accuracy")
plt.tight_layout()
plt.show()


## 11. Fine-Tune Bert

In [None]:
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer, SchedulerType
from peft       import LoraConfig, get_peft_model, TaskType
import numpy    as np
from evaluate   import load


In [None]:
train_ds = tokenized_dataset["train"]
eval_ds  = tokenized_dataset["validation"]

base_model = model
base_model.to(device)

# Config (wrap bert with lora)
lora_cfg = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_CLS,
)
model_with_lora = get_peft_model(base_model, lora_cfg)


### Fine-tuning loop

In [None]:
# Fine-tune visible progress (Transformers 4.47 compatible)

import os
import numpy as np
from evaluate import load
from transformers import (
    TrainingArguments, Trainer,
    IntervalStrategy, SchedulerType,
    TrainerCallback,
)
from transformers.utils import logging as hf_logging

# Make logs visible in the cell; silence wandb noise
hf_logging.set_verbosity_info()
os.environ["WANDB_DISABLED"] = "true"   # or: os.environ["WANDB_MODE"] = "offline"

# Metrics
accuracy_metric = load("accuracy")
def compute_metrics(pred):
    labels = pred.label_ids
    preds  = np.argmax(pred.predictions, axis=1)
    return accuracy_metric.compute(predictions=preds, references=labels)

# Training arguments (step-level eval/save + frequent logging)
training_args = TrainingArguments(
    output_dir="swag_lora_ft",

    do_eval=True,
    eval_strategy=IntervalStrategy.STEPS,   # 4.47 naming
    eval_steps=500,
    save_strategy=IntervalStrategy.STEPS,
    save_steps=500,
    save_total_limit=2,

    learning_rate=4e-5,
    warmup_ratio=0.1,
    lr_scheduler_type=SchedulerType.LINEAR,

    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4,
    num_train_epochs=3,

    # visible logs / progress
    logging_strategy=IntervalStrategy.STEPS,
    logging_steps=50,         # print every 50 steps
    disable_tqdm=False,       # show progress bar
    report_to="none",         # no wandb/tensorboard

    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

# Tiny callback to guarantee prints even if buffering happens
class PrintCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            print(f"[step {state.global_step}] {logs}")

# Trainer (reuses your existing objects: model_with_lora, train_ds, eval_ds, data_collator, tokenizer)
trainer = Trainer(
    model=model_with_lora,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=data_collator,
    tokenizer=tokenizer,          # ok; may show a future deprecation warning
    compute_metrics=compute_metrics,
    callbacks=[PrintCallback()],
)

# Train
trainer.train()


### Metrics

In [None]:
# Evaluate & report
metrics = trainer.evaluate()

# Perplexity & confusion matrix
perplexity = np.exp(metrics["eval_loss"])
print(f"→ Validation Accuracy:  {100*(metrics['eval_accuracy']):.2f}%")
print(f"→ Validation Perplexity: {perplexity:.2f}")


# Confusion matrix
from sklearn.metrics import confusion_matrix
preds_output = trainer.predict(eval_ds)
preds = np.argmax(preds_output.predictions, axis=1)
cm = confusion_matrix(preds_output.label_ids, preds)
print("→ Confusion Matrix:\n", cm)


### Comparison

In [None]:
ft_accuracy = 100*(metrics["eval_accuracy"])
icl_results.append(("Fine-Tuned BERT + LoRA", ft_accuracy))


import pandas as pd
results_df = pd.DataFrame(icl_results, columns=["Method", "Accuracy"])
print(results_df)


## In-Context Learning with the Fine-Tuned Model

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import math

# make sure model is on device
device = "cuda" if torch.cuda.is_available() else "cpu"
model_with_lora.to(device)


def evaluate_loader(model, loader, description):
    model.eval()            # set eval mode
    all_preds, all_labels = [], []
    total_loss, total_examples = 0.0, 0

    with torch.no_grad():
        for batch in tqdm(loader, desc=description):
            batch = {k: v.to(device) for k,v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss.item()
            bs = batch["labels"].size(0)
            total_loss    += loss * bs
            total_examples+= bs

            preds = torch.argmax(outputs.logits, dim=1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(batch["labels"].cpu().tolist())

    avg_loss   = total_loss / total_examples
    perplexity = math.exp(avg_loss)
    acc        = sum(p==l for p,l in zip(all_preds, all_labels)) / len(all_labels)

    print(f"\n📊 {description} Accuracy:    {100*(acc):.4f}%")
    print(f"📊 {description} Perplexity: {perplexity:.4f}")

    cm = confusion_matrix(all_labels, all_preds)
    print(f"\n{description} Confusion Matrix:\n{cm}\n")

    cr = classification_report(all_labels, all_preds, digits=4)
    print(f"{description} Classification Report:\n{cr}\n")

    return acc


### Few-shot ICL with fine-tuned model

In [None]:
fs_acc = evaluate_loader(model_with_lora, fewshot_icl_loader, "Fine-Tuned Few-Shot ICL")
icl_results.append(("Fine-Tuned Few-Shot ICL", fs_acc))

### Chain of thought ICL with fine-tuned model

In [None]:
cot_acc = evaluate_loader(model_with_lora, cot_loader, "Fine-Tuned CoT ICL")
icl_results.append(("Fine-Tuned CoT ICL", cot_acc))


### Comparison

In [None]:
import pandas as pd
results_df = pd.DataFrame(icl_results, columns=["Method", "Accuracy"])
print(results_df)

## 13. Analyze the Results

In [None]:
# Prepare the test datasets

# --- a) Baseline test (no ICL, pretrained BERT) ---
test_baseline = tokenized_dataset["test"]
test_baseline_loader = DataLoader(test_baseline, batch_size=8, collate_fn=data_collator)

# --- b) Few-Shot ICL on test ---
test_fewshot = dataset["test"].map(
    preprocess_fewshot_icl_batch,
    batched=True,
    remove_columns=dataset["test"].column_names
)
# drop old label column if it exists, keep only MCQ fields
if "label" in test_fewshot.column_names:
    test_fewshot = test_fewshot.remove_columns(["label"])
keep = ["input_ids", "attention_mask", "token_type_ids", "labels"]
test_fewshot = test_fewshot.remove_columns([c for c in test_fewshot.column_names if c not in keep])
test_fewshot_loader = DataLoader(test_fewshot, batch_size=8, collate_fn=data_collator)

# --- c) Chain-of-Thought ICL on test ---
test_cot = dataset["test"].map(
    preprocess_cot_batch,
    batched=True,
    remove_columns=dataset["test"].column_names
)
if "label" in test_cot.column_names:
    test_cot = test_cot.remove_columns(["label"])
test_cot = test_cot.remove_columns([c for c in test_cot.column_names if c not in keep])
test_cot_loader = DataLoader(test_cot, batch_size=8, collate_fn=data_collator)


test_results = []


In [None]:
# a) Baseline
baseline_test_acc = evaluate_loader(base_model, test_baseline_loader, "Test Baseline")
test_results.append(("Baseline (No ICL)", baseline_test_acc))


In [None]:
# b) Few‐Shot ICL (pre-fine-tune)
fs_test_acc = evaluate_loader(base_model, test_fewshot_loader, "Test Few-Shot ICL")
test_results.append(("Few-Shot ICL (pre-FT)", fs_test_acc))
   

In [None]:
# c) CoT ICL (pre-fine-tune)
cot_test_acc = evaluate_loader(base_model, test_cot_loader, "Test CoT ICL")
test_results.append(("CoT ICL (pre-FT)", cot_test_acc))
  

In [None]:
# d) Fine-tuned BERT + LoRA (no ICL)
ft_test_acc = evaluate_loader(model_with_lora, test_baseline_loader, "Test Fine-Tuned BERT+LoRA")
test_results.append(("Fine-Tuned BERT+LoRA", ft_test_acc))
   

In [None]:
# e) Few-Shot ICL on fine-tuned model
fs_ft_test_acc = evaluate_loader(model_with_lora, test_fewshot_loader, "Test Fine-Tuned Few-Shot ICL")
test_results.append(("Few-Shot ICL (FT)", fs_ft_test_acc))
   

In [None]:
# f) CoT ICL on fine-tuned model
cot_ft_test_acc = evaluate_loader(model_with_lora, test_cot_loader, "Test Fine-Tuned CoT ICL")
test_results.append(("CoT ICL (FT)", cot_ft_test_acc))
  

## Report All

In [None]:
import pandas as pd
df = pd.DataFrame(test_results, columns=["Method", "Test Accuracy"])
print(df)