In [1]:
!pip install -q transformers sentencepiece scikit-learn

import re
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset, Sampler
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import os
import shutil
import math
from collections import Counter, defaultdict
from torch.cuda.amp import autocast, GradScaler
import gc
from tqdm.auto import tqdm
import itertools

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

MODEL_NAME = "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli"
BATCH_SIZE = 6       
EVAL_BATCH_SIZE = 6
ACCUMULATION_STEPS = 10   
EPOCHS = 10               
LEARNING_RATE = 5e-06    
WARMUP_RATIO = 0.06        
WEIGHT_DECAY = 0.01        
NUM_WORKERS = 2
FP16 = True
DRO_WARMUP_EPOCHS = 1 
CONSISTENCY_LAMBDA = 0.5   

MAX_LEN = 232
BUFFER_SIZE = 9
PATIENCE = 2
MIN_DELTA = 0.001

OUTPUT_DIR = "/kaggle/working"
MODEL_SAVE_DIR = os.path.join(OUTPUT_DIR, "trained_model")
JSON_FILE_PATH = "/kaggle/input/datasets/tahazahid001/st2-experiments-kl-me/st2_all_combined.json"  
GOLD_FILE_PATH = "/kaggle/input/datasets/abdullahshheikh/task-2-data/gold_changes.json"

class PreparingInput(Dataset):
    def __init__(self, data, tokenizer, max_len=512, buffer_size=64):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.buffer_size = buffer_size
        self.sep_token = tokenizer.sep_token
        self.sep_id = tokenizer.sep_token_id

        self.split_pattern = r'[\.\u3002\u0964]+'

    def __len__(self):
        return len(self.data)

    def _process_syllogism(self, raw_text):
        sentences = [s.strip() for s in re.split(self.split_pattern, raw_text) if s.strip()]
        if not sentences:
            sentences = ["Empty"]

        conclusion = sentences[-1]
        premises = sentences[:-1]
        premises_text = self.sep_token.join(premises)

        encoding = self.tokenizer(
            conclusion,
            premises_text,
            truncation=True,
            max_length=self.max_len,
            padding="max_length",
            return_tensors="pt"
        )

        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)

        premise_spans = torch.zeros(self.buffer_size, 2, dtype=torch.long)
        premise_mask = torch.zeros(self.buffer_size, dtype=torch.float)

        all_sep_indices = (input_ids == self.sep_id).nonzero(as_tuple=True)[0]
        if len(all_sep_indices) > 1:
            sep_positions = all_sep_indices.tolist()
            segment_starts = sep_positions[:-1]
            segment_ends = sep_positions[1:]

            num_premises = len(segment_starts)
            limit = min(num_premises, self.buffer_size)

            for i in range(limit):
                start_pos = segment_starts[i] + 1
                end_pos = segment_ends[i] - 1
                if end_pos >= start_pos:
                    premise_spans[i, 0] = start_pos
                    premise_spans[i, 1] = end_pos
                    premise_mask[i] = 1.0

        return input_ids, attention_mask, premise_spans, premise_mask

    def __getitem__(self, idx):
        item = self.data[idx]

        primary_text = item.get("syllogism_simple", item.get("syllogism", ""))
        input_ids, attention_mask, premise_spans, premise_mask = self._process_syllogism(primary_text)

        alt_text = item.get("syllogism_complex", item.get("syllogism_variant"))
        if alt_text:
            alt_input_ids, alt_attention_mask, alt_premise_spans, alt_premise_mask = self._process_syllogism(alt_text)
        else:
            alt_input_ids = input_ids.clone()
            alt_attention_mask = attention_mask.clone()
            alt_premise_spans = premise_spans.clone()
            alt_premise_mask = premise_mask.clone()

        if 'validity' in item:
            is_valid = item['validity']
            validity_label = torch.tensor(1 if is_valid else 0, dtype=torch.float)

            selection_label = torch.zeros(self.buffer_size, dtype=torch.float)
            if is_valid:
                for p_idx in item.get('premises', []):
                    if p_idx < self.buffer_size:
                        selection_label[p_idx] = 1.0
        else:
            validity_label = torch.tensor(0, dtype=torch.long)
            selection_label = torch.zeros(self.buffer_size, dtype=torch.float)

        if 'plausibility' in item:
            pl = item['plausibility']
            if isinstance(pl, bool):
                plausibility_label = torch.tensor(2 if pl else 0, dtype=torch.long)
            elif isinstance(pl, str):
                pl = pl.strip().lower()
                if pl in ("plausible", "true", "yes"):
                    plausibility_label = torch.tensor(2, dtype=torch.long)
                elif pl in ("implausible", "false", "no"):
                    plausibility_label = torch.tensor(0, dtype=torch.long)
                else:
                    plausibility_label = torch.tensor(1, dtype=torch.long)  
            else:
                plausibility_label = torch.tensor(1, dtype=torch.long)  
        else:
            plausibility_label = torch.tensor(1, dtype=torch.long)  

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "premise_spans": premise_spans,
            "premise_mask": premise_mask,
            "alt_input_ids": alt_input_ids,
            "alt_attention_mask": alt_attention_mask,
            "alt_premise_spans": alt_premise_spans,
            "alt_premise_mask": alt_premise_mask,
            "validity_label": validity_label,
            "selection_label": selection_label,
            "plausibility_label": plausibility_label
        }

