# Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoConfig, AutoModel, AutoTokenizer

# Download Data

In [None]:
!wget https://raw.githubusercontent.com/aliisharifi/NLP---Spring-1404/main/hw2/aggregate_data/aggregate.json

# Configs

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using device:", device)

In [None]:
projection_size = 256
model_name = "cis-lmu/glot500-base"
num_epochs = 5

# Training

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, hidden_size, proj_size=256):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, proj_size),
        )

    def forward(self, x):
        return self.proj(x)

In [None]:
base_encoder = AutoModel.from_pretrained(model_name, config=config)
proj_head = ProjectionHead(config.hidden_size, projection_size)
base_encoder.to(device)
proj_head.to(device)

In [None]:
class QTDataset(Dataset):
    def __init__(self, pairs, tokenizer, max_len=128):
        self.pairs = pairs
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        # text, question = self.pairs[idx]
        text = self.pairs[idx]["context"]
        question = self.pairs[idx]["question"]
        t = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt",
        )
        q = self.tokenizer(
            question,
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt",
        )
        return {
            "text_ids": t["input_ids"].squeeze(0),
            "text_mask": t["attention_mask"].squeeze(0),
            "ques_ids": q["input_ids"].squeeze(0),
            "ques_mask": q["attention_mask"].squeeze(0),
        }

In [None]:
import json

with open("aggregate.json", "r") as f:
    train_pairs = json.load(f)

In [None]:
train_dataset = QTDataset(train_pairs, tokenizer, max_len=128)
train_loader = DataLoader(
    train_dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True
)

In [None]:
def contrastive_loss(text_emb, ques_emb, temp=0.07):

    text_norm = F.normalize(text_emb, dim=1)
    ques_norm = F.normalize(ques_emb, dim=1)
    logits = torch.matmul(text_norm, ques_norm.t()) / temp
    labels = torch.arange(text_emb.size(0), device=logits.device)
    loss_t2q = F.cross_entropy(logits, labels)
    loss_q2t = F.cross_entropy(logits.t(), labels)
    return (loss_t2q + loss_q2t) / 2

In [None]:
params = list(base_encoder.parameters()) + list(proj_head.parameters())
optimizer = AdamW(params, lr=2e-5, weight_decay=0.01)

base_encoder.train()
proj_head.train()

for epoch in range(1, num_epochs + 1):
    base_encoder.train()
    proj_head.train()
    total_loss = 0.0

    for batch in tqdm(train_loader):
        text_ids = batch["text_ids"].to(device)
        text_mask = batch["text_mask"].to(device)
        ques_ids = batch["ques_ids"].to(device)
        ques_mask = batch["ques_mask"].to(device)

        optimizer.zero_grad()

        out_text = base_encoder(
            input_ids=text_ids, attention_mask=text_mask
        ).last_hidden_state
        out_ques = base_encoder(
            input_ids=ques_ids, attention_mask=ques_mask
        ).last_hidden_state

        text_vec = out_text.mean(dim=1)
        ques_vec = out_ques.mean(dim=1)

        text_proj = proj_head(text_vec)
        ques_proj = proj_head(ques_vec)

        loss = contrastive_loss(text_proj, ques_proj, temp=0.05)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch} — Avg Loss: {avg_loss:.4f}")