In [4]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import gc
from torch.utils.data import Dataset, DataLoader, TensorDataset
from transformers import (
    AutoTokenizer, BertForMaskedLM, DistilBertForMaskedLM,
    DistilBertForSequenceClassification, DistilBertForTokenClassification,
    TrainingArguments, Trainer
)

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from tqdm import tqdm
from datasets import load_dataset
from huggingface_hub import login
import re
import ast
import json
import os
from kaggle_secrets import UserSecretsClient
import nltk
from nltk.tokenize import sent_tokenize
from sklearn.model_selection import train_test_split
from torch.optim import AdamW

# Download NLTK data for sentence tokenization
nltk.download('punkt', quiet=True)

# Create directories
os.makedirs("/kaggle/working/data/processed", exist_ok=True)
os.makedirs("/kaggle/working/models", exist_ok=True)

# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

def preprocess_data():
    # Securely load Hugging Face token
    user_secrets = UserSecretsClient()
    HF_TOKEN = user_secrets.get_secret("HF_TOKEN")
    
    if not HF_TOKEN:
        raise ValueError("Hugging Face token not found in Kaggle secrets. "
                         "Please add your HF_TOKEN to Kaggle secrets.")
    
    # Authenticate with Hugging Face
    login(token=HF_TOKEN)
    
    # Load CJPE dataset
    print("Loading CJPE dataset...")
    dataset = load_dataset("Exploration-Lab/IL-TUR", "cjpe")
    print("Dataset loaded successfully!")
    
    # Convert to DataFrames
    df_train = pd.DataFrame(dataset["single_train"])
    df_dev = pd.DataFrame(dataset["single_dev"])
    df_test = pd.DataFrame(dataset["test"])
    df_expert = pd.DataFrame(dataset["expert"]) if "expert" in dataset else None
    
    # Create raw_text column for all splits
    df_train["raw_text"] = df_train["text"]
    df_dev["raw_text"] = df_dev["text"]
    df_test["raw_text"] = df_test["text"]
    
    # Add raw_text for expert split if it exists
    if df_expert is not None:
        df_expert["raw_text"] = df_expert["text"]
    
    # Tokenization
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    MAX_LEN = 512
    
    def tokenize_batch(texts):
        return tokenizer(
            texts,
            truncation=True,
            padding="max_length",
            max_length=MAX_LEN,
            return_tensors="pt",
            return_offsets_mapping=True
        )
    
    print("Tokenizing data...")
    train_tokens = tokenize_batch(df_train["raw_text"].tolist())
    dev_tokens = tokenize_batch(df_dev["raw_text"].tolist())
    test_tokens = tokenize_batch(df_test["raw_text"].tolist())
    
    # Initialize rationale masks with zeros
    df_train["rationale_mask"] = [[0]*MAX_LEN for _ in range(len(df_train))]
    df_dev["rationale_mask"] = [[0]*MAX_LEN for _ in range(len(df_dev))]
    
    # Process expert split for rationales
    if df_expert is not None:
        print("Processing expert annotations for rationale extraction...")
        df_expert["rationale_mask"] = [[0]*MAX_LEN for _ in range(len(df_expert))]
        expert_tokens = tokenize_batch(df_expert["raw_text"].tolist())
        
        for idx in tqdm(range(len(df_expert)), desc="Expert rationale masks"):
            row = df_expert.iloc[idx]
            text = row["raw_text"]
            offsets = expert_tokens["offset_mapping"][idx].tolist()
            
            # Extract sentences using NLTK
            sentences = sent_tokenize(text)
            
            # Collect expert rationale sentences
            expert_sentences = set()
            for i in range(1, 6):
                expert = row.get(f"expert_{i}")
                if not expert:
                    continue
                
                # Handle different annotation formats
                if isinstance(expert, str):
                    try:
                        # Try to parse as Python literal
                        expert = ast.literal_eval(expert)
                    except (ValueError, SyntaxError):
                        try:
                            # Try to parse as JSON
                            expert = json.loads(expert)
                        except json.JSONDecodeError:
                            # Skip if both parsing methods fail
                            continue
                
                # Extract sentences from expert annotations
                for rank in ['rank1', 'rank2', 'rank3', 'rank4', 'rank5']:
                    if rank in expert:
                        sentences_list = expert[rank]
                        if isinstance(sentences_list, str):
                            try:
                                # Parse string representation of list
                                sentences_list = ast.literal_eval(sentences_list)
                            except (ValueError, SyntaxError):
                                continue
                        if isinstance(sentences_list, list):
                            # Add cleaned sentences to the set
                            expert_sentences.update([s.strip() for s in sentences_list])
            
            # Create rationale mask
            mask = [0] * MAX_LEN
            for sent in sentences:
                if sent.strip() in expert_sentences:
                    # Find all occurrences of the sentence
                    pattern = re.escape(sent)
                    for match in re.finditer(pattern, text):
                        start_idx = match.start()
                        end_idx = match.end()
                        
                        # Mark tokens within rationale span
                        for i, (start, end) in enumerate(offsets):
                            if i >= MAX_LEN:
                                break
                            if start == 0 and end == 0:  # Skip special tokens
                                continue
                            if not (end <= start_idx or start >= end_idx):
                                mask[i] = 1
            
            df_expert.at[idx, "rationale_mask"] = mask
        
        # Save expert data
        df_expert.to_csv("/kaggle/working/data/processed/expert.csv", index=False)
        torch.save(expert_tokens, "/kaggle/working/data/processed/expert_tokens.pt")
    
    # Save processed data
    OUTPUT_DIR = "/kaggle/working/data/processed"
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    df_train.to_csv(f"{OUTPUT_DIR}/train.csv", index=False)
    df_dev.to_csv(f"{OUTPUT_DIR}/dev.csv", index=False)
    df_test.to_csv(f"{OUTPUT_DIR}/test.csv", index=False)
    
    torch.save({
        "input_ids": train_tokens["input_ids"],
        "attention_mask": train_tokens["attention_mask"]
    }, f"{OUTPUT_DIR}/train_tokens.pt")
    
    torch.save({
        "input_ids": dev_tokens["input_ids"],
        "attention_mask": dev_tokens["attention_mask"],
        "offset_mapping": dev_tokens["offset_mapping"]
    }, f"{OUTPUT_DIR}/dev_tokens.pt")
    
    torch.save({
        "input_ids": test_tokens["input_ids"],
        "attention_mask": test_tokens["attention_mask"],
        "offset_mapping": test_tokens["offset_mapping"]
    }, f"{OUTPUT_DIR}/test_tokens.pt")
    
    print(f"Preprocessing complete! Files saved to {OUTPUT_DIR}")
    
    # Calculate rationale coverage for expert split
    if df_expert is not None:
        def calculate_coverage(masks):
            total_tokens = 0
            positive_tokens = 0
            for m in masks:
                # Only consider actual text tokens (ignore padding)
                valid_tokens = len([x for x in m if x != -1])
                total_tokens += valid_tokens
                positive_tokens += sum(m[:valid_tokens])
            coverage = positive_tokens / total_tokens if total_tokens > 0 else 0
            return coverage
        
        expert_coverage = calculate_coverage(df_expert["rationale_mask"])
        print(f"Expert rationale coverage: {expert_coverage:.4%}")
    
    return df_train, df_dev, df_test, df_expert if df_expert is not None else None

