# mBERT Baseline — Sarcasm Detection on Cleaned Hinglish Dataset

This notebook trains a **multilingual BERT (mBERT)** baseline for binary sarcasm classification on the cleaned Hinglish dataset.

Expected baseline F1: **~75%** (without data leakage).

**Run this on Google Colab / Kaggle with a GPU runtime.**

## 1. Install Dependencies

In [None]:
!pip install transformers datasets accelerate scikit-learn -q

## 2. Imports & Configuration

In [None]:
import os
import random
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    BertTokenizer,
    BertForSequenceClassification,
    get_linear_schedule_with_warmup,
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    classification_report,
    f1_score,
    accuracy_score,
    confusion_matrix,
)
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# ────────────────────────────────────────────
# CONFIG — tweak these as needed
# ────────────────────────────────────────────
MODEL_NAME   = 'bert-base-multilingual-cased'
MAX_LEN      = 128
BATCH_SIZE   = 32
EPOCHS       = 5
LR           = 2e-5
SEED         = 42
DATA_PATH    = 'Data/sarcasm_hinghlish_dataset_cleaned.csv'
SAVE_DIR     = 'mbert_baseline'

# Reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 3. Load & Split Data

In [None]:
df = pd.read_csv(DATA_PATH)
print(f'Total samples: {len(df)}')
print(f'Label distribution:\n{df["label"].value_counts()}')
print(f'\nSample rows:')
df.head()

In [None]:
# 70 / 15 / 15  split (stratified)
train_df, temp_df = train_test_split(
    df, test_size=0.30, random_state=SEED, stratify=df['label']
)
val_df, test_df = train_test_split(
    temp_df, test_size=0.50, random_state=SEED, stratify=temp_df['label']
)

print(f'Train: {len(train_df)}  |  Val: {len(val_df)}  |  Test: {len(test_df)}')

## 4. Tokenizer & Dataset

In [None]:
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

class SarcasmDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts     = texts.reset_index(drop=True)
        self.labels    = labels.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_len   = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text  = str(self.texts[idx])
        label = int(self.labels[idx])
        enc   = self.tokenizer.encode_plus(
            text,
            max_length=self.max_len,
            truncation=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'input_ids':      enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'labels':         torch.tensor(label, dtype=torch.long),
        }

train_dataset = SarcasmDataset(train_df['text'], train_df['label'], tokenizer, MAX_LEN)
val_dataset   = SarcasmDataset(val_df['text'],   val_df['label'],   tokenizer, MAX_LEN)
test_dataset  = SarcasmDataset(test_df['text'],  test_df['label'],  tokenizer, MAX_LEN)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE)

print(f'Batches  →  train: {len(train_loader)}, val: {len(val_loader)}, test: {len(test_loader)}')

## 5. Model, Optimizer & Scheduler

In [None]:
model = BertForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=2
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)

total_steps = len(train_loader) * EPOCHS
scheduler   = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps,
)

print(f'Total training steps: {total_steps}')

## 6. Training Loop

In [None]:
def train_epoch(model, loader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    preds_all, labels_all = [], []

    for batch in loader:
        input_ids = batch['input_ids'].to(device)
        attn_mask = batch['attention_mask'].to(device)
        labels    = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attn_mask, labels=labels)
        loss    = outputs.loss
        logits  = outputs.logits

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        preds_all.extend(torch.argmax(logits, dim=1).cpu().numpy())
        labels_all.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(loader)
    acc  = accuracy_score(labels_all, preds_all)
    f1   = f1_score(labels_all, preds_all, average='macro')
    return avg_loss, acc, f1


def eval_epoch(model, loader, device):
    model.eval()
    total_loss = 0
    preds_all, labels_all = [], []

    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attn_mask = batch['attention_mask'].to(device)
            labels    = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attn_mask, labels=labels)
            total_loss += outputs.loss.item()
            preds_all.extend(torch.argmax(outputs.logits, dim=1).cpu().numpy())
            labels_all.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(loader)
    acc  = accuracy_score(labels_all, preds_all)
    f1   = f1_score(labels_all, preds_all, average='macro')
    return avg_loss, acc, f1, preds_all, labels_all

In [None]:
history = {'train_loss': [], 'val_loss': [], 'train_f1': [], 'val_f1': []}
best_val_f1 = 0.0

for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc, train_f1 = train_epoch(
        model, train_loader, optimizer, scheduler, device
    )
    val_loss, val_acc, val_f1, _, _ = eval_epoch(
        model, val_loader, device
    )

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_f1'].append(train_f1)
    history['val_f1'].append(val_f1)

    print(
        f'Epoch {epoch}/{EPOCHS}  |  '
        f'Train Loss: {train_loss:.4f}  Acc: {train_acc:.4f}  F1: {train_f1:.4f}  |  '
        f'Val Loss: {val_loss:.4f}  Acc: {val_acc:.4f}  F1: {val_f1:.4f}'
    )

    # Save best model checkpoint
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        os.makedirs(SAVE_DIR, exist_ok=True)
        model.save_pretrained(SAVE_DIR)
        tokenizer.save_pretrained(SAVE_DIR)
        print(f'  ✓ Best model saved (val F1: {best_val_f1:.4f})')

print(f'\nTraining complete. Best validation F1: {best_val_f1:.4f}')

## 7. Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(history['train_loss'], label='Train Loss')
axes[0].plot(history['val_loss'],   label='Val Loss')
axes[0].set_title('Loss over Epochs')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()

axes[1].plot(history['train_f1'], label='Train F1')
axes[1].plot(history['val_f1'],   label='Val F1')
axes[1].set_title('Macro F1 over Epochs')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('F1')
axes[1].legend()

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()

## 8. Evaluate on Test Set

In [None]:
# Load best checkpoint
best_model = BertForSequenceClassification.from_pretrained(SAVE_DIR).to(device)

test_loss, test_acc, test_f1, test_preds, test_labels = eval_epoch(
    best_model, test_loader, device
)

print('═' * 60)
print(f'  TEST RESULTS')
print(f'  Loss:     {test_loss:.4f}')
print(f'  Accuracy: {test_acc:.4f}')
print(f'  Macro F1: {test_f1:.4f}')
print('═' * 60)
print()
print(classification_report(
    test_labels, test_preds,
    target_names=['Not Sarcastic (0)', 'Sarcastic (1)']
))
print('Confusion Matrix:')
print(confusion_matrix(test_labels, test_preds))

## 9. Quick Inference Demo

In [None]:
def predict(text, model, tokenizer, device, max_len=MAX_LEN):
    model.eval()
    enc = tokenizer.encode_plus(
        text,
        max_length=max_len,
        truncation=True,
        padding='max_length',
        return_tensors='pt',
    )
    input_ids = enc['input_ids'].to(device)
    attn_mask = enc['attention_mask'].to(device)

    with torch.no_grad():
        logits = model(input_ids, attention_mask=attn_mask).logits

    probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
    label = int(np.argmax(probs))
    return {
        'text':  text,
        'label': 'Sarcastic' if label == 1 else 'Not Sarcastic',
        'confidence': f'{probs[label]:.2%}',
    }

# Test examples
examples = [
    "Haan bilkul, sab kuch perfect chal raha hai",
    "Aaj bahut productive din tha, sab kaam complete ho gaya",
    "Oh great, another selfie with your car",
    "Cricket News Dear Virat sir aaj bhi aap pichle match wali form jaari Rakhe",
]

print('\n─── Inference Demo ───')
for ex in examples:
    result = predict(ex, best_model, tokenizer, device)
    print(f"  {result['label']:>15} ({result['confidence']})  │  {result['text']}")