In [None]:
!pip install -q transformers sentencepiece scikit-learn

import re
import json
import torch
import torch.nn as nn
import torch.nn.functional as F 
import time
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
import os
import shutil
import math
from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm
import itertools
import gc
KL_TEMPERATURE = 2.0

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

MODEL_NAME = "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7"

BATCH_SIZE = 32 
EVAL_BATCH_SIZE = 32 
ACCUMULATION_STEPS = 1 
EPOCHS = 10 
LEARNING_RATE = 2e-05 
WARMUP_RATIO = 0.06 
WEIGHT_DECAY = 0.01 
NUM_WORKERS = 2
FP16 = True

DRO_WARMUP_EPOCHS = 1
MAX_LEN = 96
PATIENCE = 2
MIN_DELTA = 0.001

OUTPUT_DIR = "/kaggle/working"
MODEL_SAVE_DIR = os.path.join(OUTPUT_DIR, "trained_model_task1")
JSON_FILE_PATH_AUG = "/kaggle/input/st1-exp-dataset/st1_combined.json" 
JSON_FILE_PATH_GOLD = "/kaggle/input/task-1-dataset-v3/gold.json"


def safe_mean(tensor, eps=1e-8):
 """Compute mean with NaN safety check."""
 if tensor.numel() == 0:
 return torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype)
 result = tensor.mean()
 if torch.isnan(result) or torch.isinf(result):
 print(f" WARNING: NaN/Inf detected in mean computation. Returning 0.")
 return torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype)
 return result

def safe_division(numerator, denominator, eps=1e-8):
 """Safely divide tensors with epsilon for numerical stability."""
 result = numerator / (denominator + eps)
 if torch.isnan(result).any() or torch.isinf(result).any():
 print(f" WARNING: NaN/Inf in division. Numerator: {numerator}, Denominator: {denominator}")
 return torch.zeros_like(result)
 return result

def check_tensor_health(tensor, name="tensor", verbose=False):
 """Check if tensor contains NaN or Inf values."""
 has_nan = torch.isnan(tensor).any()
 has_inf = torch.isinf(tensor).any()
 
 if has_nan or has_inf:
 print(f" ALERT: {name} contains NaN: {has_nan}, Inf: {has_inf}")
 if verbose:
 print(f" Stats - Min: {tensor.min()}, Max: {tensor.max()}, Mean: {tensor.mean()}")
 return False
 return True


class SyllogismDataset(Dataset):
 def __init__(self, data, tokenizer, max_len=256):
 self.data = data
 self.tokenizer = tokenizer
 self.max_len = max_len
 self.split_pattern = r'[\.\u3002\u0964]+'
 
 @staticmethod
 def encode_plausibility_index(p):
 if isinstance(p, bool):
 return 2 if p else 0
 if isinstance(p, str):
 t = p.strip().lower()
 if t in ("neutral", "neither", "uncertain"):
 return 1
 if t in ("true", "plausible", "yes"):
 return 2
 if t in ("false", "implausible", "no"):
 return 0
 return 1 
 
 @staticmethod
 def encode_group_id(validity, plausibility):
 v_idx = 1 if bool(validity) else 0
 p_idx = SyllogismDataset.encode_plausibility_index(plausibility)
 return v_idx * 3 + p_idx 
 
 def __len__(self):
 return len(self.data)

 def process_syllogism(self, raw_text):
 sentences = [s.strip() for s in re.split(self.split_pattern, raw_text) if s.strip()]
 if not sentences: sentences = ["Empty"]
 
 conclusion = sentences[-1]
 premises = sentences[:-1]
 premises_text = self.tokenizer.sep_token.join(premises)

 encoding = self.tokenizer(
 conclusion,
 premises_text,
 truncation=True,
 max_length=self.max_len,
 padding="max_length",
 return_tensors="pt"
 )

 input_ids = encoding['input_ids'].squeeze(0)
 attention_mask = encoding['attention_mask'].squeeze(0)

 return input_ids, attention_mask
 
 
 def __getitem__(self, idx):
 item = self.data[idx]

 if item.get("is_gold") or "syllogism" in item:
 primary_text = item.get("syllogism")
 alt_text = primary_text 
 else:
 primary_text = item.get("syllogism_simple", "")
 alt_text = item.get("syllogism_complex", primary_text)

 input_ids, attention_mask = self.process_syllogism(primary_text)

 alt_input_ids, alt_attention_mask = self.process_syllogism(alt_text)
 
 validity = 1.0 if item["validity"] else 0.0
 group_id = SyllogismDataset.encode_group_id(item["validity"], item.get("plausibility", "neutral"))
 
 return {
 "input_ids": input_ids,
 "attention_mask": attention_mask,
 'alt_input_ids': alt_input_ids,
 'alt_attention_mask': alt_attention_mask,
 "validity_label": torch.tensor(validity, dtype=torch.float),
 "group_id": torch.tensor(group_id, dtype=torch.long),
 "id": item.get("id", str(idx))
 }