# Execute preprocessing
train_df, dev_df, test_df, expert_df = preprocess_data()

# Define ClassificationDataset
class ClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(self.labels[idx], dtype=torch.long)
        }

# Step 2: Corrected Distillation Training
def distillation_training():
    # Configuration
    TEACHER_MODEL = "bert-base-uncased"
    STUDENT_MODEL = "distilbert-base-uncased"
    DATA_PATH = "/kaggle/working/data/processed/train_tokens.pt"
    SAVE_PATH = "/kaggle/working/models/distilled_model"
    BATCH_SIZE = 8
    EPOCHS = 1
    ALPHA, BETA, GAMMA = 0.5, 0.4, 0.1  # Adjusted loss weights
    
    # Load token data
    print("Loading token data...")
    token_data = torch.load(DATA_PATH, map_location='cpu')
    
    # Extract tensors
    input_ids = token_data["input_ids"]
    attention_mask = token_data["attention_mask"]
    print(f"Loaded {len(input_ids)} samples")
    
    # Create DataLoader
    dataset = TensorDataset(input_ids, attention_mask)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    # Initialize models
    print("Loading teacher and student models...")
    teacher = BertForMaskedLM.from_pretrained(TEACHER_MODEL).to(DEVICE)
    student = DistilBertForMaskedLM.from_pretrained(STUDENT_MODEL).to(DEVICE)
    teacher.eval()
    
    # Loss functions and optimizer
    kl_loss = nn.KLDivLoss(reduction="batchmean")
    cosine_loss = nn.CosineEmbeddingLoss()
    optimizer = AdamW(student.parameters(), lr=5e-5, weight_decay=1e-4)
    
    # Gradient accumulation steps
    grad_accum_steps = 2
    print(f"\nStarting distillation training for {EPOCHS} epochs")
    print(f"Batch size: {BATCH_SIZE} (effective: {BATCH_SIZE * grad_accum_steps})")
    print(f"Total batches: {len(dataloader)}")
    
    # Training loop
    for epoch in range(EPOCHS):
        student.train()
        total_loss = 0
        optimizer.zero_grad()
        
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), 
                           desc=f"Epoch {epoch+1}/{EPOCHS}")
        
        for batch_idx, batch in progress_bar:
            batch_input_ids, batch_attention_mask = [t.to(DEVICE) for t in batch]
            
            # Create masked inputs (15% masking probability)
            masked_input_ids = batch_input_ids.clone()
            mask_prob = torch.rand(masked_input_ids.shape, device=DEVICE)
            
            # Only mask non-special tokens
            special_tokens_mask = (batch_input_ids == 0) | (batch_input_ids == 101) | (batch_input_ids == 102)
            mask = (mask_prob < 0.15) & ~special_tokens_mask
            masked_input_ids[mask] = 103  # DistilBERT's mask token ID
            
            # Teacher forward pass
            with torch.no_grad():
                teacher_outputs = teacher(
                    input_ids=masked_input_ids,
                    attention_mask=batch_attention_mask,
                    output_hidden_states=True
                )
            
            # Student forward pass with labels for proper MLM loss
            student_outputs = student(
                input_ids=masked_input_ids,
                attention_mask=batch_attention_mask,
                output_hidden_states=True,
                labels=batch_input_ids  # Add labels for built-in MLM loss
            )
            
            # Calculate losses
            # Use log probabilities for numerical stability
            student_log_probs = torch.nn.functional.log_softmax(student_outputs.logits, dim=-1)
            
            # Apply temperature scaling to teacher outputs
            teacher_logits = teacher_outputs.logits / 2.0
            teacher_probs = torch.nn.functional.softmax(teacher_logits, dim=-1)
            
            distil_loss = kl_loss(student_log_probs, teacher_probs)
            
            # Hidden states loss
            student_hidden = student_outputs.hidden_states[-1]
            teacher_hidden = teacher_outputs.hidden_states[-1]
            
            # Flatten hidden states for cosine loss
            student_flat = student_hidden.view(-1, student_hidden.size(-1))
            teacher_flat = teacher_hidden.view(-1, teacher_hidden.size(-1))
            target = torch.ones(student_flat.size(0), device=DEVICE)
            
            cos_loss = cosine_loss(student_flat, teacher_flat, target)
            
            # Use built-in MLM loss (only calculates loss on masked tokens)
            mlm_loss = student_outputs.loss
            
            # Combined loss
            loss = (ALPHA * distil_loss + 
                    BETA * mlm_loss + 
                    GAMMA * cos_loss) / grad_accum_steps
            
            # Backpropagation with gradient accumulation
            loss.backward()
            
            if (batch_idx + 1) % grad_accum_steps == 0:
                torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()
            
            total_loss += loss.item() * grad_accum_steps
            
            # Update progress bar
            if batch_idx % 10 == 0:
                progress_bar.set_postfix({
                    "batch_loss": f"{loss.item() * grad_accum_steps:.4f}",
                    "avg_loss": f"{total_loss/(batch_idx+1):.4f}",
                    "d_loss": f"{distil_loss.item():.4f}",
                    "m_loss": f"{mlm_loss.item():.4f}",
                    "c_loss": f"{cos_loss.item():.4f}"
                })
            
            # Clear memory
            del masked_input_ids, mask_prob, mask, teacher_outputs, student_outputs
            torch.cuda.empty_cache()
            gc.collect()
        
        # Final gradient step if needed
        if len(dataloader) % grad_accum_steps != 0:
            torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} Complete - Avg Loss: {avg_loss:.4f}")
    
    # Save distilled model
    os.makedirs(SAVE_PATH, exist_ok=True)
    student.save_pretrained(SAVE_PATH)
    tokenizer = AutoTokenizer.from_pretrained(STUDENT_MODEL)
    tokenizer.save_pretrained(SAVE_PATH)
    print(f"\nDistilled model saved to {SAVE_PATH}")
    return