class VariableSyllogismModel(nn.Module):
    def __init__(self, model_name, dropout_rate=0.1, max_len=512):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        hidden_size = self.backbone.config.hidden_size
        self.max_len = max_len

        self.validity_head = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, 1)  
        )

        self.selection_head = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, input_ids=None, attention_mask=None, premise_spans=None, premise_mask=None,
        validity_labels=None, selection_labels=None, inputs_embeds=None):

        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds)
        last_hidden_state = outputs.last_hidden_state

        cls_vectors = last_hidden_state[:, 0, :]
        validity_logits = self.validity_head(cls_vectors).squeeze(dim=-1)

        batch_size, buffer_size, _ = premise_spans.shape
        
        premise_vectors = []
        for i in range(buffer_size):
            starts = premise_spans[:, i, 0]
            ends = premise_spans[:, i, 1]
            
            slot_means = []
            for b in range(batch_size):
                span = last_hidden_state[b, starts[b]:ends[b]+1, :]
                if span.size(0) > 0:
                    slot_means.append(span.mean(dim=0))
                else:
                    slot_means.append(torch.zeros(last_hidden_state.size(-1), device=input_ids.device if input_ids is not None else inputs_embeds.device))
            
            premise_vectors.append(torch.stack(slot_means))

        premise_vectors = torch.stack(premise_vectors, dim=1)
        
        selection_logits = self.selection_head(premise_vectors).squeeze(dim=-1)

        return {
            'validity_logits': validity_logits,
            'selection_logits': selection_logits
        }

