In [None]:
# Task 4: Training Loop Implementation (BONUS)

'''
Assumptions:
We have two synthetic datasets:

task_a_data = List[Tuple[str, int]] (sentence, label)

task_b_data = List[Tuple[str, int]]

We'll alternate batches from Task A and B.

Loss: nn.CrossEntropyLoss for both heads.

Key Decisions and Insights:

Loss funcyion: CrossEntropyLoss per task; it is the standard loss function for classification problems;

Optimizer: AdamW;

reason: AdamW is a variant of the Adam optimizer with decoupled weight decay. Compared to standard Adam, it helps prevent overfitting and improves generalization,
especially in transformer models (like BERT, RoBERTa, MiniLM). HuggingFace and Google recommend AdamW over Adam for training/fine-tuning transformers.

Multi-tasking logic:	Parallel batches (zip):	Alternating task updates is easy and modular

Model outputs:	Raw logits	Matches PyTorch expectations for loss functions

Metric handling:	compute accuracy per task

Freeze support:	Controlled via .requires_grad before training	Aligns with Task 3 setup

'''

# Dependencies
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import random
from typing import List, Tuple

# Dummy Dataset Class
class SimpleTextDataset(Dataset):
    def __init__(self, data: List[Tuple[str, int]]):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sentence, label = self.data[idx]
        return sentence, label

# Training Loop
def train_multitask_model(model: nn.Module,
                          task_a_loader: DataLoader,
                          task_b_loader: DataLoader,
                          num_epochs=5,
                          lr=2e-5,
                          device='cuda' if torch.cuda.is_available() else 'cpu'):
    model = model.to(device)

    # Optimizer
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

    # Loss function
    criterion = nn.CrossEntropyLoss()

    # Iterate through epochs
    for epoch in range(num_epochs):
        model.train()
        total_loss_a = 0.0
        total_loss_b = 0.0

        # Combine iterators for both tasks (zip stops at shortest)
        for (batch_a, batch_b) in zip(task_a_loader, task_b_loader):
            optimizer.zero_grad()

            # --- Task A: Sentence Classification ---
            sentences_a, labels_a = batch_a
            logits_a = model(sentences_a, task='A')
            loss_a = criterion(logits_a, labels_a.to(device))

            # --- Task B: Sentiment Analysis ---
            sentences_b, labels_b = batch_b
            logits_b = model(sentences_b, task='B')
            loss_b = criterion(logits_b, labels_b.to(device))

            # Total loss = sum of both
            loss = loss_a + loss_b
            loss.backward()
            optimizer.step()

            total_loss_a += loss_a.item()
            total_loss_b += loss_b.item()

        avg_loss_a = total_loss_a / len(task_a_loader)
        avg_loss_b = total_loss_b / len(task_b_loader)
        print(f"Epoch {epoch+1}/{num_epochs} | Loss A: {avg_loss_a:.4f} | Loss B: {avg_loss_b:.4f}")

# Example: Dummy Data and Execution

# Simulate small datasets
task_a_data = [("This is about sports.", 0), ("Finance news", 1), ("Politics today", 2)]
task_b_data = [("I love this!", 0), ("Terrible experience", 1), ("It's okay.", 2)]

# Wrap in Datasets and Loaders
task_a_loader = DataLoader(SimpleTextDataset(task_a_data), batch_size=2, shuffle=True)
task_b_loader = DataLoader(SimpleTextDataset(task_b_data), batch_size=2, shuffle=True)

# Initialize model
model = MultiTaskSentenceTransformer(task_a_num_classes=3, task_b_num_classes=3)

# Train (mock)
train_multitask_model(model, task_a_loader, task_b_loader)

# Evaluation Function (Accuracy)

def evaluate_model(model, data_loader, task='A', device='cpu'):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for sentences, labels in data_loader:
            logits = model(sentences, task=task).to(device)
            predictions = torch.argmax(logits, dim=1)
            correct += (predictions.cpu() == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total if total > 0 else 0
    print(f"[Task {task.upper()}] Evaluation Accuracy: {accuracy:.2%}")
    return accuracy

evaluate_model(model, task_a_loader, task='a')
evaluate_model(model, task_b_loader, task='b')
