In [None]:
import re
import json
import torch
import torch.nn as nn
import math
import os
import gc
from torch.utils.data import Dataset
from transformers import AutoModel, get_linear_schedule_with_warmup
from torch.optim import AdamW
from torch.amp import autocast, GradScaler
from tqdm.auto import tqdm
import random
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import shutil
from torch.utils.data import Sampler
from collections import defaultdict
import numpy as np

class SyllogismDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.split_pattern = r'[\.\u3002\u0964]+'
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        raw_text = item["syllogism"]
        sentences = [s.strip() for s in re.split(self.split_pattern, raw_text) if s.strip()]
        if not sentences: sentences = ["Empty"]
        conclusion = sentences[-1]
        premises = sentences[:-1]
        premises_text = self.tokenizer.sep_token.join(premises)
        encoding = self.tokenizer(premises_text, conclusion, truncation=True, max_length=self.max_len, padding="max_length", return_tensors="pt")
        validity = 1.0 if item["validity"] else 0.0
        plausibility = 1.0 if item.get("plausibility", False) else 0.0
        return {
            "input_ids": encoding['input_ids'].squeeze(0),
            "attention_mask": encoding['attention_mask'].squeeze(0),
            "validity_label": torch.tensor(validity, dtype=torch.float),
            "plausibility_label": torch.tensor(plausibility, dtype=torch.float),
            "id": item.get("id", str(idx))
        }
    
class EarlyStopping:
    def __init__(self, patience=3, min_delta=0.001, mode='max'):
        self.patience, self.min_delta, self.mode = patience, min_delta, mode
        self.counter, self.best_score, self.early_stop = 0, None, False
    
    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score
            return False
        improved = score > (self.best_score + self.min_delta) if self.mode == 'max' else score < (self.best_score - self.min_delta)
        if improved:
            self.best_score, self.counter = score, 0
        else:
            self.counter += 1
            if self.counter >= self.patience: self.early_stop = True
        return self.early_stop
    

def compute_semeval_metrics(predictions, ground_truth):
    gt_map = {item['id']: item for item in ground_truth}
    correct, total = 0, 0
    subgroups = {(True, True): [0, 0], (True, False): [0, 0], 
                 (False, True): [0, 0], (False, False): [0, 0]}
    
    for pred in predictions:
        if pred['id'] not in gt_map: continue
        item = gt_map[pred['id']]
        pred_val = pred['validity_pred'] if 'validity_pred' in pred else pred.get('validity')
        true_val = item['validity']
        plaus = item.get('plausibility', False)
        
        total += 1
        if pred_val == true_val: correct += 1
            
        key = (true_val, plaus)
        if key in subgroups:
            subgroups[key][1] += 1
            if pred_val == true_val: subgroups[key][0] += 1
    
    accuracy = (correct / total * 100) if total > 0 else 0.0
    def get_acc(v, p): return (subgroups[(v, p)][0] / subgroups[(v, p)][1] * 100) if subgroups[(v, p)][1] > 0 else 0.0
    
    intra_diff = (abs(get_acc(True, True) - get_acc(True, False)) + abs(get_acc(False, True) - get_acc(False, False))) / 2.0
    inter_diff = (abs(get_acc(True, True) - get_acc(False, True)) + abs(get_acc(True, False) - get_acc(False, False))) / 2.0
    tot_bias = (intra_diff + inter_diff) / 2.0
    
    ranking_score = accuracy / (1 + math.log(1 + tot_bias))
    
    return {
        "accuracy": accuracy,
        "total_bias": tot_bias,
        "ranking_score": ranking_score
    }

def monitor_language_performance(predictions, ground_truth):
    pred_map = {p['id']: p for p in predictions}
    lang_groups = {}
    for item in ground_truth:
        lang = item.get('language', 'unknown') 
        if lang not in lang_groups: lang_groups[lang] = {'gt': [], 'preds': []}
        lang_groups[lang]['gt'].append(item)
        if item['id'] in pred_map: lang_groups[lang]['preds'].append(pred_map[item['id']])

    print("-" * 65)
    print(f"{'LANG':<5} | {'COUNT':<5} | {'ACC':<6} | {'BIAS':<8} | {'SCORE':<7}")
    print("-" * 65)
    for lang in sorted(lang_groups.keys()):
        group = lang_groups[lang]
        if len(group['gt']) == 0: continue
        metrics = compute_semeval_metrics(group['preds'], group['gt'])
        print(f"{lang:<5} | {len(group['gt']):<5} | {metrics['accuracy']:<6.2f} | {metrics['total_bias']:<8.4f} | {metrics['ranking_score']:<7.2f}")
    print("-" * 65)