class EarlyStopping:
 def __init__(self, patience=3, min_delta=0.001, mode='max'):
 self.patience = patience
 self.min_delta = min_delta
 self.mode = mode
 self.counter = 0
 self.best_score = None
 self.early_stop = False
 
 def __call__(self, score):
 if self.best_score is None:
 self.best_score = score
 return False
 
 if self.mode == 'max':
 improved = score > (self.best_score + self.min_delta)
 else:
 improved = score < (self.best_score - self.min_delta)
 
 if improved:
 self.best_score = score
 self.counter = 0
 else:
 self.counter += 1
 if self.counter >= self.patience:
 self.early_stop = True
 
 return self.early_stop

class BinarySyllogismModel(nn.Module):
 def __init__(self, model_name, dropout_rate=0.1):
 super().__init__()
 self.backbone = AutoModel.from_pretrained(model_name)

 
 hidden_size = self.backbone.config.hidden_size
 
 self.classifier = nn.Sequential(
 nn.Dropout(dropout_rate),
 nn.Linear(hidden_size, hidden_size),
 nn.Tanh(),
 nn.Linear(hidden_size, 1)
 )
 
 def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
 outputs = self.backbone(
 input_ids=input_ids, 
 attention_mask=attention_mask,
 inputs_embeds=inputs_embeds
 )
 cls_vector = outputs.last_hidden_state[:, 0, :]
 logits = self.classifier(cls_vector).squeeze(-1)
 return logits

def compute_task1_metrics(predictions, ground_truth):
 gt_map = {item['id']: item for item in ground_truth}
 
 correct = 0
 total = 0
 
 subgroups = {(True, True): [0, 0], (True, False): [0, 0], 
 (False, True): [0, 0], (False, False): [0, 0]}
 
 for pred in predictions:
 if pred['id'] not in gt_map: 
 continue
 
 item = gt_map[pred['id']]
 plaus = item.get('plausibility', False)
 
 if isinstance(plaus, str):
 plaus_lower = plaus.strip().lower()
 if plaus_lower in ("neutral", "neither", "uncertain"):
 continue 
 if plaus_lower in ("true", "plausible", "yes"):
 plaus = True
 elif plaus_lower in ("false", "implausible", "no"):
 plaus = False
 else:
 continue 
 
 pred_val = pred['validity_pred'] if 'validity_pred' in pred else pred.get('validity')
 true_val = item['validity']
 
 total += 1
 if pred_val == true_val:
 correct += 1
 
 key = (true_val, plaus)
 if key in subgroups:
 subgroups[key][1] += 1
 if pred_val == true_val:
 subgroups[key][0] += 1
 
 accuracy = (correct / total * 100) if total > 0 else 0.0
 
 def get_acc(v, p):
 corr, tot = subgroups[(v, p)]
 return (corr / tot * 100) if tot > 0 else 0.0
 
 acc_plausible_valid = get_acc(True, True)
 acc_implausible_valid = get_acc(True, False)
 acc_plausible_invalid = get_acc(False, True)
 acc_implausible_invalid = get_acc(False, False)
 
 intra_valid_diff = abs(acc_plausible_valid - acc_implausible_valid)
 intra_invalid_diff = abs(acc_plausible_invalid - acc_implausible_invalid)
 content_effect_intra = (intra_valid_diff + intra_invalid_diff) / 2.0
 
 inter_plausible_diff = abs(acc_plausible_valid - acc_plausible_invalid)
 inter_implausible_diff = abs(acc_implausible_valid - acc_implausible_invalid)
 content_effect_inter = (inter_plausible_diff + inter_implausible_diff) / 2.0
 
 tot_content_effect = (content_effect_intra + content_effect_inter) / 2.0
 
 log_penalty = math.log(1 + tot_content_effect)
 combined_score = accuracy / (1 + log_penalty)
 
 return {
 "accuracy": accuracy,
 "content_effect_intra": content_effect_intra,
 "content_effect_inter": content_effect_inter,
 "total_bias": tot_content_effect,
 "ranking_score": combined_score
 }

def save_model_for_inference(model, tokenizer, save_dir, config_dict):
 os.makedirs(save_dir, exist_ok=True)
 torch.save(model.state_dict(), os.path.join(save_dir, 'pytorch_model.bin'))
 with open(os.path.join(save_dir, 'config.json'), 'w') as f:
 json.dump(config_dict, f, indent=2)
 tokenizer.save_pretrained(save_dir)
 print(f" Full Model (Backbone + Heads) saved to '{save_dir}/'")