# Run distillation training
distillation_training()

# Step 3: Classification Fine-tuning with Fallback
def train_classifier():
    # Configuration
    MODEL_PATH = "/kaggle/working/models/distilled_model"
    TRAIN_CSV = "/kaggle/working/data/processed/train.csv"
    DEV_CSV = "/kaggle/working/data/processed/dev.csv"
    SAVE_PATH = "/kaggle/working/models/classification_model"
    BATCH_SIZE = 8
    EPOCHS = 3
    LEARNING_RATE = 3e-5
    
    # Load data
    train_df = pd.read_csv(TRAIN_CSV)
    dev_df = pd.read_csv(DEV_CSV)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    
    # Try loading distilled model, fallback to pretrained if failed
    try:
        model = DistilBertForSequenceClassification.from_pretrained(
            MODEL_PATH, 
            num_labels=2,
            ignore_mismatched_sizes=True
        ).to(DEVICE)
        print("Loaded distilled model for classification fine-tuning")
    except:
        print("Failed to load distilled model, using pretrained as fallback")
        model = DistilBertForSequenceClassification.from_pretrained(
            "distilbert-base-uncased", 
            num_labels=2
        ).to(DEVICE)
    
    # Create datasets and dataloaders
    train_dataset = ClassificationDataset(
        train_df["text"].tolist(), 
        train_df["label"].tolist(), 
        tokenizer
    )
    dev_dataset = ClassificationDataset(
        dev_df["text"].tolist(), 
        dev_df["label"].tolist(), 
        tokenizer
    )
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE)
    
    # Optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader)*EPOCHS)
    
    # Training loop
    best_f1 = 0
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        
        for batch in progress_bar:
            inputs = {k: v.to(DEVICE) for k, v in batch.items() if k != "labels"}
            labels = batch["labels"].to(DEVICE)
            
            # Forward pass
            outputs = model(**inputs, labels=labels)
            loss = outputs.loss
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            total_loss += loss.item()
            progress_bar.set_postfix({"loss": f"{loss.item():.4f}", "lr": f"{scheduler.get_last_lr()[0]:.2e}"})
        
        avg_train_loss = total_loss / len(train_loader)
        
        # Evaluation
        model.eval()
        all_preds = []
        all_labels = []
        val_loss = 0
        
        with torch.no_grad():
            for batch in dev_loader:
                inputs = {k: v.to(DEVICE) for k, v in batch.items() if k != "labels"}
                labels = batch["labels"].to(DEVICE)
                
                outputs = model(**inputs, labels=labels)
                preds = torch.argmax(outputs.logits, dim=1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                val_loss += outputs.loss.item()
        
        # Calculate metrics
        if len(all_labels) > 0:
            accuracy = accuracy_score(all_labels, all_preds)
            f1 = f1_score(all_labels, all_preds, average="binary", zero_division=0)
            precision = precision_score(all_labels, all_preds, average="binary", zero_division=0)
            recall = recall_score(all_labels, all_preds, average="binary", zero_division=0)
            avg_val_loss = val_loss / len(dev_loader)
            
            print(f"\nEpoch {epoch+1} Evaluation:")
            print(f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
            print(f"Accuracy: {accuracy:.4f}, F1: {f1:.4f}")
            print(f"Precision: {precision:.4f}, Recall: {recall:.4f}")
            
            # Save best model
            if f1 > best_f1:
                best_f1 = f1
                model.save_pretrained(SAVE_PATH)
                tokenizer.save_pretrained(SAVE_PATH)
                print(f"New best model saved to {SAVE_PATH} with F1: {f1:.4f}")
        else:
            print(f"\nEpoch {epoch+1} Evaluation: No valid labels to evaluate")
    
    print("Training complete!")
    return

# Run classification training
train_classifier()

# Define RationaleDataset class
class RationaleDataset(Dataset):
    def __init__(self, texts, rationale_masks, tokenizer, max_len=512):
        self.texts = texts
        self.rationale_masks = rationale_masks
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )
        
        # Get rationale mask for this sample
        mask = self.rationale_masks[idx]
        # Pad or truncate mask to max length
        mask = mask[:self.max_len] + [0] * (self.max_len - len(mask))
        
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(mask, dtype=torch.long)
        }

