In [None]:
"""
SemEval Task 9 - Subtask 2: SIMPLIFIED Multi-Label Classification
Target: F1-Macro >= 0.65+ for both languages

STRATEGY - BACK TO BASICS:
âœ“ Standard BCE loss (proven to work for multi-label)
âœ“ Positive class weighting per label
âœ“ More epochs with early stopping
âœ“ Better threshold search
âœ“ No complex sampling - keep it simple
"""

import os, pandas as pd, numpy as np, torch, torch.nn as nn
import re, random, gc, warnings
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from transformers import (AutoTokenizer, AutoModelForSequenceClassification,
                         get_linear_schedule_with_warmup, set_seed)
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
warnings.filterwarnings('ignore')

from google.colab import drive
drive.mount('/content/drive')
set_seed(42)

class Config:
    BASE_PATH = '/content/drive/MyDrive/NLP'
    TRAIN_ENG = f'{BASE_PATH}/subtask2/train/eng.csv'
    TRAIN_SWA = f'{BASE_PATH}/subtask2/train/swa.csv'
    DEV_ENG = f'{BASE_PATH}/subtask2/dev/eng.csv'
    DEV_SWA = f'{BASE_PATH}/subtask2/dev/swa.csv'
    OUTPUT_DIR = '/content/subtask2/models'
    PREDICTIONS_DIR = '/content/subtask_2'
    os.makedirs(PREDICTIONS_DIR, exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    MODEL_ENG = 'microsoft/deberta-v3-base'
    MODEL_SWA = 'xlm-roberta-base'
    MAX_LENGTH = 128

    LABELS = ['political', 'racial/ethnic', 'religious', 'gender/sexual', 'other']
    NUM_LABELS = 5

    # Simple training params
    BATCH_SIZE = 16
    GRAD_ACCUM = 2
    EPOCHS = 18
    LR_ENG = 2e-5
    LR_SWA = 2.5e-5
    WEIGHT_DECAY = 0.01
    WARMUP_RATIO = 0.15
    DROPOUT = 0.2
    MAX_GRAD_NORM = 1.0

    VAL_SIZE = 0.18
    USE_AUGMENTATION = False  # Disable augmentation - might add noise

    USE_FP16 = True
    SEED = 42
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Device: {Config.DEVICE}")

class TextPreprocessor:
    @staticmethod
    def clean(text):
        text = str(text).strip().lower()
        text = re.sub(r'http\S+|www\.\S+', ' [url] ', text)
        text = re.sub(r'@\w+', ' [user] ', text)
        text = re.sub(r'#(\w+)', r' \1 ', text)
        text = re.sub(r'(.)\1{3,}', r'\1\1', text)
        text = re.sub(r'([!?.]){2,}', r'\1', text)
        text = re.sub(r'\s+', ' ', text).strip()
        return text if text else "[empty]"

class WeightedBCELoss(nn.Module):
    """BCE with per-label positive class weights"""
    def __init__(self, pos_weights):
        super().__init__()
        self.pos_weights = torch.tensor(pos_weights, dtype=torch.float32)

    def forward(self, logits, targets):
        device = logits.device
        pos_weights = self.pos_weights.to(device)

        loss = nn.functional.binary_cross_entropy_with_logits(
            logits, targets, pos_weight=pos_weights, reduction='mean'
        )
        return loss

class MultiLabelModel(nn.Module):
    def __init__(self, model_name, num_labels=5, dropout=0.2):
        super().__init__()
        self.transformer = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_labels,
            problem_type="multi_label_classification",
            hidden_dropout_prob=dropout,
            attention_probs_dropout_prob=dropout
        )

    def forward(self, input_ids, attention_mask):
        return self.transformer(input_ids=input_ids, attention_mask=attention_mask)

class MultiLabelDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len, label_names):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.label_names = label_names
        self.prep = TextPreprocessor()

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.prep.clean(self.texts[idx])
        label_vector = self.labels.iloc[idx][self.label_names].values.astype(np.float32)

        enc = self.tokenizer(text, max_length=self.max_len, padding='max_length',
                           truncation=True, return_tensors='pt')

        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'labels': torch.tensor(label_vector, dtype=torch.float32)
        }