def load_model_for_inference(model_class, save_dir, device):
 """Load the best saved model from disk."""
 config_path = os.path.join(save_dir, 'config.json')
 with open(config_path, 'r') as f:
 config = json.load(f)
 
 model = model_class(config['model_name'], dropout_rate=config['dropout_rate'])
 state_dict = torch.load(os.path.join(save_dir, 'pytorch_model.bin'), map_location=device)
 model.load_state_dict(state_dict)
 model.to(device)
 print(f" Best model loaded from '{save_dir}/'")
 return model


def kl_consistency_loss(teacher_logits, student_logits, temperature=KL_TEMPERATURE):
 """
 Computes KL divergence loss treating the main/simple syllogism as the TEACHER
 and the alt/complex syllogism as the STUDENT.

 For a binary classifier, each logit defines a Bernoulli distribution.
 We expand to a 2-class distribution [P(invalid), P(valid)] so that
 F.kl_div operates on proper probability simplices.

 KL(teacher || student) — the student is trained to match the teacher's
 soft distribution. Gradients flow only through student_logits.

 Args:
 teacher_logits: logits from the main/simple syllogism (no gradient needed)
 student_logits: logits from the alt/complex syllogism (gradients flow here)
 temperature: softening temperature T (>1 softens, =1 is standard)

 Returns:
 Scalar KL divergence loss.
 """
 T = temperature

 teacher_prob_valid = torch.sigmoid(teacher_logits.detach() / T) 
 student_prob_valid = torch.sigmoid(student_logits / T)

 teacher_dist = torch.stack([1.0 - teacher_prob_valid, teacher_prob_valid], dim=-1) 
 student_dist = torch.stack([1.0 - student_prob_valid, student_prob_valid], dim=-1) 

 student_log_dist = torch.log(student_dist.clamp(min=1e-8))

 loss = F.kl_div(student_log_dist, teacher_dist, reduction='batchmean') * (T ** 2)

 return loss