class EarlyStopping:
    def __init__(self, patience=3, min_delta=0.001, mode='max'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score
            return False

        if self.mode == 'max':
            improved = score > (self.best_score + self.min_delta)
        else:
            improved = score < (self.best_score - self.min_delta)

        if improved:
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

        return self.early_stop

def save_model_for_inference(model, tokenizer, save_dir, config_dict):
    os.makedirs(save_dir, exist_ok=True)

    torch.save(model.state_dict(), os.path.join(save_dir, 'pytorch_model.bin'))

    config = {
        'model_name': config_dict['model_name'],
        'max_len': config_dict['max_len'],
        'buffer_size': config_dict['buffer_size'],
        'dropout_rate': 0.1
    }
    with open(os.path.join(save_dir, 'config.json'), 'w') as f:
        json.dump(config, f, indent=2)

    tokenizer.save_pretrained(save_dir)
    print(f" Full Model (Backbone + Heads) saved to '{save_dir}/'")

def load_model_for_inference(save_dir, device='cpu'):
    with open(os.path.join(save_dir, 'config.json'), 'r') as f:
        config = json.load(f)

    tokenizer = AutoTokenizer.from_pretrained(save_dir)

    model = VariableSyllogismModel(
        model_name=config['model_name'],
        dropout_rate=config['dropout_rate'],
        max_len=config['max_len']
    )

    model_path = os.path.join(save_dir, 'pytorch_model.bin')
    state_dict = torch.load(model_path, map_location=device)

    model.load_state_dict(state_dict)

    model.to(device)
    model.eval()
    return model, tokenizer, config

def get_group_ids(validity_labels, plausibility_labels):
    """
    Maps each sample to a group 0..5
    Group = validity * 3 + plausibility_index
    plausibility_index: 0=implausible, 1=neutral, 2=plausible
    """
    plaus_idx = plausibility_labels.clone()
    plaus_idx[plaus_idx == 0] = 0   
    plaus_idx[plaus_idx == 1] = 1   
    plaus_idx[plaus_idx == 2] = 2   

    group_ids = (validity_labels.long() * 3) + plaus_idx.long()
    return group_ids

def compute_semeval_metrics(predictions, ground_truth):
    gt_map = {item['id']: item for item in ground_truth}
    
    total_precision, total_recall, valid_f1_count = 0.0, 0.0, 0

    for pred in predictions:
        item_id = pred['id']
        if item_id not in gt_map: 
            continue
            
        gt_item = gt_map[item_id]
        
        if not gt_item.get('validity', False):
            continue

        pred_set = set(pred.get('premises_pred', []))
        true_set = set(gt_item.get('premises', []))
        
        tp = len(true_set.intersection(pred_set))
        fp = len(pred_set - true_set)
        fn = len(true_set - pred_set)
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        
        total_precision += precision
        total_recall += recall
        valid_f1_count += 1

    f1_premises = 0.0
    if valid_f1_count > 0:
        macro_prec = total_precision / valid_f1_count
        macro_rec = total_recall / valid_f1_count
        if (macro_prec + macro_rec) > 0:
            f1_premises = 100 * (2 * (macro_prec * macro_rec) / (macro_prec + macro_rec))

    correct_validity = 0
    total_validity = 0
    subgroups = {
        (True, True): [0, 0], (True, False): [0, 0],
        (False, True): [0, 0], (False, False): [0, 0]
    }

    for pred in predictions:
        item_id = pred['id']
        if item_id not in gt_map: 
            continue
            
        gt_item = gt_map[item_id]
        p_val = pred['validity_pred']
        t_val = gt_item['validity']
        
        is_correct = (p_val == t_val)
        total_validity += 1
        if is_correct: 
            correct_validity += 1
            
        pl = gt_item.get('plausibility', None)
        if isinstance(pl, str):
            pl = pl.strip().lower()
            if pl in ("neutral", "neither", "uncertain"):
                continue
            pl = True if pl in ("plausible", "true", "yes") else False
        
        if isinstance(pl, bool):
            key = (t_val, pl)
            subgroups[key][1] += 1 
            if is_correct:
                subgroups[key][0] += 1 

    accuracy = (correct_validity / total_validity * 100) if total_validity > 0 else 0.0

    def get_acc(v, p):
        corr, tot = subgroups[(v, p)]
        return (corr / tot * 100) if tot > 0 else 0.0

    c_intra = (abs(get_acc(True, True) - get_acc(True, False)) +
               abs(get_acc(False, True) - get_acc(False, False))) / 2.0
    c_inter = (abs(get_acc(True, True) - get_acc(False, True)) +
               abs(get_acc(True, False) - get_acc(False, False))) / 2.0

    tot_content_effect = max(0.0, (c_intra + c_inter) / 2.0)
    log_penalty = math.log(1 + tot_content_effect)
    
    overall_performance = (accuracy + f1_premises) / 2.0
    ranking_score = overall_performance / (1 + log_penalty)

    return {
        "validity_accuracy": round(accuracy, 4),
        "premise_f1": round(f1_premises, 4),
        "overall_performance": round(overall_performance, 4),
        "total_bias": round(tot_content_effect, 4),
        "ranking_score": round(ranking_score, 4)
    }

def kl_symmetric(prob_p, prob_q):
    """Symmetric KL for binary probs shaped [B, 2]."""
    log_p = torch.log(prob_p + 1e-8)
    log_q = torch.log(prob_q + 1e-8)
    kl_pq = torch.sum(prob_p * (log_p - log_q), dim=-1)
    kl_qp = torch.sum(prob_q * (log_q - log_p), dim=-1)
    return 0.5 * (kl_pq + kl_qp)



def asymmetric_kl_forward(student_probs, teacher_probs):
    """
    Forward KL: KL(student || teacher)
    Student (complex variant) learns from teacher (simple variant).
    
    Args:
        student_probs: [B, 2] probability distribution from complex syllogism
        teacher_probs: [B, 2] probability distribution from simple syllogism (more reliable)
    
    Returns:
        [B] KL divergence values
    """
    eps = 1e-8
    
    student_probs = torch.clamp(student_probs, min=eps, max=1.0 - eps)
    teacher_probs = torch.clamp(teacher_probs, min=eps, max=1.0 - eps)
    
    log_student = torch.log(student_probs)
    log_teacher = torch.log(teacher_probs)
    
    kl = torch.sum(student_probs * (log_student - log_teacher), dim=-1)
    
    return kl

def train_engine(model, train_loader, val_loader, epochs, lr, device, 
                 patience=3, min_delta=0.001, save_dir=MODEL_SAVE_DIR, 
                 bias_lambda=1.0, fp16=True):
    torch.cuda.empty_cache()
    gc.collect()

    backbone_params = list(model.backbone.named_parameters())
    head_params = list(model.validity_head.named_parameters()) + list(model.selection_head.named_parameters())
    
    no_decay = ["bias", "LayerNorm.weight"]
    
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in backbone_params if not any(nd in n for nd in no_decay)],
            "weight_decay": WEIGHT_DECAY, 
            "lr": lr
        },
        {
            "params": [p for n, p in backbone_params if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0, 
            "lr": lr
        },
        {
            "params": [p for n, p in head_params if not any(nd in n for nd in no_decay)],
            "weight_decay": WEIGHT_DECAY, 
            "lr": lr * 10  
        },
        {
            "params": [p for n, p in head_params if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0, 
            "lr": lr * 10
        }
    ]
    
    optimizer = AdamW(optimizer_grouped_parameters)
    num_update_steps_per_epoch = len(train_loader) // ACCUMULATION_STEPS
    if len(train_loader) % ACCUMULATION_STEPS != 0:
        num_update_steps_per_epoch += 1
        
    total_training_steps = num_update_steps_per_epoch * epochs
    num_warmup_steps = int(total_training_steps * WARMUP_RATIO)
    
    print(f" Total Steps: {total_training_steps} | Warmup Steps: {num_warmup_steps} (Ratio: {WARMUP_RATIO})")
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=num_warmup_steps, 
        num_training_steps=total_training_steps
    )
    early_stopping = EarlyStopping(patience=patience, min_delta=min_delta, mode='max')

    scaler = GradScaler()

    print(f" Training on {device}")
    val_ground_truth = val_loader.dataset.data
    best_ranking_score = 0.0

    MAX_LAMBDA = bias_lambda

    num_groups = 6
    group_weights = torch.ones(num_groups, device=device) / num_groups
    group_lr_lst = [0.0, 0.01, 0.02, 0.05, 0.07, 0.1, 0.1, 0.1, 0.1, 0.1]
    consistency_lambda_lst = [0.1, 0.1, 0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0]

    for epoch in range(epochs):
        if epoch < 1:
            current_lambda = 0.0      
        elif epoch == 1:
            current_lambda = 0.5      
        elif epoch == 2:
            current_lambda = 1.0
        else:
            current_lambda = MAX_LAMBDA

        if epoch < len(group_lr_lst):
            group_lr = group_lr_lst[epoch]
            CONSISTENCY_LAMBDA = consistency_lambda_lst[epoch]
        else:
            group_lr = group_lr_lst[-1]
            CONSISTENCY_LAMBDA = consistency_lambda_lst[-1]

        
        start_time = time.time()

        model.train()
        total_train_loss = 0
        total_train_bias = 0
        total_train_cons = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")

        running_group_losses = torch.zeros(num_groups, device=device)
        running_group_counts = torch.zeros(num_groups, device=device)

        for step, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            premise_spans = batch['premise_spans'].to(device)
            premise_mask = batch['premise_mask'].to(device)
            validity_labels = batch['validity_label'].to(device)
            selection_labels = batch['selection_label'].to(device)
            plausibility_labels = batch['plausibility_label'].to(device)

            alt_input_ids = batch['alt_input_ids'].to(device)
            alt_mask = batch['alt_attention_mask'].to(device)
            alt_premise_spans = batch['alt_premise_spans'].to(device)
            alt_premise_mask = batch['alt_premise_mask'].to(device)

            batch_size = input_ids.size(0)

            if step % ACCUMULATION_STEPS == 0:
                optimizer.zero_grad()

            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=fp16):
                outputs_main = model(
                    input_ids=input_ids,  
                    attention_mask=mask,
                    premise_spans=premise_spans,
                    premise_mask=premise_mask
                )
                
                outputs_alt = model(
                    input_ids=alt_input_ids,
                    attention_mask=alt_mask,
                    premise_spans=alt_premise_spans,
                    premise_mask=alt_premise_mask
                )
                
                validity_logits = outputs_main['validity_logits']
                alt_validity_logits = outputs_alt['validity_logits']
                selection_logits = outputs_main['selection_logits']
                alt_selection_logits = outputs_alt['selection_logits']


                bce_fct = nn.BCEWithLogitsLoss(reduction='none')
                
                raw_val_loss_main = bce_fct(validity_logits, validity_labels)
                raw_val_loss_alt = bce_fct(alt_validity_logits, validity_labels)
                raw_val_loss = (raw_val_loss_main + raw_val_loss_alt) / 2.0

                group_ids = get_group_ids(validity_labels, plausibility_labels)
                group_losses = torch.zeros(num_groups, device=device)

                for g in range(num_groups):
                    mask_group = group_ids == g
                    if mask_group.any():
                        group_losses[g] = raw_val_loss[mask_group].mean()
                        
                        running_group_losses[g] += raw_val_loss[mask_group].sum().detach()
                        running_group_counts[g] += mask_group.sum().detach()

                loss_fct = nn.BCEWithLogitsLoss(reduction='none')
                sel_loss_raw_main = loss_fct(selection_logits, selection_labels)
                sel_loss_main = (sel_loss_raw_main * premise_mask).sum() / (premise_mask.sum() + 1e-9)
                
                sel_loss_raw_alt = loss_fct(alt_selection_logits, selection_labels)
                sel_loss_alt = (sel_loss_raw_alt * alt_premise_mask).sum() / (alt_premise_mask.sum() + 1e-9)

                sel_loss = (sel_loss_main + sel_loss_alt) / 2.0


                val_loss = torch.sum(group_weights * group_losses)

                bias_penalty = torch.tensor(0.0, device=device)

                probs_main = torch.sigmoid(validity_logits.clamp(min=-15, max=15))
                probs_alt = torch.sigmoid(alt_validity_logits.clamp(min=-15, max=15))
                eps = 1e-7
                probs_main = torch.clamp(probs_main, min=eps, max=1.0 - eps)
                probs_alt = torch.clamp(probs_alt, min=eps, max=1.0 - eps)
                probs = (probs_main + probs_alt) / 2.0

                confidences = torch.where(validity_labels > 0.5, probs, 1.0 - probs)


                intra_diffs = []
                valid_plaus_mask = plausibility_labels != 1
                unique_plaus = torch.unique(plausibility_labels[valid_plaus_mask])

                for p in unique_plaus:
                    mask_p = (plausibility_labels == p) & valid_plaus_mask
                    lbls_p = validity_labels[mask_p]
                    confs_p = confidences[mask_p]

                    if (lbls_p > 0.5).any() and (lbls_p < 0.5).any():
                        mean_conf_valid = confs_p[lbls_p > 0.5].mean()
                        mean_conf_invalid = confs_p[lbls_p < 0.5].mean()
                        diff = F.smooth_l1_loss(mean_conf_valid, mean_conf_invalid, beta=0.1)
                        if not (torch.isnan(diff) or torch.isinf(diff)):
                            intra_diffs.append(diff)

                intra_loss = torch.stack(intra_diffs).mean() if intra_diffs else torch.tensor(0.0, device=device)

                cross_diffs = []
                unique_val = torch.unique(validity_labels)

                for v in unique_val:
                    valid_mask = (plausibility_labels == 0) | (plausibility_labels == 2)
                    mask_v = (torch.abs(validity_labels - v) < 1e-5) & valid_mask

                    lbls_plaus = plausibility_labels[mask_v]
                    confs_v = confidences[mask_v]

                    if (lbls_plaus == 2).any() and (lbls_plaus == 0).any():
                        mean_conf_plaus = confs_v[lbls_plaus == 2].mean()
                        mean_conf_implaus = confs_v[lbls_plaus == 0].mean()
                        diff = F.smooth_l1_loss(mean_conf_plaus, mean_conf_implaus, beta=0.1)
                        if not (torch.isnan(diff) or torch.isinf(diff)):
                            cross_diffs.append(diff)

                cross_loss = torch.stack(cross_diffs).mean() if cross_diffs else torch.tensor(0.0, device=device)

                if intra_diffs and cross_diffs:
                    bias_penalty = (intra_loss + cross_loss) / 2.0
                else:
                    bias_penalty = torch.tensor(0.0, device=device)

                probs_simple_dist = torch.stack([1.0 - probs_main, probs_main], dim=1)
                probs_complex_dist = torch.stack([1.0 - probs_alt, probs_alt], dim=1)
                
                consistency_loss = asymmetric_kl_forward(
                    student_probs=probs_complex_dist,
                    teacher_probs=probs_simple_dist
                ).mean()

                total_step_loss = (val_loss + 
                                  (1.0 * sel_loss) + 
                                  (current_lambda * bias_penalty) + 
                                  (CONSISTENCY_LAMBDA * consistency_loss))
                
                loss = total_step_loss / ACCUMULATION_STEPS

            scaler.scale(loss).backward()


            if (step + 1) % ACCUMULATION_STEPS == 0 or (step + 1) == len(train_loader):
                if epoch >= DRO_WARMUP_EPOCHS:
                    with torch.no_grad():
                        active = running_group_counts > 0
                        avg_group_losses = torch.zeros_like(running_group_losses)
                        avg_group_losses[active] = running_group_losses[active] / running_group_counts[active]
                        
                        update_factor = torch.exp(group_lr * avg_group_losses)
                        group_weights *= update_factor
                        group_weights /= group_weights.sum()
                        group_weights = torch.clamp(group_weights, min=0.01, max=0.8)
                        group_weights /= group_weights.sum()

                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)
                
                running_group_losses.zero_()
                running_group_counts.zero_()

            total_train_loss += loss.item() * ACCUMULATION_STEPS
            total_train_bias += bias_penalty.item()
            total_train_cons += consistency_loss.item()
            progress_bar.set_postfix({
                'loss': f"{loss.item() * ACCUMULATION_STEPS:.4f}", 
                'bias_pen': f"{bias_penalty.item():.4f}",
                'cons': f"{consistency_loss.item():.4f}",
                'lr': f"{scheduler.get_last_lr()[0]:.2e}"
            })

        avg_train_loss = total_train_loss / len(train_loader)
        avg_train_bias = total_train_bias / len(train_loader)
        avg_train_cons = total_train_cons / len(train_loader)

        torch.cuda.empty_cache()
        gc.collect()

        model.eval()
        total_val_loss = 0
        total_val_bias = 0
        total_val_cons = 0
        val_predictions = []
        global_idx = 0

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                mask = batch['attention_mask'].to(device)
                premise_spans = batch['premise_spans'].to(device)
                premise_mask = batch['premise_mask'].to(device)
                validity_labels = batch['validity_label'].to(device)
                selection_labels = batch['selection_label'].to(device)
                plausibility_labels = batch['plausibility_label'].to(device)

                alt_input_ids = batch['alt_input_ids'].to(device)
                alt_mask = batch['alt_attention_mask'].to(device)
                alt_premise_spans = batch['alt_premise_spans'].to(device)
                alt_premise_mask = batch['alt_premise_mask'].to(device)

                batch_size = input_ids.size(0)

                with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=fp16):
                    outputs_main = model(input_ids, mask, premise_spans, premise_mask)
                    outputs_alt = model(alt_input_ids, alt_mask, alt_premise_spans, alt_premise_mask)

                    validity_logits = outputs_main['validity_logits']
                    alt_validity_logits = outputs_alt['validity_logits']
                    selection_logits = outputs_main['selection_logits']
                    alt_selection_logits = outputs_alt['selection_logits']

                    bce_fct = nn.BCEWithLogitsLoss(reduction='none')
                    raw_val_loss_main = bce_fct(validity_logits, validity_labels)
                    raw_val_loss_alt = bce_fct(alt_validity_logits, validity_labels)
                    raw_val_loss = (raw_val_loss_main + raw_val_loss_alt) / 2.0


                    group_ids = get_group_ids(validity_labels, plausibility_labels)
                    group_losses = torch.zeros(num_groups, device=device)

                    for g in range(num_groups):
                        mask_group = group_ids == g
                        if mask_group.any():
                            group_losses[g] = raw_val_loss[mask_group].mean()


                    loss_fct = nn.BCEWithLogitsLoss(reduction='none')
                    s_loss_raw_main = loss_fct(selection_logits, selection_labels)
                    s_loss_main = (s_loss_raw_main * premise_mask).sum() / (premise_mask.sum() + 1e-9)
                    s_loss_raw_alt = loss_fct(alt_selection_logits, selection_labels)
                    s_loss_alt = (s_loss_raw_alt * alt_premise_mask).sum() / (alt_premise_mask.sum() + 1e-9)
                    s_loss = (s_loss_main + s_loss_alt) / 2.0


                    probs_main = torch.sigmoid(validity_logits.clamp(min=-15, max=15))
                    probs_alt = torch.sigmoid(alt_validity_logits.clamp(min=-15, max=15))
                    eps = 1e-7
                    probs_main = torch.clamp(probs_main, min=eps, max=1.0 - eps)
                    probs_alt = torch.clamp(probs_alt, min=eps, max=1.0 - eps)
                    probs = (probs_main + probs_alt) / 2.0
                    confidences = torch.where(validity_labels > 0.5, probs, 1.0 - probs)

                    bias_penalty = torch.tensor(0.0, device=device)

                    intra_diffs = []
                    valid_plaus_mask = plausibility_labels != 1  
                    unique_plaus = torch.unique(plausibility_labels[valid_plaus_mask])

                    for p in unique_plaus:
                        mask_p = (plausibility_labels == p) & valid_plaus_mask
                        lbls_p = validity_labels[mask_p]
                        confs_p = confidences[mask_p]
                        if (lbls_p > 0.5).any() and (lbls_p < 0.5).any():
                            diff = F.smooth_l1_loss(confs_p[lbls_p > 0.5].mean(), confs_p[lbls_p < 0.5].mean(), beta=0.1)
                            if not (torch.isnan(diff) or torch.isinf(diff)):
                                intra_diffs.append(diff)
                    intra_loss = torch.stack(intra_diffs).mean() if intra_diffs else torch.tensor(0.0, device=device)

                    cross_diffs = []
                    unique_val = torch.unique(validity_labels)

                    for v in unique_val:
                        valid_mask = (plausibility_labels == 0) | (plausibility_labels == 2)
                        mask_v = (torch.abs(validity_labels - v) < 1e-5) & valid_mask

                        lbls_plaus = plausibility_labels[mask_v]
                        confs_v = confidences[mask_v]

                        if (lbls_plaus == 2).any() and (lbls_plaus == 0).any():
                            mean_conf_plaus = confs_v[lbls_plaus == 2].mean()
                            mean_conf_implaus = confs_v[lbls_plaus == 0].mean()
                            diff = F.smooth_l1_loss(mean_conf_plaus, mean_conf_implaus, beta=0.1)
                            if not (torch.isnan(diff) or torch.isinf(diff)):
                                cross_diffs.append(diff)

                    cross_loss = torch.stack(cross_diffs).mean() if cross_diffs else torch.tensor(0.0, device=device)

                    if intra_diffs and cross_diffs:
                        bias_penalty = (intra_loss + cross_loss) / 2.0
                    else:
                        bias_penalty = torch.tensor(0.0, device=device)

                    probs_simple_dist = torch.stack([1.0 - probs_main, probs_main], dim=1)
                    probs_complex_dist = torch.stack([1.0 - probs_alt, probs_alt], dim=1)
                    
                    consistency_loss = asymmetric_kl_forward(
                        student_probs=probs_complex_dist,
                        teacher_probs=probs_simple_dist
                    ).mean()

                    v_loss = torch.sum(group_weights * group_losses)
                    batch_val_loss = v_loss + (1.0 * s_loss) + (current_lambda * bias_penalty) + (CONSISTENCY_LAMBDA * consistency_loss)
                    
                    total_val_loss += batch_val_loss.item()
                    total_val_bias += bias_penalty.item()
                    total_val_cons += consistency_loss.item()

                val_probs = torch.sigmoid(validity_logits)
                sel_probs = torch.sigmoid(selection_logits)

                for i in range(len(val_probs)):
                    current_id = val_ground_truth[global_idx]['id']
                    global_idx += 1
                    p_mask = premise_mask[i].bool()
                    real_s_probs = sel_probs[i][p_mask]

                    result = {'id': current_id}
                    result['validity_pred'] = bool((val_probs[i] >= 0.5).item())

                    if result['validity_pred']:
                        k = min(2, len(real_s_probs))
                        if k > 0:
                            top_indices = torch.argsort(real_s_probs, descending=False)[-k:]
                            result['premises_pred'] = sorted(top_indices.tolist())
                        else:
                            result['premises_pred'] = []
                    else:
                        result['premises_pred'] = []

                    val_predictions.append(result)

        avg_val_loss = total_val_loss / len(val_loader)
        avg_val_bias = total_val_bias / len(val_loader)
        avg_val_cons = total_val_cons / len(val_loader)

        metrics = compute_semeval_metrics(val_predictions, val_ground_truth)
        current_ranking_score = metrics['ranking_score']

        epoch_time = time.time() - start_time

        print(f"\nEpoch {epoch + 1}/{epochs} Completed in {epoch_time:.0f}s")

        print(f"Loss  (Total): Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}")
        
        print(f"Bias Penalty : Train: {avg_train_bias:.4f} | Val: {avg_val_bias:.4f}")
        
        print(f"Consistency : Train: {avg_train_cons:.4f} | Val: {avg_val_cons:.4f}")

        print(f"Val Acc: {metrics['validity_accuracy']:.2f}% | Premise F1: {metrics['premise_f1']:.2f}%")
        print(f"Metric Bias: {metrics['total_bias']:.4f} (Calculated by SemEval script)")
        print(f"Ranking Score: {metrics['ranking_score']:.4f}")

        if current_ranking_score > best_ranking_score:
            best_ranking_score = current_ranking_score
            save_model_for_inference(model, train_loader.dataset.tokenizer, save_dir,
                                     {'model_name': MODEL_NAME, 'max_len': MAX_LEN, 'buffer_size': BUFFER_SIZE})
            print("    Checkpoint Saved!")

        if early_stopping(current_ranking_score):
            print(f"\n Early Stopping triggered.")
            break

        torch.cuda.empty_cache()
        gc.collect()

    return model