def compute_pos_weights(df, label_names):
    """Compute positive class weights for BCE loss"""
    pos_weights = []

    print("  Computing pos_weights:")
    for label in label_names:
        pos = df[label].sum()
        neg = len(df) - pos

        if pos == 0:
            weight = 1.0
        else:
            # Weight = neg/pos, capped at 10
            weight = min(10.0, neg / pos)

        pos_weights.append(weight)
        print(f"    {label:20s}: pos={pos:5d}, neg={neg:5d}, weight={weight:.2f}")

    return pos_weights

def load_data(path, label_names):
    df = pd.read_csv(path)
    df['text'] = df['text'].apply(TextPreprocessor.clean)
    df = df[df['text'].str.len() > 0].reset_index(drop=True)

    print(f"  Samples: {len(df)}")
    print(f"  Label Distribution:")

    for label in label_names:
        count = df[label].sum()
        pct = count / len(df) * 100
        print(f"    {label:20s}: {count:5d} ({pct:5.2f}%)")

    return df

def find_optimal_thresholds(model, loader, device, label_names, use_fp16=False):
    """Find optimal threshold per label"""
    print("  Finding thresholds...")
    model.eval()
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="  Probs", leave=False):
            ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            labels = batch['labels']

            if use_fp16:
                with autocast():
                    out = model(ids, mask)
                    probs = torch.sigmoid(out.logits)
            else:
                out = model(ids, mask)
                probs = torch.sigmoid(out.logits)

            all_probs.append(probs.cpu().numpy())
            all_labels.append(labels.numpy())

    all_probs = np.vstack(all_probs).astype(np.float32)
    all_labels = np.vstack(all_labels).astype(np.float32)

    thresholds = []
    for i, label in enumerate(label_names):
        label_probs = all_probs[:, i]
        label_true = all_labels[:, i]

        pos_count = int(label_true.sum())

        if pos_count < 3:
            print(f"    {label:20s}: too few ({pos_count}), using 0.5")
            thresholds.append(0.5)
            continue

        # Standard threshold search
        threshs = np.linspace(0.1, 0.9, 81)
        best_t, best_f1 = 0.5, 0.0

        for t in threshs:
            preds = (label_probs >= t).astype(np.float32)
            f1 = f1_score(label_true, preds, average='binary', zero_division=0)
            if f1 > best_f1:
                best_f1, best_t = f1, t

        thresholds.append(best_t)
        print(f"    {label:20s}: t={best_t:.3f}, F1={best_f1:.4f}")

    return thresholds

def train_epoch(model, loader, opt, sched, crit, device, grad_accum, scaler=None):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []

    opt.zero_grad()

    for step, batch in enumerate(tqdm(loader, desc="  Train", leave=False)):
        ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        if scaler:
            with autocast():
                out = model(ids, mask)
                loss = crit(out.logits, labels) / grad_accum
                probs = torch.sigmoid(out.logits)

            batch_loss = loss.item() * grad_accum
            all_preds.append((probs >= 0.5).long().cpu().numpy())
            all_labels.append(labels.cpu().numpy())

            scaler.scale(loss).backward()

            if (step + 1) % grad_accum == 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), Config.MAX_GRAD_NORM)
                scaler.step(opt)
                scaler.update()
                sched.step()
                opt.zero_grad()
        else:
            out = model(ids, mask)
            loss = crit(out.logits, labels) / grad_accum

            with torch.no_grad():
                probs = torch.sigmoid(out.logits)

            batch_loss = loss.item() * grad_accum
            all_preds.append((probs >= 0.5).long().cpu().numpy())
            all_labels.append(labels.cpu().numpy())

            loss.backward()

            if (step + 1) % grad_accum == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), Config.MAX_GRAD_NORM)
                opt.step()
                sched.step()
                opt.zero_grad()

        total_loss += batch_loss

    all_preds = np.vstack(all_preds).astype(np.float32)
    all_labels = np.vstack(all_labels).astype(np.float32)
    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)

    return total_loss / len(loader), f1_macro

