In [None]:
import random
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
import pandas as pd
from sklearn.metrics import f1_score, classification_report
from tqdm import tqdm
import gc


# config
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
MAX_LEN = 128
BATCH_SIZE = 8
EPOCHS = 3
LR = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

set_seed(2026)

# dataset
class SarcasmDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.texts = self.df["text"].tolist()
        self.labels = self.df["label"].tolist()
        tokenizer.pad_token = tokenizer.eos_token
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding="max_length",
            max_length=MAX_LEN,
            return_tensors="pt",
            
        )
        item = {k: v.squeeze(0) for k,v in enc.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item


# model
class SarcasmClassifier(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 1)
        )

    def forward(self, x):
        return self.net(x).squeeze(-1)

class SarcasmModel(nn.Module):
    def __init__(self, encoder, hidden_size, pos_weight):
        super().__init__()
        self.encoder = encoder
        self.classifier = SarcasmClassifier(hidden_size)
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    def forward(self, input_ids, attention_mask, labels=None):
        with torch.no_grad():
            out = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )
            h = out.hidden_states[-1]     # [B,T,D]
    
            # mean pooling
            mask = attention_mask.unsqueeze(-1)
            h_masked = h * mask
            emb = h_masked.sum(dim=1) / mask.sum(dim=1)
            emb = emb.float()
    
        logits = self.classifier(emb)
    
        loss = None
        if labels is not None:
            loss = self.loss_fn(logits, labels)
    
        return logits, loss



# train function
def train_model(train_df, test_df, save_path):

    print(f"\n=== Start Training  ===")

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    encoder = AutoModel.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map={"": 0}   # tutto su una GPU
)

    encoder.eval()
    for p in encoder.parameters():
        p.requires_grad = False

    train_ds = SarcasmDataset(train_df, tokenizer)
    test_ds  = SarcasmDataset(test_df, tokenizer)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

    # class imbalance weight
    labels = torch.tensor(train_ds.labels)
    N_pos = (labels==1).sum().item()
    N_neg = (labels==0).sum().item()
    pos_weight = torch.tensor([min(N_neg / N_pos, 2.0)]).to(DEVICE)

    print(f"Pos weight: {pos_weight.item():.2f}")

    hidden_size = encoder.config.hidden_size
    model = SarcasmModel(encoder, hidden_size, pos_weight).to(DEVICE)

    optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=LR)

    
    # train loop
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0

        for batch in tqdm(train_loader):
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            logits, loss = model(input_ids, attention_mask, labels)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_loss += loss.item()

        print(f"Epoch {epoch+1} | loss = {total_loss/len(train_loader):.4f}")

    
    # eval
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader):
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            logits,_ = model(input_ids, attention_mask)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).int()

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    f1 = f1_score(all_labels, all_preds)
    report_dict = classification_report(all_labels, all_preds, output_dict=True)
    
    print("labels:", [int(x) for x in all_labels])
    print("preds: ", [ int(x) for x in all_preds])

    print("\nF1-score sarcasm:", f1)
    print("\nClassification report:\n", report_dict)

    report_df = pd.DataFrame(report_dict).transpose()
    report_df["variety"] = test_df["variety"].iloc[0]
    report_df["source"] = test_df["source"].iloc[0]
    report_df["task"] = test_df["task"].iloc[0]


    # save model
    torch.save(model.classifier.state_dict(), save_path)
    print(f"\nModel saved to {save_path}")

    
    # free GPU
    del model
    del encoder
    torch.cuda.empty_cache()
    gc.collect()

    return report_df


# main
if __name__ == "__main__":
    train = pd.read_csv("/kaggle/input/besstie/train.csv")
    test = pd.read_csv("/kaggle/input/besstie/valid.csv")
    train = train.dropna(subset=['text', 'label', 'variety', 'source', 'task'])
    test = test.dropna(subset=['text', 'label', 'variety', 'source', 'task'])
    
    train_AU = train[
    (train['variety'] == "en-AU") &
    (train['source'] == "Reddit") &
    (train['task'] == "Sarcasm")
    ]

    train_IN = train[
    (train['variety'] == "en-IN") &
    (train['source'] == "Reddit") &
    (train['task'] == "Sarcasm")
    ]

    train_UK = train[
    (train['variety'] == "en-UK") &
    (train['source'] == "Reddit") &
    (train['task'] == "Sarcasm")
    ]

    test_AU = test[
    (test['variety'] == "en-AU") &
    (test['source'] == "Reddit") &
    (test['task'] == "Sarcasm")
    ]

    test_IN = test[
    (test['variety'] == "en-IN") &
    (test['source'] == "Reddit") &
    (test['task'] == "Sarcasm")
    ]

    test_UK = test[
    (test['variety'] == "en-UK") &
    (test['source'] == "Reddit") &
    (test['task'] == "Sarcasm")
    ]

    all_reports = []
    
    # UK
    all_reports.append(train_model(
        train_df=train_UK,
        test_df=test_UK,
        save_path="/kaggle/working/sarcasm_head_uk.pt"
    ))
    
    # IN
    all_reports.append(train_model(
        train_df = train_IN,
        test_df= test_IN,
        save_path="/kaggle/working/sarcasm_head_in.pt"
    ))
    

    # AU
    all_reports.append(train_model(
        train_df = train_AU,
        test_df=test_AU,
        save_path="/kaggle/working/sarcasm_head_au.pt"
    ))

    final_report = pd.concat(all_reports)
    final_report.to_csv("/kaggle/working/mistral_head_reddit_sarcasm.csv", index=True)