def train_engine(model, train_loader, val_loader, epochs, lr, device,
 patience=3, min_delta=0.001, save_dir=MODEL_SAVE_DIR, bias_lambda=1.0, fp16=True, num_groups=6):

 torch.cuda.empty_cache()
 gc.collect()

 backbone_params = list(model.backbone.named_parameters())
 head_params = list(model.classifier.named_parameters())
 
 no_decay = ["bias", "LayerNorm.weight"]
 optimizer_grouped_parameters = [
 {
 "params": [p for n, p in backbone_params if not any(nd in n for nd in no_decay)],
 "weight_decay": WEIGHT_DECAY, 
 "lr": lr
 },
 {
 "params": [p for n, p in backbone_params if any(nd in n for nd in no_decay)],
 "weight_decay": 0.0, 
 "lr": lr
 },
 {
 "params": [p for n, p in head_params if not any(nd in n for nd in no_decay)],
 "weight_decay": WEIGHT_DECAY, 
 "lr": lr * 10 
 },
 {
 "params": [p for n, p in head_params if any(nd in n for nd in no_decay)],
 "weight_decay": 0.0, 
 "lr": lr * 10
 }
 ]
 
 
 optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
 num_update_steps_per_epoch = len(train_loader) // ACCUMULATION_STEPS
 if len(train_loader) % ACCUMULATION_STEPS != 0:
 num_update_steps_per_epoch += 1
 
 total_training_steps = num_update_steps_per_epoch * epochs
 num_warmup_steps = int(total_training_steps * WARMUP_RATIO)
 
 print(f"Total Steps: {total_training_steps} | Warmup Steps: {num_warmup_steps} (Ratio: {WARMUP_RATIO})")
 
 scheduler = get_linear_schedule_with_warmup(
 optimizer,
 num_warmup_steps=num_warmup_steps,
 num_training_steps=total_training_steps
 )

 early_stopping = EarlyStopping(patience=patience, min_delta=min_delta, mode='max')
 scaler = GradScaler()
 
 print(f"Training on {device}")
 val_ground_truth = val_loader.dataset.data
 best_ranking_score = 0.0

 MAX_LAMBDA = bias_lambda
 group_weights = torch.ones(num_groups, device=device) / num_groups
 group_lr_lst = [0.0, 0.01, 0.02, 0.05, 0.07, 0.1, 0.1, 0.1, 0.1, 0.1]
 consistency_lambda_lst = [0.1, 0.1, 0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0]
 
 for epoch in range(epochs):
 if epoch < 1:
 current_lambda = 0.0 
 elif epoch == 1:
 current_lambda = 0.5 
 elif epoch == 2:
 current_lambda = 1.0
 else:
 current_lambda = MAX_LAMBDA

 if epoch < len(group_lr_lst):
 group_lr = group_lr_lst[epoch]
 CONSISTENCY_LAMBDA = consistency_lambda_lst[epoch]
 else:
 group_lr = group_lr_lst[-1]
 CONSISTENCY_LAMBDA = consistency_lambda_lst[-1]

 start_time = time.time()
 
 model.train()
 total_train_loss = 0
 total_train_bias = 0
 total_train_cons = 0
 progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")

 running_group_losses = torch.zeros(num_groups, device=device)
 running_group_counts = torch.zeros(num_groups, device=device)
 
 for step, batch in enumerate(progress_bar):
 input_ids = batch['input_ids'].to(device)
 mask = batch['attention_mask'].to(device)
 alt_input_ids = batch['alt_input_ids'].to(device)
 alt_mask = batch['alt_attention_mask'].to(device)
 validity_labels = batch['validity_label'].to(device)
 group_ids = batch["group_id"].to(device)
 plausibility_labels = group_ids % 3

 if step % ACCUMULATION_STEPS == 0:
 optimizer.zero_grad()
 
 with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=fp16):
 
 logits = model(input_ids=input_ids, attention_mask=mask)
 alt_logits = model(input_ids=alt_input_ids, attention_mask=alt_mask)

 loss_fct = nn.BCEWithLogitsLoss(reduction='none')
 per_sample_loss = (loss_fct(logits, validity_labels) + loss_fct(alt_logits, validity_labels)) / 2
 
 

 current_group_losses = torch.zeros(num_groups, device=device)
 for g in range(num_groups):
 mask_g = (group_ids == g)
 if mask_g.any():
 g_loss = per_sample_loss[mask_g].mean()
 current_group_losses[g] = g_loss
 running_group_losses[g] += per_sample_loss[mask_g].sum().detach()
 running_group_counts[g] += mask_g.sum().detach()

 probs_main = torch.sigmoid(logits)
 probs_alt = torch.sigmoid(alt_logits)

 probs_main = torch.clamp(probs_main, min=1e-7, max=1-1e-7)
 probs_alt = torch.clamp(probs_alt, min=1e-7, max=1-1e-7)
 
 probs = (probs_main + probs_alt) / 2.0
 confidences = torch.where(validity_labels > 0.5, probs, 1.0 - probs)
 
 intra_diffs = []
 valid_plaus_mask = plausibility_labels != 1
 unique_plaus = torch.unique(plausibility_labels[valid_plaus_mask])

 for p in unique_plaus:
 mask_p = (plausibility_labels == p) & valid_plaus_mask
 lbls_p = validity_labels[mask_p]
 confs_p = confidences[mask_p]

 if (lbls_p > 0.5).any() and (lbls_p < 0.5).any():
 mean_conf_valid = confs_p[lbls_p == 1.0].mean()
 mean_conf_invalid = confs_p[lbls_p == 0.0].mean()
 diff = F.smooth_l1_loss(mean_conf_valid, mean_conf_invalid, beta=0.1)
 if not (torch.isnan(diff) or torch.isinf(diff)):
 intra_diffs.append(diff)

 intra_loss = torch.stack(intra_diffs).mean() if intra_diffs else torch.tensor(0.0, device=device)
 
 cross_diffs = []
 unique_val = torch.unique(validity_labels)

 for v in unique_val:
 valid_mask = (plausibility_labels == 0) | (plausibility_labels == 2)
 mask_v = (validity_labels == v) & valid_mask

 lbls_plaus = plausibility_labels[mask_v]
 confs_v = confidences[mask_v]

 if (lbls_plaus == 2).any() and (lbls_plaus == 0).any():
 mean_conf_plaus = confs_v[lbls_plaus == 2].mean()
 mean_conf_implaus = confs_v[lbls_plaus == 0].mean()
 diff = F.smooth_l1_loss(mean_conf_plaus, mean_conf_implaus, beta=0.1)
 
 if not (torch.isnan(diff) or torch.isinf(diff)):
 cross_diffs.append(diff)

 cross_loss = torch.stack(cross_diffs).mean() if cross_diffs else torch.tensor(0.0, device=device)
 
 bias_penalty = (intra_loss + cross_loss) / 2.0

 
 consistency_loss = kl_consistency_loss(
 teacher_logits=logits,
 student_logits=alt_logits,
 temperature=KL_TEMPERATURE
 )
 

 
 dro_component = torch.dot(group_weights.detach(), current_group_losses)
 

 bias_loss = current_lambda * bias_penalty

 const_loss = CONSISTENCY_LAMBDA * consistency_loss
 

 total_step_loss = dro_component + bias_loss + const_loss
 
 
 loss = total_step_loss / float(ACCUMULATION_STEPS)


 scaler.scale(loss).backward()
 
 if (step + 1) % ACCUMULATION_STEPS == 0:
 if epoch >= DRO_WARMUP_EPOCHS:
 with torch.no_grad():
 
 active = running_group_counts > 0
 avg_group_losses = torch.zeros_like(running_group_losses)
 avg_group_losses[active] = running_group_losses[active] / running_group_counts[active]
 
 
 scaled_losses = group_lr * avg_group_losses
 scaled_losses = torch.clamp(scaled_losses, min=-10, max=10) 
 
 
 group_weights = group_weights + group_lr * avg_group_losses
 group_weights = torch.clamp(group_weights, min=0.01)
 group_weights = group_weights / group_weights.sum()

 scaler.unscale_(optimizer)

 
 
 
 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
 scaler.step(optimizer)
 scaler.update()
 scheduler.step()
 
 running_group_losses.zero_()
 running_group_counts.zero_()
 
 total_train_loss += loss.item() * ACCUMULATION_STEPS
 total_train_bias += bias_penalty.item()
 total_train_cons += consistency_loss.item()
 progress_bar.set_postfix({
 'loss': f"{loss.item() * ACCUMULATION_STEPS:.4f}",
 'lr': f"{scheduler.get_last_lr()[0]:.2e}"
 })
 
 avg_train_loss = total_train_loss / len(train_loader)
 avg_train_bias = total_train_bias / len(train_loader)
 avg_train_cons = total_train_cons / len(train_loader)

 torch.cuda.empty_cache()
 gc.collect()
 
 model.eval()
 val_predictions = []
 total_val_loss = 0
 total_val_bias = 0
 total_val_cons = 0
 
 with torch.no_grad():
 for batch in val_loader:
 input_ids = batch['input_ids'].to(device)
 mask = batch['attention_mask'].to(device)
 alt_input_ids = batch['alt_input_ids'].to(device)
 alt_mask = batch['alt_attention_mask'].to(device)
 validity_labels = batch['validity_label'].to(device)
 group_ids = batch["group_id"].to(device)
 plausibility_labels = group_ids % 3
 ids = batch['id']
 
 with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=fp16):
 logits = model(input_ids, mask)
 alt_logits = model(alt_input_ids, alt_mask)
 
 
 loss_fct = nn.BCEWithLogitsLoss(reduction='none')
 per_sample_loss = (loss_fct(logits, validity_labels) + loss_fct(alt_logits, validity_labels)) / 2.0
 
 current_group_losses = torch.zeros(num_groups, device=device)
 for g in range(num_groups):
 mask_g = (group_ids == g)
 if mask_g.any():
 current_group_losses[g] = per_sample_loss[mask_g].mean()

 probs_main = torch.sigmoid(logits)
 probs_alt = torch.sigmoid(alt_logits)
 
 
 consistency_loss = kl_consistency_loss(
 teacher_logits=logits,
 student_logits=alt_logits,
 temperature=KL_TEMPERATURE
 )

 avg_probs = (probs_main + probs_alt) / 2.0
 confidences = torch.where(validity_labels > 0.5, avg_probs, 1.0 - avg_probs)

 intra_diffs = []
 valid_plaus_mask = (plausibility_labels != 1)
 for p in [0, 2]:
 m_p = (plausibility_labels == p) & valid_plaus_mask
 if m_p.any():
 l_p, c_p = validity_labels[m_p], confidences[m_p]
 if (l_p > 0.5).any() and (l_p < 0.5).any():
 mean_valid = c_p[l_p > 0.5].mean() 
 mean_invalid = c_p[l_p < 0.5].mean()
 diff = F.smooth_l1_loss(mean_valid, mean_invalid, beta=0.1)
 if not (torch.isnan(diff) or torch.isinf(diff)):
 intra_diffs.append(diff)
 
 intra_loss = torch.stack(intra_diffs).mean() if intra_diffs else torch.tensor(0.0, device=device)

 cross_diffs = []
 for v in [0.0, 1.0]:
 valid_mask = (plausibility_labels == 0) | (plausibility_labels == 2)
 mask_v = (torch.abs(validity_labels - v) < 1e-5) & valid_mask
 if mask_v.any():
 confs_v = confidences[mask_v]
 l_plaus = plausibility_labels[mask_v]
 if (l_plaus == 2).any() and (l_plaus == 0).any():
 mean_plaus = confs_v[l_plaus == 2].mean() 
 mean_implaus = confs_v[l_plaus == 0].mean() 
 diff = F.smooth_l1_loss(mean_plaus, mean_implaus, beta=0.1)
 if not (torch.isnan(diff) or torch.isinf(diff)):
 cross_diffs.append(diff)
 
 cross_loss = torch.stack(cross_diffs).mean() if cross_diffs else torch.tensor(0.0, device=device)
 bias_penalty = (intra_loss + cross_loss) / 2.0

 dro_loss = torch.dot(group_weights.detach(), current_group_losses)
 v_loss = dro_loss + (current_lambda * bias_penalty) + (CONSISTENCY_LAMBDA * consistency_loss)
 
 total_val_loss += v_loss.item()
 total_val_bias += bias_penalty.item()
 total_val_cons += consistency_loss.item()

 preds = (probs_main >= 0.5).long().cpu().numpy()
 for i, uid in enumerate(ids):
 val_predictions.append({'id': uid, 'validity_pred': bool(preds[i])})
 
 avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0.0
 avg_val_bias = total_val_bias / len(val_loader) if len(val_loader) > 0 else 0.0
 avg_val_cons = total_val_cons / len(val_loader) if len(val_loader) > 0 else 0.0

 metrics = compute_task1_metrics(val_predictions, val_ground_truth) 
 
 
 print(f"\nEpoch {epoch + 1}/{epochs}")
 print(f"Total Loss: Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}")
 print(f"Bias Loss: Train: {avg_train_bias:.4f} | Val: {avg_val_bias:.4f}")
 print(f"Consistency Loss: Train: {avg_train_cons:.4f} | Val: {avg_val_cons:.4f}")
 print(f"Metrics : Acc: {metrics['accuracy']:.2f}% | Real Bias: {metrics['total_bias']:.4f}")
 print(f"Ranking Score: {metrics['ranking_score']:.4f}")
 
 if metrics['ranking_score'] > best_ranking_score:
 best_ranking_score = metrics['ranking_score']
 config = {'model_name': MODEL_NAME, 'max_len': MAX_LEN, 'dropout_rate': 0.1}
 save_model_for_inference(model, train_loader.dataset.tokenizer, save_dir,
 config)
 print(" Best Checkpoint Saved!")
 
 if early_stopping(metrics['ranking_score']):
 print(f"\n Early Stopping triggered.")
 break

 torch.cuda.empty_cache()
 gc.collect()
 
 return model