def predict(model, dataloader, device):
    model.eval()
    predictions = []
    raw_data = dataloader.dataset.data
    global_idx = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Predicting"):
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            premise_spans = batch['premise_spans'].to(device)
            premise_mask = batch['premise_mask'].to(device)

            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=(FP16 and device.type == 'cuda')):
                outputs = model(input_ids, mask, premise_spans, premise_mask)

            val_probs = torch.sigmoid(outputs['validity_logits'])
            sel_probs = torch.sigmoid(outputs['selection_logits'])

            for i in range(len(val_probs)):
                current_id = raw_data[global_idx]['id']
                global_idx += 1
                p_mask = premise_mask[i].bool()
                real_s_probs = sel_probs[i][p_mask]

                result = {'id': current_id}
                result['validity_pred'] = bool((val_probs[i] >= 0.5).item())

                if result['validity_pred']:
                    k = min(2, len(real_s_probs))
                    if k > 0:
                        top_indices = torch.argsort(real_s_probs, descending=False)[-k:]
                        result['premises_pred'] = sorted(top_indices.tolist())
                    else:
                        result['premises_pred'] = []
                else:
                    result['premises_pred'] = []

                predictions.append(result)
    return predictions

def compute_group_id(item):
    validity = int(item['validity'])
    pl = item.get('plausibility', 1)
    if isinstance(pl, bool):
        pl = 2 if pl else 0
    elif isinstance(pl, str):
        pl = 1
    return validity * 3 + pl