def evaluate(model, loader, crit, device, thresholds, label_names, use_fp16=False, show_report=True):
    model.eval()
    total_loss = 0
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="  Eval", leave=False):
            ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            if use_fp16:
                with autocast():
                    out = model(ids, mask)
                    loss = crit(out.logits, labels)
                    probs = torch.sigmoid(out.logits)
            else:
                out = model(ids, mask)
                loss = crit(out.logits, labels)
                probs = torch.sigmoid(out.logits)

            total_loss += loss.item()
            all_probs.append(probs.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    all_probs = np.vstack(all_probs).astype(np.float32)
    all_labels = np.vstack(all_labels).astype(np.float32)

    all_preds = np.zeros_like(all_probs, dtype=np.float32)
    for i, thresh in enumerate(thresholds):
        all_preds[:, i] = (all_probs[:, i] >= thresh).astype(np.float32)

    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)

    if show_report:
        print("\n  Per-label F1:")
        for i, label in enumerate(label_names):
            f1 = f1_score(all_labels[:, i], all_preds[:, i], average='binary', zero_division=0)
            print(f"    {label:20s}: {f1:.4f}")
        print(f"  Macro: {f1_macro:.4f}")

    return total_loss / len(loader), f1_macro

def train_model(train_df, lang, model_name, lr, config):
    print(f"\n{'='*70}")
    print(f"TRAINING: {lang.upper()}")
    print(f"{'='*70}")

    # Compute pos_weights
    pos_weights = compute_pos_weights(train_df, config.LABELS)

    # Split
    train_data, val_data = train_test_split(
        train_df, test_size=config.VAL_SIZE, random_state=config.SEED
    )

    print(f"\n  Train: {len(train_data)}, Val: {len(val_data)}")

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    train_ds = MultiLabelDataset(train_data['text'].values, train_data, tokenizer,
                                 config.MAX_LENGTH, config.LABELS)
    val_ds = MultiLabelDataset(val_data['text'].values, val_data, tokenizer,
                               config.MAX_LENGTH, config.LABELS)

    train_loader = DataLoader(train_ds, config.BATCH_SIZE, shuffle=True,
                              num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, config.BATCH_SIZE*2, shuffle=False,
                           num_workers=2, pin_memory=True)

    model = MultiLabelModel(model_name, config.NUM_LABELS, config.DROPOUT).to(config.DEVICE)

    opt = torch.optim.AdamW(model.parameters(), lr=lr,
                           weight_decay=config.WEIGHT_DECAY, eps=1e-8)

    steps = len(train_loader) * config.EPOCHS // config.GRAD_ACCUM
    warmup = int(steps * config.WARMUP_RATIO)
    sched = get_linear_schedule_with_warmup(opt, warmup, steps)

    crit = WeightedBCELoss(pos_weights)
    scaler = GradScaler() if config.USE_FP16 else None

    best_f1 = 0.0
    best_thresholds = [0.5] * config.NUM_LABELS
    patience, p_cnt = 5, 0

    for ep in range(config.EPOCHS):
        print(f"\n[Epoch {ep+1}/{config.EPOCHS}]")

        tr_loss, tr_f1 = train_epoch(model, train_loader, opt, sched, crit,
                                     config.DEVICE, config.GRAD_ACCUM, scaler)

        val_loss, val_f1 = evaluate(model, val_loader, crit, config.DEVICE,
                                    [0.5]*config.NUM_LABELS, config.LABELS,
                                    config.USE_FP16, show_report=False)

        print(f"  Train: Loss={tr_loss:.4f}, F1={tr_f1:.4f}")
        print(f"  Val:   Loss={val_loss:.4f}, F1={val_f1:.4f}")

        # Threshold search from epoch 5
        if ep >= 5:
            thresholds = find_optimal_thresholds(model, val_loader, config.DEVICE,
                                                config.LABELS, config.USE_FP16)
            _, val_f1 = evaluate(model, val_loader, crit, config.DEVICE, thresholds,
                               config.LABELS, config.USE_FP16,
                               show_report=(ep >= config.EPOCHS - 2))

            if val_f1 > best_f1:
                best_f1, best_thresholds, p_cnt = val_f1, thresholds, 0
                torch.save({
                    'model': model.state_dict(),
                    'thresholds': thresholds,
                    'f1': float(val_f1)
                }, f"{config.OUTPUT_DIR}/best_{lang}.pt", _use_new_zipfile_serialization=True)
                print(f"  âœ“ Saved (F1={best_f1:.4f})")
            else:
                p_cnt += 1
                print(f"  No improvement ({p_cnt}/{patience})")

        if ep >= 8 and p_cnt >= patience:
            print("  Early stopping")
            break

    # Load best
    ckpt = torch.load(f"{config.OUTPUT_DIR}/best_{lang}.pt",
                     map_location=config.DEVICE, weights_only=False)
    model.load_state_dict(ckpt['model'])
    best_f1 = ckpt['f1']
    best_thresholds = ckpt['thresholds']

    print(f"\n{'='*70}")
    print(f"FINAL: F1-Macro={best_f1:.4f}")
    print(f"{'='*70}\n")

    return model, tokenizer, best_f1, best_thresholds