def save_model_for_inference(model, tokenizer, save_dir, config_dict):
    os.makedirs(save_dir, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_dir, 'pytorch_model.bin'))
    with open(os.path.join(save_dir, 'config.json'), 'w') as f: json.dump(config_dict, f, indent=2)
    tokenizer.save_pretrained(save_dir)

class BinarySyllogismModel(nn.Module):
    def __init__(self, model_name, dropout_rate=0.1, pooling_type="cls"):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        self.pooling_type = pooling_type
        hidden_size = self.backbone.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate), nn.Linear(hidden_size, hidden_size), nn.Tanh(), nn.Linear(hidden_size, 1)
        )
    
    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        if self.pooling_type == "mean":
            mask = attention_mask.unsqueeze(-1).expand(outputs.last_hidden_state.size()).float()
            vec = torch.sum(outputs.last_hidden_state * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
        else:
            vec = outputs.last_hidden_state[:, 0, :]
        return self.classifier(vec).squeeze(-1)


def train_engine(model, train_loader, val_loader, config, device):
    torch.cuda.empty_cache()
    gc.collect()
    
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": config['weight_decay']},
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=config['lr'])
    
    num_update_steps_per_epoch = math.ceil(len(train_loader) / config['acc_steps'])
    total_training_steps = num_update_steps_per_epoch * config['epochs']
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(total_training_steps * config['warmup_ratio']), num_training_steps=total_training_steps)

    early_stopping = EarlyStopping(patience=config['patience'], min_delta=config['min_delta'], mode='max')
    scaler = GradScaler(enabled=config['fp16'])
    best_ranking_score = 0.0
    
    for epoch in range(config['epochs']):
        lambda_schedule = config["lambda_schedule"]
        default_lambda = config["lambda_default"]

        if epoch < len(lambda_schedule):
            current_lambda = lambda_schedule[epoch]
        else:
            current_lambda = default_lambda

        print(f" Bias Î» (epoch {epoch + 1}): {current_lambda}")


        model.train()
        total_train_loss, total_train_bias = 0, 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{config['epochs']}")
    
        for step, batch in enumerate(progress_bar):
            input_ids, mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
            v_labels, p_labels = batch['validity_label'].to(device), batch['plausibility_label'].to(device)

            with autocast(device_type="cuda", dtype=torch.float16, enabled=config['fp16']):
                logits = model(input_ids, mask)
                task_loss = nn.BCEWithLogitsLoss()(logits, v_labels)
                bias_penalty = torch.tensor(0.0, device=device)
                if current_lambda > 0:
                    probs = torch.sigmoid(logits)
                    conf = torch.where(v_labels == 1, probs, 1.0 - probs)
                    intra_diffs = []
                    for p in [0, 1]:
                        m_p = (p_labels == p)
                        if (v_labels[m_p] == 1).any() and (v_labels[m_p] == 0).any():
                            intra_diffs.append(torch.abs(conf[m_p][v_labels[m_p]==1].mean() - conf[m_p][v_labels[m_p]==0].mean()))
                    intra_loss = torch.stack(intra_diffs).mean() if intra_diffs else torch.tensor(0.0, device=device)
                    cross_diffs = []
                    for v in [0, 1]:
                        m_v = (v_labels == v)
                        if (p_labels[m_v] == 1).any() and (p_labels[m_v] == 0).any():
                            cross_diffs.append(torch.abs(conf[m_v][p_labels[m_v]==1].mean() - conf[m_v][p_labels[m_v]==0].mean()))
                    cross_loss = torch.stack(cross_diffs).mean() if cross_diffs else torch.tensor(0.0, device=device)
                    bias_penalty = (intra_loss + cross_loss) / 2.0
    
                loss = (task_loss + (current_lambda * bias_penalty)) / config['acc_steps']
            
            scaler.scale(loss).backward()
            if (step + 1) % config['acc_steps'] == 0 or (step + 1) == len(train_loader):
                scaler.unscale_(optimizer); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer); scaler.update(); scheduler.step(); optimizer.zero_grad(set_to_none=True)
    
            total_train_loss += loss.item() * config['acc_steps']
            total_train_bias += bias_penalty.item()
            progress_bar.set_postfix({'l': f"{loss.item()*config['acc_steps']:.4f}", 'b': f"{bias_penalty.item():.4f}", 'lr': f"{scheduler.get_last_lr()[0]:.2e}"})
    
        model.eval(); val_predictions, total_val_loss, total_val_bias = [], 0, 0
        with torch.no_grad():
            for batch in val_loader:
                input_ids, mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
                v_labels, p_labels, ids = batch['validity_label'].to(device), batch['plausibility_label'].to(device), batch['id']
                with autocast(device_type="cuda", dtype=torch.float16, enabled=config['fp16']):
                    logits = model(input_ids, mask)
                    total_val_loss += nn.BCEWithLogitsLoss()(logits, v_labels).item()
                    probs = torch.sigmoid(logits); conf = torch.where(v_labels == 1, probs, 1.0 - probs)
                    
                    i_diffs = [torch.abs(conf[p_labels==p][v_labels[p_labels==p]==1].mean() - conf[p_labels==p][v_labels[p_labels==p]==0].mean()) for p in [0,1] if (v_labels[p_labels==p]==1).any() and (v_labels[p_labels==p]==0).any()]
                    c_diffs = [torch.abs(conf[v_labels==v][p_labels[v_labels==v]==1].mean() - conf[v_labels==v][p_labels[v_labels==v]==0].mean()) for v in [0,1] if (p_labels[v_labels==v]==1).any() and (p_labels[v_labels==v]==0).any()]
                    total_val_bias += ((torch.stack(i_diffs).mean() if i_diffs else torch.tensor(0.0, device=device)) + (torch.stack(c_diffs).mean() if c_diffs else torch.tensor(0.0, device=device))).item() / 2.0
                
                preds = (probs >= 0.5).long().cpu().numpy()
                for i, uid in enumerate(batch['id']): val_predictions.append({'id': uid, 'validity_pred': bool(preds[i])})
        
        metrics = compute_semeval_metrics(val_predictions, val_loader.dataset.data)
        print(f"\n Epoch {epoch + 1} | Train Total Loss: {total_train_loss/len(train_loader):.4f} | Train Bias Loss: {total_train_bias/len(train_loader):.4f} | Validation Total Loss: {total_val_loss/len(val_loader):.4f} | Validation Total Bias: {total_val_bias/len(val_loader):.4f}")
        print(f"    Acc: {metrics['accuracy']:.4f}% | Bias: {metrics['total_bias']:.4f} | Rank: {metrics['ranking_score']:.4f}")
        print("\n    Per-Language Stats:"); monitor_language_performance(val_predictions, val_loader.dataset.data)
    
        if metrics['ranking_score'] > (best_ranking_score + config['min_delta']):
            best_ranking_score = metrics['ranking_score']
            save_model_for_inference(
                model, 
                train_loader.dataset.tokenizer, 
                config['save_dir'], 
                {
                    'model_name': config['model_name'], 
                    'max_len': config['max_len'], 
                    'dropout_rate': config['dropout_rate'],
                    'pooling_type': config.get('pooling_type', 'cls'),
                    "lambda_schedule": config["lambda_schedule"],
                    "lambda_default": config["lambda_default"]
                }
            )
            print("    Best Checkpoint Saved!")
        if early_stopping(metrics['ranking_score']): break
        torch.cuda.empty_cache(); gc.collect()
    return model