def train_rationale_model():
    # Configuration - ONLY USE EXPERT DATA FOR RATIONALE TRAINING
    MODEL_PATH = "/kaggle/working/models/distilled_model"
    EXPERT_CSV = "/kaggle/working/data/processed/expert.csv"
    SAVE_PATH = "/kaggle/working/models/rationale_model"
    BATCH_SIZE = 8
    EPOCHS = 10  # More epochs for small dataset
    LEARNING_RATE = 3e-5
    MAX_LEN = 512
    
    # Load expert data
    print("Loading expert data for rationale training...")
    expert_df = pd.read_csv(EXPERT_CSV)
    
    # Convert rationale masks
    expert_df["rationale_mask"] = expert_df["rationale_mask"].apply(
        lambda x: ast.literal_eval(x) if isinstance(x, str) else x
    )
    
    # Split expert data into train and validation
    train_df, val_df = train_test_split(expert_df, test_size=0.2, random_state=42)
    
    # Enhanced class weight calculation
    def calculate_class_weights(masks):
        total_tokens = 0
        positive_tokens = 0
        for mask in masks:
            # Consider only first MAX_LEN tokens
            valid_mask = mask[:MAX_LEN]
            total_tokens += len(valid_mask)
            positive_tokens += sum(valid_mask)
        
        print(f"Positive tokens: {positive_tokens}/{total_tokens} ({positive_tokens/total_tokens:.4%})")
        
        # Handle case where there are no positive tokens
        if positive_tokens == 0:
            return torch.tensor([1.0, 1.0]).to(DEVICE)
        
        weight_positive = total_tokens / (2.0 * positive_tokens)
        weight_negative = total_tokens / (2.0 * (total_tokens - positive_tokens))
        
        print(f"Class weights - Negative: {weight_negative:.2f}, Positive: {weight_positive:.2f}")
        return torch.tensor([weight_negative, weight_positive]).to(DEVICE)
    
    # Calculate class weights using training data only
    class_weights = calculate_class_weights(train_df["rationale_mask"].tolist())
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    
    # Create datasets and dataloaders
    train_dataset = RationaleDataset(
        train_df["text"].tolist(),
        train_df["rationale_mask"].tolist(),
        tokenizer
    )
    val_dataset = RationaleDataset(
        val_df["text"].tolist(),
        val_df["rationale_mask"].tolist(),
        tokenizer
    )
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
    
    # Initialize model from distilled base
    model = DistilBertForTokenClassification.from_pretrained(
        MODEL_PATH, 
        num_labels=2,
        ignore_mismatched_sizes=True
    ).to(DEVICE)
    
    # Optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader)*EPOCHS)
    
    # Loss function with class weights
    loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
    
    # Training loop
    best_f1 = 0
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        
        for batch in progress_bar:
            inputs = {
                'input_ids': batch['input_ids'].to(DEVICE),
                'attention_mask': batch['attention_mask'].to(DEVICE),
            }
            labels = batch['labels'].to(DEVICE)
            
            # Forward pass
            outputs = model(**inputs)
            logits = outputs.logits
            
            # Calculate loss only on active tokens
            active_loss = inputs['attention_mask'].view(-1) == 1
            active_logits = logits.view(-1, 2)[active_loss]
            active_labels = labels.view(-1)[active_loss]
            
            if active_labels.numel() > 0:
                loss = loss_fn(active_logits, active_labels)
            else:
                loss = torch.tensor(0.0, requires_grad=True).to(DEVICE)
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            total_loss += loss.item()
            progress_bar.set_postfix({
                "loss": f"{loss.item():.4f}",
                "lr": f"{scheduler.get_last_lr()[0]:.2e}"
            })
        
        avg_train_loss = total_loss / len(train_loader)
        
        # Evaluation
        model.eval()
        all_preds = []
        all_labels = []
        val_loss = 0
        
        with torch.no_grad():
            for batch in val_loader:
                inputs = {
                    'input_ids': batch['input_ids'].to(DEVICE),
                    'attention_mask': batch['attention_mask'].to(DEVICE),
                }
                labels = batch['labels'].to(DEVICE)
                
                outputs = model(**inputs)
                logits = outputs.logits
                
                # Get predictions
                active_mask = inputs['attention_mask'].view(-1) == 1
                active_logits = logits.view(-1, 2)[active_mask]
                active_labels = labels.view(-1)[active_mask]
                
                preds = torch.argmax(active_logits, dim=-1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(active_labels.cpu().numpy())
                val_loss += loss_fn(active_logits, active_labels).item() if active_labels.numel() > 0 else 0
        
        # Calculate metrics
        if len(all_labels) > 0:
            accuracy = accuracy_score(all_labels, all_preds)
            f1 = f1_score(all_labels, all_preds, average="binary", zero_division=0)
            precision = precision_score(all_labels, all_preds, average="binary", zero_division=0)
            recall = recall_score(all_labels, all_preds, average="binary", zero_division=0)
            avg_val_loss = val_loss / len(val_loader)
            
            print(f"\nEpoch {epoch+1} Evaluation:")
            print(f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
            print(f"Accuracy: {accuracy:.4f}, F1: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")
            
            # Save best model
            if f1 > best_f1:
                best_f1 = f1
                model.save_pretrained(SAVE_PATH)
                tokenizer.save_pretrained(SAVE_PATH)
                print(f"New best model saved to {SAVE_PATH} with F1: {f1:.4f}")
        else:
            print(f"\nEpoch {epoch+1} Evaluation: No valid labels to evaluate")
    
    # Always save final model
    if not os.path.exists(SAVE_PATH):
        model.save_pretrained(SAVE_PATH)
        tokenizer.save_pretrained(SAVE_PATH)
        print(f"Final model saved to {SAVE_PATH}")
    
    print("Training complete!")
    return model

# Run rationale training
rationale_model = train_rationale_model()

# Step 5: Enhanced Inference Demo
class InferenceSystem:
    def __init__(self):
        # Configuration
        self.CLASSIFIER_PATH = "/kaggle/working/models/classification_model"
        self.RATIONALE_PATH = "/kaggle/working/models/rationale_model"
        self.DISTILLED_PATH = "/kaggle/working/models/distilled_model"
        
        # Load components
        self.tokenizer = self._load_tokenizer()
        self.classifier = self._load_classifier()
        self.rationale_extractor = self._load_rationale_extractor()
    
    def _load_tokenizer(self):
        # Try rationale path first, then fallback to classifier
        try:
            tokenizer = AutoTokenizer.from_pretrained(self.RATIONALE_PATH)
            print("Loaded tokenizer from rationale model")
            return tokenizer
        except:
            try:
                tokenizer = AutoTokenizer.from_pretrained(self.CLASSIFIER_PATH)
                print("Loaded tokenizer from classification model")
                return tokenizer
            except:
                tokenizer = AutoTokenizer.from_pretrained(self.DISTILLED_PATH)
                print("Loaded tokenizer from distilled model")
                return tokenizer
    
    def _load_classifier(self):
        try:
            return DistilBertForSequenceClassification.from_pretrained(
                self.CLASSIFIER_PATH
            ).to(DEVICE).eval()
        except Exception as e:
            print(f"Error loading classifier: {e}")
            return None
    
    def _load_rationale_extractor(self):
        try:
            model = DistilBertForTokenClassification.from_pretrained(
                self.RATIONALE_PATH
            ).to(DEVICE).eval()
            print("Loaded rationale extraction model")
            return model
        except Exception as e:
            print(f"Could not load rationale model: {e}")
            print("Using classification model with attention fallback")
            return None
    
    def predict(self, text):
        # Tokenize input
        inputs = self.tokenizer(
            text, 
            return_tensors="pt", 
            max_length=512, 
            padding="max_length", 
            truncation=True
        ).to(DEVICE)
        
        with torch.no_grad():
            # Predict outcome
            if self.classifier is None:
                raise RuntimeError("Classifier model failed to load")
                
            clf_output = self.classifier(**inputs)
            label = clf_output.logits.argmax(-1).item()
            outcome = "ALLOWED" if label == 1 else "DISMISSED"
            
            # Extract rationale
            if self.rationale_extractor:
                token_output = self.rationale_extractor(**inputs)
                mask = token_output.logits.argmax(-1).squeeze().cpu().numpy()
            else:
                # Fallback: use attention weights from classifier
                outputs = self.classifier(**inputs, output_attentions=True)
                attentions = torch.stack(outputs.attentions).mean(0).mean(1)[0, 0].cpu().numpy()
                mask = (attentions > attentions.mean()).astype(int)
            
            tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze().cpu().numpy())
            attention_mask = inputs["attention_mask"].squeeze().cpu().numpy()
            
            # Filter out special tokens and create highlighted text
            highlighted_tokens = []
            for i, (tok, m) in enumerate(zip(tokens, mask)):
                # Skip special tokens and padding
                if tok in [self.tokenizer.cls_token, 
                          self.tokenizer.sep_token, 
                          self.tokenizer.pad_token] or attention_mask[i] == 0:
                    continue
                    
                # Clean up token representation
                if tok.startswith("##"):
                    tok = tok[2:]
                    if highlighted_tokens:
                        highlighted_tokens[-1] += tok
                        continue
                
                if m == 1:
                    highlighted_tokens.append(f"[{tok}]")
                else:
                    highlighted_tokens.append(tok)
            
            # Convert to readable text
            rationale_text = " ".join(highlighted_tokens)
            # Clean up spacing around punctuation
            rationale_text = re.sub(r'\s+([.,;:!?])', r'\1', rationale_text)
            rationale_text = re.sub(r'\[\s+', '[', rationale_text)
            rationale_text = re.sub(r'\s+\]', ']', rationale_text)
        
        return outcome, rationale_text

# Test with sample case
sample_case = """
The appellant was charged under Section 302 of the Indian Penal Code for the murder of his neighbor.
Evidence shows the accused was present at the crime scene, and fingerprints match those found on the weapon.
However, the defense argues there was no motive and the forensic evidence was mishandled by police.
The prosecution maintains the circumstantial evidence is sufficient for conviction.
"""

print("Initializing inference system...")
inference_system = InferenceSystem()

print("\nRunning sample prediction...")
outcome, rationale = inference_system.predict(sample_case)

print("\n" + "="*60)
print(f"Predicted Outcome: {outcome}")
print("\nExtracted Rationale (key phrases in brackets):")
print("-"*60)
print(rationale)
print("="*60)

Using device: cuda
Loading CJPE dataset...
Dataset loaded successfully!
Tokenizing data...
Processing expert annotations for rationale extraction...


Expert rationale masks: 100%|██████████| 56/56 [00:00<00:00, 86.76it/s]


Preprocessing complete! Files saved to /kaggle/working/data/processed
Expert rationale coverage: 33.5658%
Loading token data...
Loaded 5082 samples
Loading teacher and student models...


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).