def predict(model, tokenizer, test_file, out_file, thresholds, label_names, config):
    print(f"\nPredicting: {test_file}")

    df = pd.read_csv(test_file)
    df['text'] = df['text'].apply(TextPreprocessor.clean)

    for label in label_names:
        df[label] = 0

    ds = MultiLabelDataset(df['text'].values, df, tokenizer, config.MAX_LENGTH, label_names)
    loader = DataLoader(ds, config.BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)

    model.eval()
    all_preds = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="  Predict", leave=False):
            ids = batch['input_ids'].to(config.DEVICE)
            mask = batch['attention_mask'].to(config.DEVICE)

            if config.USE_FP16:
                with autocast():
                    out = model(ids, mask)
                    probs = torch.sigmoid(out.logits)
            else:
                out = model(ids, mask)
                probs = torch.sigmoid(out.logits)

            all_preds.append(probs.cpu().numpy())

    all_probs = np.vstack(all_preds)

    predictions = np.zeros_like(all_probs, dtype=int)
    for i, thresh in enumerate(thresholds):
        predictions[:, i] = (all_probs[:, i] >= thresh).astype(int)

    out_df = pd.DataFrame({'id': df['id']})
    for i, label in enumerate(label_names):
        out_df[label] = predictions[:, i]

    out_df.to_csv(out_file, index=False)

    print(f"âœ“ Saved: {out_file}")
    for i, label in enumerate(label_names):
        count = predictions[:, i].sum()
        pct = count / len(predictions) * 100
        print(f"  {label:20s}: {count:5d} ({pct:5.2f}%)")
    print()

if __name__ == "__main__":
    print("\n" + "="*70)
    print("SemEval Task 9 - Subtask 2: SIMPLIFIED Multi-Label")
    print("="*70)

    print("\nðŸ“Š English...")
    eng_train = load_data(Config.TRAIN_ENG, Config.LABELS)
    eng_model, eng_tok, eng_f1, eng_t = train_model(
        eng_train, 'english', Config.MODEL_ENG, Config.LR_ENG, Config
    )
    predict(eng_model, eng_tok, Config.DEV_ENG,
           f"{Config.PREDICTIONS_DIR}/pred_eng.csv", eng_t, Config.LABELS, Config)

    del eng_model
    gc.collect()
    torch.cuda.empty_cache()

    print("\nðŸ“Š Kiswahili...")
    swa_train = load_data(Config.TRAIN_SWA, Config.LABELS)
    swa_model, swa_tok, swa_f1, swa_t = train_model(
        swa_train, 'kiswahili', Config.MODEL_SWA, Config.LR_SWA, Config
    )
    predict(swa_model, swa_tok, Config.DEV_SWA,
           f"{Config.PREDICTIONS_DIR}/pred_swa.csv", swa_t, Config.LABELS, Config)

    print("\n" + "="*70)
    print("RESULTS")
    print("="*70)
    print(f"English:   F1={eng_f1:.4f}")
    print(f"Kiswahili: F1={swa_f1:.4f}")
    print(f"Average:   {(eng_f1+swa_f1)/2:.4f}")
    print("="*70)

    print("\nðŸ“¦ Creating submission...")
    !zip -r -q subtask_2.zip subtask_2
    print("âœ… DONE!")

KeyboardInterrupt: 