def predict(model, dataloader, device, config):
    model.eval(); predictions = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Predicting"):
            with autocast(device_type="cuda", dtype=torch.float16, enabled=config['fp16']):
                logits = model(batch['input_ids'].to(device), batch['attention_mask'].to(device))
            preds = (torch.sigmoid(logits) >= 0.5).long().cpu().numpy()
            for i, uid in enumerate(batch['id']): predictions.append({'id': uid, 'validity_pred': bool(preds[i])})
    return predictions

class StratifiedBatchSampler(Sampler):
    def __init__(self, group_ids, batch_size, num_groups=6):
        self.group_ids = np.array(group_ids)
        self.batch_size = batch_size
        self.num_groups = num_groups
        
        self.groups = defaultdict(list)
        for idx, g in enumerate(self.group_ids):
            self.groups[g].append(idx)
        
        self.group_keys = list(self.groups.keys())
        
        self.samples_per_group = max(1, batch_size // num_groups)
        
        min_group_size = min(len(v) for v in self.groups.values())
        self.num_batches = min_group_size // self.samples_per_group

    def __iter__(self):
        shuffled_groups = {}
        for g in self.group_keys:
            shuffled = np.array(self.groups[g]).copy()
            np.random.shuffle(shuffled)
            shuffled_groups[g] = shuffled
        
        for i in range(self.num_batches):
            batch = []
            
            for g in self.group_keys:
                start = i * self.samples_per_group
                end = (i + 1) * self.samples_per_group
                batch.extend(shuffled_groups[g][start:end].tolist())
            
            np.random.shuffle(batch)
            yield batch

    def __len__(self):
        return self.num_batches

CONFIG = {
    "model_name": "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7",
    "batch_size": 144, "epochs": 3, "lr": 2e-05, "warmup_ratio": 0.06, "weight_decay": 0.01,
    "acc_steps": 1, "fp16": True, "patience": 2, "min_delta": 0.001,
    "save_dir": "/kaggle/working/trained_model_task3",
    "pooling_type": "cls", 
    "lambda_schedule": [0.5, 1.5, 2.5],
    "lambda_default": 2.5,
    "dropout_rate": 0.1,
    "max_len": 96
}


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    with open("/kaggle/input/task-3-dataset/task_3_final_dataset.json", 'r') as f: train_aug_data = json.load(f)
    with open("/kaggle/input/task-3-dataset/curator_combined_dataset.json", 'r') as f: gold_data = json.load(f)
    with open("/kaggle/input/task3-hans/task3_dataset_hans_translated.json", 'r') as f: hans_data = json.load(f)
    print(f" Loaded {len(train_aug_data)} aug samples, {len(gold_data)} gold samples, and {len(hans_data)} hans samples.")

    wrong_ids_lst = ["9f087d4f-5b6f-4d89-b9a0-2ed39ea55950", "213df683-1fcc-4372-9833-3120c641d5eb", "49f81753-fe0d-492c-91fc-c565f8e5ee1e", "202e8697-32e7-4ac9-ae52-4cf462be9766", "e1d810fa-9d81-4a2f-899b-4b9de70ec62d", "b9303323-806b-4664-b29a-1a81fe0e7af0", "6a2a14d5-6a43-4610-a539-6a5afe905356", "080b1667-1426-47c6-95c5-61451e0deee6", "191acf6e-20dd-4c59-ae05-db7585ecef52", "f0402a7a-f2a2-4430-b18d-fa56ada5acf3"]

    train_aug_data = [e for e in train_aug_data if e.get("id") not in wrong_ids_lst]

    if train_aug_data:
        random.seed(42); random.shuffle(train_aug_data); random.shuffle(gold_data); random.shuffle(hans_data)
        
        stratify_labels = [f"{x['language']}_{x['validity']}_{x.get('plausibility', False)}" for x in gold_data]

        train_gold, temp, _, temp_lbls = train_test_split(gold_data, stratify_labels, test_size=0.20, random_state=42, stratify=stratify_labels)
        val_data, test_data = train_test_split(temp, test_size=0.50, random_state=42, stratify=temp_lbls)

        train_data = train_aug_data + train_gold + hans_data
        
        def make_group_id(item):
            return f"{item.get('language','unk')}_{item['validity']}_{item.get('plausibility', False)}"
        
        train_group_ids = [make_group_id(x) for x in train_data]
        
        tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_name'])
        
        train_dataset = SyllogismDataset(train_data, tokenizer, CONFIG['max_len'])
        
        batch_sampler = StratifiedBatchSampler(
            group_ids=train_group_ids,
            batch_size=CONFIG['batch_size'],
            num_groups=72
        )
        
        train_loader = DataLoader(
            train_dataset,
            batch_sampler=batch_sampler,
            num_workers=4,
            pin_memory=True
        )
        val_loader = DataLoader(SyllogismDataset(val_data, tokenizer, CONFIG['max_len']), batch_size=CONFIG['batch_size'], shuffle=False, num_workers=4, pin_memory=True)
        test_loader = DataLoader(SyllogismDataset(test_data, tokenizer, CONFIG['max_len']), batch_size=CONFIG['batch_size'], shuffle=False, num_workers=4, pin_memory=True)

        model = BinarySyllogismModel(CONFIG['model_name'], CONFIG['dropout_rate'], CONFIG['pooling_type'])
        
        model.backbone.config.use_cache = False  
        model.backbone.gradient_checkpointing_enable(
            gradient_checkpointing_kwargs={"use_reentrant": False}
        )
        
        model = model.to(device)
        os.makedirs(CONFIG['save_dir'], exist_ok=True)

        trained_model = train_engine(model, train_loader, val_loader, CONFIG, device)

        print("\n Predicting on Test Set...")
        test_preds = predict(trained_model, test_loader, device, CONFIG)
        metrics = compute_semeval_metrics(test_preds, test_data)

        print(f"\n FINAL: Acc: {metrics['accuracy']:.2f}% | Bias: {metrics['total_bias']:.4f} | Rank: {metrics['ranking_score']:.4f}")
        print("\n Final Per-Language Analysis:"); monitor_language_performance(test_preds, test_data)

        output_base = "/kaggle/working/subtask_3_bias_scheduler"
        with open(f"{output_base}_predictions.json", "w") as f: json.dump(test_preds, f, indent=4)
        with open(f"{output_base}_ground_truth.json", "w") as f: json.dump(test_data, f, indent=4)
        shutil.make_archive(f"{output_base}_model", 'zip', CONFIG['save_dir'])
        print("\n Predictions and Model Saved/Zipped.")