Starting distillation training for 1 epochs
Batch size: 8 (effective: 16)
Total batches: 636


Epoch 1/1: 100%|██████████| 636/636 [14:21<00:00,  1.36s/it, batch_loss=29.8640, avg_loss=38.5053, d_loss=58.5412, m_loss=1.4276, c_loss=0.2237]   


Epoch 1 Complete - Avg Loss: 38.4480

Distilled model saved to /kaggle/working/models/distilled_model


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at /kaggle/working/models/distilled_model and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded distilled model for classification fine-tuning


Epoch 1/3: 100%|██████████| 636/636 [05:28<00:00,  1.94it/s, loss=0.6919, lr=2.25e-05]



Epoch 1 Evaluation:
Train Loss: 0.6642, Val Loss: 0.6994
Accuracy: 0.5062, F1: 0.0343
Precision: 0.8800, Recall: 0.0175
New best model saved to /kaggle/working/models/classification_model with F1: 0.0343


Epoch 2/3: 100%|██████████| 636/636 [05:28<00:00,  1.94it/s, loss=0.6399, lr=7.50e-06]



Epoch 2 Evaluation:
Train Loss: 0.6399, Val Loss: 0.7036
Accuracy: 0.5627, F1: 0.3754
Precision: 0.6613, Recall: 0.2621
New best model saved to /kaggle/working/models/classification_model with F1: 0.3754