class StratifiedBatchSampler(Sampler):
    def __init__(self, group_ids, batch_size, num_groups=6):
        self.group_ids = np.array(group_ids)
        self.batch_size = batch_size
        self.num_groups = num_groups
        
        self.group_indices = [np.where(self.group_ids == i)[0] for i in range(num_groups)]
        
        self.samples_per_group = max(1, batch_size // num_groups)
        
        self.num_batches = max(len(inds) // self.samples_per_group for inds in self.group_indices if len(inds) > 0)
        
        self.all_indices = np.arange(len(group_ids))

    def __iter__(self):
        for inds in self.group_indices:
            np.random.shuffle(inds)
            
        group_iters = [itertools.cycle(inds) for inds in self.group_indices if len(inds) > 0]
        
        for _ in range(self.num_batches):
            batch = []
            
            for g_it in group_iters:
                for _ in range(self.samples_per_group):
                    batch.append(next(g_it))
            
            remaining_slots = self.batch_size - len(batch)
            if remaining_slots > 0:
                extra_indices = np.random.choice(self.all_indices, remaining_slots, replace=False)
                batch.extend(extra_indices.tolist())
            
            np.random.shuffle(batch)
            yield batch

    def __len__(self):
        return self.num_batches


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    if device.type == 'cuda':
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    try:
        with open(JSON_FILE_PATH, 'r') as f:
            aug_pool = json.load(f) 
        print(f" Loaded {len(aug_pool)} augmented samples.")

        with open(GOLD_FILE_PATH, 'r') as f:
            gold_pool = json.load(f)
        print(f" Loaded {len(gold_pool)} gold samples.")
    except Exception as e:
        print(f" Error loading data: {e}")
        exit()

    random.seed(42)
    
    gold_stratify = [f"{x['validity']}_{x.get('plausibility', 'neutral')}" for x in gold_pool]
    
    train_gold, temp_gold, _, temp_gold_labels = train_test_split(
        gold_pool, gold_stratify, test_size=0.20, random_state=42, stratify=gold_stratify
    )
    val_gold, test_gold = train_test_split(
        temp_gold, test_size=0.5, random_state=42, stratify=temp_gold_labels
    )

    aug_stratify = [f"{x['validity']}_{x.get('plausibility', 'neutral')}" for x in aug_pool]
    
    train_aug, temp_aug, _, temp_aug_labels = train_test_split(
        aug_pool, aug_stratify, test_size=0.20, random_state=42, stratify=aug_stratify
    )
    val_aug, test_aug = train_test_split(
        temp_aug, test_size=0.5, random_state=42, stratify=temp_aug_labels
    )

    train_data = train_gold + train_aug
    val_data = val_gold + val_aug
    test_data = test_gold + test_aug


    random.shuffle(train_data) 
    random.shuffle(val_data)
    random.shuffle(test_data)

    train_group_ids = [compute_group_id(x) for x in train_data]
    sampler = StratifiedBatchSampler(train_group_ids, BATCH_SIZE)


    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    train_ds = PreparingInput(train_data, tokenizer, max_len=MAX_LEN, buffer_size=BUFFER_SIZE)
    val_ds = PreparingInput(val_data, tokenizer, max_len=MAX_LEN, buffer_size=BUFFER_SIZE)
    test_ds = PreparingInput(test_data, tokenizer, max_len=MAX_LEN, buffer_size=BUFFER_SIZE)

    

    train_loader = DataLoader(
        train_ds,
        batch_sampler=sampler,
        num_workers=NUM_WORKERS,
        pin_memory=True
    )
    val_loader = DataLoader(val_ds, batch_size=EVAL_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=EVAL_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    model = VariableSyllogismModel(MODEL_NAME, max_len=MAX_LEN).to(device)

    trained_model = train_engine(
        model, train_loader, val_loader,
        epochs=EPOCHS, lr=LEARNING_RATE, device=device,
        patience=PATIENCE, min_delta=MIN_DELTA, save_dir=MODEL_SAVE_DIR,
        bias_lambda=2.0,
        fp16=FP16
    )

    print("\n Loading best model checkpoint for test predictions...")
    best_model, _, _ = load_model_for_inference(MODEL_SAVE_DIR, device=device)
    

    print("\n Predicting on Test Set...")
    test_predictions = predict(best_model, test_loader, device)
    metrics = compute_semeval_metrics(test_predictions, test_data)

    print("\n" + "="*40)
    print(" FINAL TEST RESULTS")
    print("="*40)
    print(f"Validity Acc: {metrics['validity_accuracy']:.2f}%")
    print(f"Premise F1:   {metrics['premise_f1']:.2f}%")
    print(f"Bias:         {metrics['total_bias']:.4f}")
    print(f"Ranking Score: {metrics['ranking_score']:.4f}")

    pred_path = os.path.join(OUTPUT_DIR, "train_predictions.json")
    with open(pred_path, "w") as f:
        json.dump(test_predictions, f, indent=4)
    print(f"\n Predictions saved to {pred_path}")

    gt_path = os.path.join(OUTPUT_DIR, "train_ground_truth.json")
    with open(gt_path, "w") as f:
        json.dump(test_data, f, indent=4)
    print(f" Ground Truth saved to {gt_path}")

    print("\n Zipping model files for download...")
    shutil.make_archive(os.path.join(OUTPUT_DIR, "subtask_2_trained_model"), 'zip', MODEL_SAVE_DIR)
    print(" Created 'subtask_2_trained_model_pack.zip'")


Using device: cuda
GPU: Tesla P100-PCIE-16GB
GPU Memory: 17.06 GB
 Loaded 27028 augmented samples.
 Loaded 949 gold samples.


tokenizer_config.json:   0%|          | 0.00/395 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/18.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

2026-02-15 18:07:40.147072: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1771178860.335917      24 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1771178860.386567      24 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1771178860.828457      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771178860.828501      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771178860.828504      24 computation_placer.cc:177] computation placer alr

model.safetensors:   0%|          | 0.00/870M [00:00<?, ?B/s]

 Total Steps: 3800 | Warmup Steps: 228 (Ratio: 0.06)
 Training on cuda


  scaler = GradScaler()


Epoch 1/10:   0%|          | 0/3795 [00:00<?, ?it/s]


Epoch 1/10 Completed in 4687s
Loss  (Total): Train: 0.7887 | Val: 0.2889
Bias Penalty : Train: 0.1279 | Val: 0.0709
Consistency : Train: 0.0764 | Val: 0.1323
Val Acc: 94.46% | Premise F1: 97.71%
Metric Bias: 5.6962 (Calculated by SemEval script)
Ranking Score: 33.1160
 Full Model (Backbone + Heads) saved to '/kaggle/working/trained_model/'
    Checkpoint Saved!


Epoch 2/10:   0%|          | 0/3795 [00:00<?, ?it/s]


Epoch 2/10 Completed in 4696s
Loss  (Total): Train: 0.1836 | Val: 0.0969
Bias Penalty : Train: 0.0529 | Val: 0.0170
Consistency : Train: 0.1069 | Val: 0.0843
Val Acc: 98.68% | Premise F1: 99.82%
Metric Bias: 1.5825 (Calculated by SemEval script)
Ranking Score: 50.9296
 Full Model (Backbone + Heads) saved to '/kaggle/working/trained_model/'
    Checkpoint Saved!


Epoch 3/10:   0%|          | 0/3795 [00:00<?, ?it/s]


Epoch 3/10 Completed in 4694s
Loss  (Total): Train: 0.0806 | Val: 0.0524
Bias Penalty : Train: 0.0130 | Val: 0.0049
Consistency : Train: 0.0458 | Val: 0.0311
Val Acc: 99.68% | Premise F1: 99.86%
Metric Bias: 0.5274 (Calculated by SemEval script)
Ranking Score: 70.0821
 Full Model (Backbone + Heads) saved to '/kaggle/working/trained_model/'
    Checkpoint Saved!


Epoch 4/10:   0%|          | 0/3795 [00:00<?, ?it/s]


Epoch 4/10 Completed in 4674s
Loss  (Total): Train: nan | Val: nan
Bias Penalty : Train: 0.0065 | Val: 0.0057
Consistency : Train: nan | Val: nan
Val Acc: 99.46% | Premise F1: 99.93%
Metric Bias: 0.4219 (Calculated by SemEval script)
Ranking Score: 73.7386
 Full Model (Backbone + Heads) saved to '/kaggle/working/trained_model/'
    Checkpoint Saved!


Epoch 5/10:   0%|          | 0/3795 [00:00<?, ?it/s]


Epoch 5/10 Completed in 4646s
Loss  (Total): Train: nan | Val: nan
Bias Penalty : Train: 0.0000 | Val: 0.0000
Consistency : Train: nan | Val: nan
Val Acc: 49.96% | Premise F1: 0.00%
Metric Bias: 50.0000 (Calculated by SemEval script)
Ranking Score: 5.0655


Epoch 6/10:   0%|          | 0/3795 [00:00<?, ?it/s]


Epoch 6/10 Completed in 4647s
Loss  (Total): Train: nan | Val: nan
Bias Penalty : Train: 0.0000 | Val: 0.0000
Consistency : Train: nan | Val: nan
Val Acc: 49.96% | Premise F1: 0.00%
Metric Bias: 50.0000 (Calculated by SemEval script)
Ranking Score: 5.0655

 Early Stopping triggered.

 Loading best model checkpoint for test predictions...

 Predicting on Test Set...


Predicting:   0%|          | 0/467 [00:00<?, ?it/s]


 FINAL TEST RESULTS
Validity Acc: 99.64%
Premise F1:   100.00%
Bias:         0.5274
Ranking Score: 70.1197

 Predictions saved to /kaggle/working/train_predictions.json
 Ground Truth saved to /kaggle/working/train_ground_truth.json

 Zipping model files for download...
 Created 'subtask_2_trained_model_pack.zip'


In [2]:
print("\n" + "="*40)
print(" RUNNING INFERENCE ON UNLABELED TEST DATA")
print("="*40)

try:
    with open("/kaggle/input/datasets/tahazahid001/test-data-subtask-2/test_data_subtask_2.json", "r") as f:
        unlabeled_test_data = json.load(f)
    print(f" Loaded {len(unlabeled_test_data)} unlabeled test samples.")
except Exception as e:
    print(f" Could not load unlabeled test data: {e}")
    unlabeled_test_data = None

if unlabeled_test_data:
    predictions = []
    split_pattern = r'[\.\u3002\u0964]+'
    
    for i in tqdm(range(0, len(unlabeled_test_data), EVAL_BATCH_SIZE), desc="Inference"):
        batch_items = unlabeled_test_data[i : i + EVAL_BATCH_SIZE]
        ids = [item['id'] for item in batch_items]
        
        b_input_ids, b_masks, b_spans, b_pmasks = [], [], [], []
        
        for item in batch_items:
            raw_text = item['syllogism']
            sentences = [s.strip() for s in re.split(split_pattern, raw_text) if s.strip()]
            
            if not sentences:
                sentences = ["Empty"]
            
            conclusion = sentences[-1]
            premises = sentences[:-1]
            premises_text = tokenizer.sep_token.join(premises)

            encoding = tokenizer(
                conclusion,
                premises_text,
                truncation=True,
                max_length=MAX_LEN,
                padding="max_length",
                return_tensors="pt"
            )

            input_ids = encoding['input_ids'].squeeze(0)
            attention_mask = encoding['attention_mask'].squeeze(0)

            premise_spans = torch.zeros(BUFFER_SIZE, 2, dtype=torch.long)
            premise_mask = torch.zeros(BUFFER_SIZE, dtype=torch.float)
            
            sep_id = tokenizer.sep_token_id
            all_sep_indices = (input_ids == sep_id).nonzero(as_tuple=True)[0]
        
            if len(all_sep_indices) > 1:
                sep_positions = all_sep_indices.tolist()
                segment_starts = sep_positions[:-1]
                segment_ends = sep_positions[1:]

                num_premises = len(segment_starts)
                limit = min(num_premises, BUFFER_SIZE)
        
                for j in range(limit):
                    start_pos = segment_starts[j] + 1
                    end_pos = segment_ends[j] - 1
                    
                    if end_pos >= start_pos:
                        premise_spans[j, 0] = start_pos
                        premise_spans[j, 1] = end_pos
                        premise_mask[j] = 1.0
        
            b_input_ids.append(input_ids)
            b_masks.append(attention_mask)
            b_spans.append(premise_spans)
            b_pmasks.append(premise_mask)
        
        input_ids = torch.stack(b_input_ids).to(device)
        mask = torch.stack(b_masks).to(device)
        premise_spans = torch.stack(b_spans).to(device)
        premise_mask = torch.stack(b_pmasks).to(device)
        
        with torch.no_grad():
            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=(FP16 and device.type == "cuda")):
                outputs = best_model(input_ids, mask, premise_spans, premise_mask)

            val_logits = outputs['validity_logits']
            sel_logits = outputs['selection_logits']
            
            val_probs = torch.sigmoid(val_logits).cpu().numpy()
            sel_probs = torch.sigmoid(sel_logits).cpu().numpy()
        
        for idx in range(len(batch_items)):
            pred_label = bool(val_probs[idx] >= 0.5)  
            
            pred_indices = []
            if pred_label:
                p_mask = premise_mask[idx].bool().cpu().numpy()
                valid_sel_probs = sel_probs[idx][p_mask]
                
                k = min(2, len(valid_sel_probs))
                if k > 0:
                    top_k_indices = np.argsort(valid_sel_probs)[-k:]
                    pred_indices = sorted(top_k_indices.tolist())
                
            predictions.append({
                "id": ids[idx],
                "validity": pred_label,
                "relevant_premises": pred_indices
            })

    final_pred_path = os.path.join(OUTPUT_DIR, "predictions.json")
    with open(final_pred_path, "w") as f:
        json.dump(predictions, f, indent=4)

    print(f"\n Saved {len(predictions)} predictions to {final_pred_path}")


 RUNNING INFERENCE ON UNLABELED TEST DATA
 Loaded 192 unlabeled test samples.


Inference:   0%|          | 0/32 [00:00<?, ?it/s]


 Saved 192 predictions to /kaggle/working/predictions.json