def predict(model, dataloader, device):
 model.eval()
 predictions = []
 with torch.no_grad():
 for batch in tqdm(dataloader, desc="Predicting"):
 input_ids = batch['input_ids'].to(device)
 mask = batch['attention_mask'].to(device)
 ids = batch['id']
 
 with autocast():
 logits = model(input_ids, mask)
 
 probs = torch.sigmoid(logits)
 preds = (probs >= 0.5).long().cpu().numpy()
 
 for i, uid in enumerate(ids):
 predictions.append({
 'id': uid,
 'validity_pred': bool(preds[i])
 })
 return predictions

class StratifiedBatchSampler(torch.utils.data.Sampler):
 def __init__(self, labels, batch_size, num_groups=6):
 labels = np.array(labels)
 self.indices = [np.where(labels == i)[0] for i in range(num_groups)]
 self.batch_size = batch_size
 self.samples_per_group = max(1, batch_size // num_groups)
 self.num_batches = max(len(inds) // self.samples_per_group for inds in self.indices)
 self.all_indices = np.arange(len(labels))
 
 def __iter__(self):
 for inds in self.indices:
 np.random.shuffle(inds)
 
 group_iters = [itertools.cycle(inds) for inds in self.indices]
 
 for _ in range(self.num_batches):
 batch = []
 for g_it in group_iters:
 for _ in range(self.samples_per_group):
 batch.append(next(g_it))
 
 remaining_slots = self.batch_size - len(batch)
 if remaining_slots > 0:
 extra_indices = np.random.choice(self.all_indices, remaining_slots, replace=False)
 batch.extend(extra_indices)
 
 np.random.shuffle(batch)
 yield batch
 
 def __len__(self):
 return self.num_batches

if __name__ == "__main__":
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 print(f"Using device: {device}")

 try:
 with open(JSON_FILE_PATH_AUG, 'r') as f:
 aug_pool = json.load(f) 
 for i in aug_pool: i["is_gold"] = False
 print(f" Loaded {len(aug_pool)} augmented samples.")

 with open(JSON_FILE_PATH_GOLD, 'r') as f:
 gold_pool = json.load(f)
 for i in gold_pool: i["is_gold"] = True
 print(f" Loaded {len(gold_pool)} gold samples.")
 except Exception as e:
 print(f" Error loading data: {e}")
 exit()

 random.seed(42)

 gold_stratify = [f"{x['validity']}_{x.get('plausibility', 'neutral')}" for x in gold_pool]
 
 train_gold, temp_gold, _, temp_gold_labels = train_test_split(
 gold_pool, gold_stratify, test_size=0.20, random_state=42, stratify=gold_stratify
 )
 val_gold, test_gold = train_test_split(
 temp_gold, test_size=0.5, random_state=42, stratify=temp_gold_labels
 )

 aug_stratify = [f"{x['validity']}_{x.get('plausibility', 'neutral')}" for x in aug_pool]
 
 train_aug, temp_aug, _, temp_aug_labels = train_test_split(
 aug_pool, aug_stratify, test_size=0.20, random_state=42, stratify=aug_stratify
 )
 val_aug, test_aug = train_test_split(
 temp_aug, test_size=0.5, random_state=42, stratify=temp_aug_labels
 )

 train_data = train_gold + train_aug
 val_data = val_gold + val_aug
 test_data = test_gold + test_aug

 random.shuffle(train_data) 
 random.shuffle(val_data)
 random.shuffle(test_data)

 tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

 train_ds = SyllogismDataset(train_data, tokenizer, max_len=MAX_LEN)
 val_ds = SyllogismDataset(val_data, tokenizer, max_len=MAX_LEN)
 test_ds = SyllogismDataset(test_data, tokenizer, max_len=MAX_LEN)

 train_groups = [
 SyllogismDataset.encode_group_id(item["validity"], item.get("plausibility", "neutral"))
 for item in train_data
 ]

 print(f"\n Debugging Group Counts:")
 for g in range(6):
 count = train_groups.count(g)
 print(f" Group {g}: {count} samples")
 
 print(f"\n Total samples: {len(train_groups)}")
 print(f" Batch size: {BATCH_SIZE}")
 print(f" Samples per group: {BATCH_SIZE // 6}")
 
 stratified_sampler = StratifiedBatchSampler(train_groups, BATCH_SIZE, num_groups=6)
 train_loader = DataLoader(
 train_ds,
 batch_sampler=stratified_sampler,
 num_workers=NUM_WORKERS,
 pin_memory=True
 )
 val_loader = DataLoader(val_ds, batch_size=EVAL_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
 test_loader = DataLoader(test_ds, batch_size=EVAL_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

 model = BinarySyllogismModel(MODEL_NAME).to(device)
 
 os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

 trained_model = train_engine(
 model, train_loader, val_loader,
 epochs=EPOCHS, lr=LEARNING_RATE, device=device,
 patience=PATIENCE, min_delta=MIN_DELTA,
 save_dir=MODEL_SAVE_DIR,
 bias_lambda=2.0,
 fp16=FP16,
 num_groups=6
 )

 print("\n Loading best model for test predictions...")
 best_model = load_model_for_inference(BinarySyllogismModel, MODEL_SAVE_DIR, device)

 print("\n Predicting on Test Set...")
 test_predictions = predict(best_model, test_loader, device)
 metrics = compute_task1_metrics(test_predictions, test_data)

 print("\n" + "="*40)
 print(" FINAL TEST RESULTS (SUBTASK 1)")
 print("="*40)
 print(f"Accuracy: {metrics['accuracy']:.2f}%")
 print(f"Total Bias: {metrics['total_bias']:.4f}")
 print(f"Ranking Score: {metrics['ranking_score']:.4f}")

 pred_path = os.path.join(OUTPUT_DIR, "subtask_1_bias_scheduler_predictions.json")
 with open(pred_path, "w") as f:
 json.dump(test_predictions, f, indent=4)
 print(f"\n Predictions saved to {pred_path}")

 gt_path = os.path.join(OUTPUT_DIR, "subtask_1_bias_scheduler_ground_truth.json")
 with open(gt_path, "w") as f:
 json.dump(test_data, f, indent=4)
 print(f" Ground Truth saved to {gt_path}")
 
 print("\n Zipping model...")
 shutil.make_archive(os.path.join(OUTPUT_DIR, "subtask_1_bias_scheduler_model"), 'zip', MODEL_SAVE_DIR)
 print("\nZipping finished!")


Using device: cuda
 Loaded 22564 augmented samples.
 Loaded 949 gold samples.


tokenizer_config.json: 0%| | 0.00/467 [00:00<?, ?B/s]

spm.model: 0%| | 0.00/4.31M [00:00<?, ?B/s]

tokenizer.json: 0%| | 0.00/16.3M [00:00<?, ?B/s]

added_tokens.json: 0%| | 0.00/23.0 [00:00<?, ?B/s]

special_tokens_map.json: 0%| | 0.00/173 [00:00<?, ?B/s]


 Debugging Group Counts:
 Group 0: 3199 samples
 Group 1: 3006 samples
 Group 2: 3198 samples
 Group 3: 3202 samples
 Group 4: 3006 samples
 Group 5: 3199 samples

 Total samples: 18810
 Batch size: 32
 Samples per group: 5


config.json: 0.00B [00:00, ?B/s]

2026-02-11 14:45:01.955243: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770821102.123663 24 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770821102.166849 24 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770821102.556192 24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770821102.556232 24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770821102.556235 24 computation_placer.cc:177] computation placer already registered. Please c

model.safetensors: 0%| | 0.00/558M [00:00<?, ?B/s]

Total Steps: 6400 | Warmup Steps: 384 (Ratio: 0.06)
Training on cuda


 scaler = GradScaler()


Epoch 1/10: 0%| | 0/640 [00:00<?, ?it/s]


Epoch 1/10
Total Loss: Train: 0.2923 | Val: 0.0669
Bias Loss: Train: 0.0439 | Val: 0.0169
Consistency Loss: Train: 0.1737 | Val: 0.1228
Metrics : Acc: 99.31% | Real Bias: 0.7500
Ranking Score: 63.6773
 Full Model (Backbone + Heads) saved to '/kaggle/working/trained_model_task1/'
 Best Checkpoint Saved!


Epoch 2/10: 0%| | 0/640 [00:00<?, ?it/s]


Epoch 2/10
Total Loss: Train: 0.0591 | Val: 0.0191
Bias Loss: Train: 0.0098 | Val: 0.0036
Consistency Loss: Train: 0.1316 | Val: 0.0351
Metrics : Acc: 99.87% | Real Bias: 0.1253
Ranking Score: 89.3287
 Full Model (Backbone + Heads) saved to '/kaggle/working/trained_model_task1/'
 Best Checkpoint Saved!


Epoch 3/10: 0%| | 0/640 [00:00<?, ?it/s]


Epoch 3/10
Total Loss: Train: 0.0505 | Val: 0.0103
Bias Loss: Train: 0.0035 | Val: 0.0010
Consistency Loss: Train: 0.0572 | Val: 0.0086
Metrics : Acc: 99.94% | Real Bias: 0.1250
Ranking Score: 89.4069
 Full Model (Backbone + Heads) saved to '/kaggle/working/trained_model_task1/'
 Best Checkpoint Saved!


Epoch 4/10: 0%| | 0/640 [00:00<?, ?it/s]


Epoch 4/10
Total Loss: Train: 0.0282 | Val: 0.0348
Bias Loss: Train: 0.0021 | Val: 0.0033
Consistency Loss: Train: 0.0279 | Val: 0.0258
Metrics : Acc: 99.87% | Real Bias: 0.1253
Ranking Score: 89.3287


Epoch 5/10: 0%| | 0/640 [00:00<?, ?it/s]


Epoch 5/10
Total Loss: Train: 0.0204 | Val: 0.0083
Bias Loss: Train: 0.0016 | Val: 0.0007
Consistency Loss: Train: 0.0188 | Val: 0.0038
Metrics : Acc: 99.94% | Real Bias: 0.1250
Ranking Score: 89.4069

 Early Stopping triggered.

 Loading best model for test predictions...
 Best model loaded from '/kaggle/working/trained_model_task1/'

 Predicting on Test Set...


Predicting: 0%| | 0/74 [00:00<?, ?it/s]

 with autocast():



 FINAL TEST RESULTS (SUBTASK 1)
Accuracy: 100.00%
Total Bias: 0.0000
Ranking Score: 100.0000

 Predictions saved to /kaggle/working/subtask_1_bias_scheduler_predictions.json
 Ground Truth saved to /kaggle/working/subtask_1_bias_scheduler_ground_truth.json

 Zipping model...

Zipping finished!


In [2]:
if __name__ == "__main__":
 with open("/kaggle/input/test-data-subtask-1/test_data_subtask_1.json", "r") as f:
 test_data = json.load(f)

 predictions = []
 best_model.eval()
 print("\n Starting Inference...")
 
 split_pattern = r'[\.\u3002\u0964]+' 

 for i in tqdm(range(0, len(test_data), BATCH_SIZE)):
 batch_items = test_data[i : i + BATCH_SIZE]
 
 batch_conclusions = []
 batch_premises = []
 
 for item in batch_items:
 raw_text = item['syllogism']
 sentences = [s.strip() for s in re.split(split_pattern, raw_text) if s.strip()]
 
 if not sentences:
 sentences = ["Empty"]
 
 conclusion = sentences[-1]
 premises = sentences[:-1]
 premises_text = tokenizer.sep_token.join(premises)
 
 batch_conclusions.append(conclusion)
 batch_premises.append(premises_text)

 ids = [str(item['id']) for item in batch_items]
 
 inputs = tokenizer(
 batch_conclusions, 
 batch_premises,
 truncation=True, 
 max_length=MAX_LEN, 
 padding="max_length", 
 return_tensors="pt"
 )
 
 input_ids = inputs['input_ids'].to(device)
 mask = inputs['attention_mask'].to(device)
 
 with torch.no_grad():
 logits = best_model(input_ids, mask)
 probs = torch.sigmoid(logits).cpu().numpy()
 
 for idx, prob in enumerate(probs):
 pred_label = bool(prob >= 0.5)
 predictions.append({
 "id": ids[idx],
 "validity": pred_label
 })

 with open('/kaggle/working/predictions.json', "w") as f:
 json.dump(predictions, f, indent=4)



 Starting Inference...


 0%| | 0/6 [00:00<?, ?it/s]