Epoch 3/3: 100%|██████████| 636/636 [05:28<00:00,  1.94it/s, loss=0.4454, lr=0.00e+00]



Epoch 3 Evaluation:
Train Loss: 0.5935, Val Loss: 0.7024
Accuracy: 0.5850, F1: 0.5144
Precision: 0.6223, Recall: 0.4384


Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at /kaggle/working/models/distilled_model and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


New best model saved to /kaggle/working/models/classification_model with F1: 0.5144
Training complete!
Loading expert data for rationale training...
Positive tokens: 7361/22528 (32.6749%)
Class weights - Negative: 0.74, Positive: 1.53


Epoch 1/10: 100%|██████████| 6/6 [00:02<00:00,  2.23it/s, loss=0.8230, lr=2.93e-05]



Epoch 1 Evaluation:
Train Loss: 0.7487, Val Loss: 0.6922
Accuracy: 0.6337, F1: 0.1383, Precision: 0.5723, Recall: 0.0787
New best model saved to /kaggle/working/models/rationale_model with F1: 0.1383


Epoch 2/10: 100%|██████████| 6/6 [00:02<00:00,  2.27it/s, loss=0.7122, lr=2.71e-05]



Epoch 2 Evaluation:
Train Loss: 0.6875, Val Loss: 0.6453
Accuracy: 0.5339, F1: 0.5978, Precision: 0.4412, Recall: 0.9266
New best model saved to /kaggle/working/models/rationale_model with F1: 0.5978


Epoch 3/10: 100%|██████████| 6/6 [00:02<00:00,  2.25it/s, loss=0.6684, lr=2.38e-05]



Epoch 3 Evaluation:
Train Loss: 0.6566, Val Loss: 0.6192
Accuracy: 0.6477, F1: 0.5623, Precision: 0.5249, Recall: 0.6054


Epoch 4/10: 100%|██████████| 6/6 [00:02<00:00,  2.27it/s, loss=0.5257, lr=1.96e-05]



Epoch 4 Evaluation:
Train Loss: 0.6069, Val Loss: 0.5814
Accuracy: 0.6500, F1: 0.5965, Precision: 0.5241, Recall: 0.6920


Epoch 5/10: 100%|██████████| 6/6 [00:02<00:00,  2.25it/s, loss=0.6070, lr=1.50e-05]



Epoch 5 Evaluation:
Train Loss: 0.5876, Val Loss: 0.5451
Accuracy: 0.6476, F1: 0.6265, Precision: 0.5187, Recall: 0.7910
New best model saved to /kaggle/working/models/rationale_model with F1: 0.6265


Epoch 6/10: 100%|██████████| 6/6 [00:02<00:00,  2.25it/s, loss=0.4812, lr=1.04e-05]



Epoch 6 Evaluation:
Train Loss: 0.5305, Val Loss: 0.5422
Accuracy: 0.6766, F1: 0.6152, Precision: 0.5540, Recall: 0.6916


Epoch 7/10: 100%|██████████| 6/6 [00:02<00:00,  2.25it/s, loss=0.4328, lr=6.18e-06]



Epoch 7 Evaluation:
Train Loss: 0.4893, Val Loss: 0.5259
Accuracy: 0.6803, F1: 0.6258, Precision: 0.5562, Recall: 0.7154


Epoch 8/10: 100%|██████████| 6/6 [00:02<00:00,  2.26it/s, loss=0.4798, lr=2.86e-06]



Epoch 8 Evaluation:
Train Loss: 0.4690, Val Loss: 0.5208
Accuracy: 0.6827, F1: 0.6275, Precision: 0.5591, Recall: 0.7150
New best model saved to /kaggle/working/models/rationale_model with F1: 0.6275


Epoch 9/10: 100%|██████████| 6/6 [00:02<00:00,  2.26it/s, loss=0.5694, lr=7.34e-07]



Epoch 9 Evaluation:
Train Loss: 0.4620, Val Loss: 0.5165
Accuracy: 0.6836, F1: 0.6342, Precision: 0.5583, Recall: 0.7340
New best model saved to /kaggle/working/models/rationale_model with F1: 0.6342


Epoch 10/10: 100%|██████████| 6/6 [00:02<00:00,  2.26it/s, loss=0.4133, lr=0.00e+00]



Epoch 10 Evaluation:
Train Loss: 0.4444, Val Loss: 0.5162
Accuracy: 0.6841, F1: 0.6346, Precision: 0.5589, Recall: 0.7340
New best model saved to /kaggle/working/models/rationale_model with F1: 0.6346
Training complete!
Initializing inference system...
Loaded tokenizer from rationale model
Loaded rationale extraction model

Running sample prediction...

Predicted Outcome: DISMISSED

Extracted Rationale (key phrases in brackets):
------------------------------------------------------------
[the] [app]ellant [was] [charged] [under] [section] [302] [of] [the] [indian] [penal] [code] [for] [the] [murder] [of] [his] [neighbor] [.] [evidence] [shows] [the] [accused] [was] [present] [at] [the] [crime] [scene] [,] [and] [fingerprints] [match] [those] [found] [on] [the] [weapon] [.] [however] [,] [the] [defense] [argues] [there] [was] [no] [motive] [and] [the] [forensic] [evidence] [was] [mis]handled [by] [police] [.] [the] [prosecution] [maintains] [the] [ci]rcumstantial [evidence] [is] [suff