In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5ForConditionalGeneration
from torch.optim import AdamW
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
from tqdm import tqdm
import numpy as np
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


# ========================
# 🔧 CONFIG
# ========================
# model_path = "/home/liorkob/M.Sc/thesis/t5/mt5-mlm-final"
# model_path = "imvladikon/het5-base"
model_path = "google/mt5-base"
# model_path="/home/liorkob/M.Sc/thesis/t5/mt5-punishment-regression"
train_file = "/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_train.csv"
val_file = "/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_val.csv"
test_file = "/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_test.csv"
batch_size = 4
max_len = 512
epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========================
# 🧠 Dataset for T5
# ========================
class T5CitationDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=512):
        self.inputs = df.apply(lambda row: f"predict citation: {row['gpt_facts_a']} </s> {row['gpt_facts_b']}", axis=1).tolist()
        self.targets = df["label"].apply(lambda l: "yes" if l == 1 else "no").tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_enc = self.tokenizer(self.inputs[idx], padding='max_length', truncation=True, max_length=self.max_len, return_tensors="pt")
        target_enc = self.tokenizer(self.targets[idx], padding='max_length', truncation=True, max_length=4, return_tensors="pt")
        return {
            "input_ids": input_enc["input_ids"].squeeze(0),
            "attention_mask": input_enc["attention_mask"].squeeze(0),
            "labels": target_enc["input_ids"].squeeze(0)
        }

# ========================
# 📥 Load Data
# ========================
# tokenizer = AutoTokenizer.from_pretrained('google/mt5-large')
# model = T5ForConditionalGeneration.from_pretrained(model_path).to(device)
# model.gradient_checkpointing_enable()  # ✅ חיסכון בזיכרון

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)
model.gradient_checkpointing_enable()  # ✅ חיסכון בזיכרון

df_train = pd.read_csv(train_file)
df_val = pd.read_csv(val_file)
df_test = pd.read_csv(test_file)

train_dataset = T5CitationDataset(df_train, tokenizer, max_len=max_len)
val_dataset = T5CitationDataset(df_val, tokenizer, max_len=max_len)
test_dataset = T5CitationDataset(df_test, tokenizer, max_len=max_len)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# ========================
# 🔁 Training Loop
# ========================
optimizer = AdamW(model.parameters(), lr=2e-5)

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()
        torch.cuda.empty_cache()  # ✅ ריקון בין צעדים

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")

# ========================
# 📊 Evaluation
# ========================
def evaluate(model, dataloader, tokenizer, device):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model.generate(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
            pred_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            label_texts = tokenizer.batch_decode(batch["labels"], skip_special_tokens=True)

            preds.extend([1 if p.strip().lower() == "yes" else 0 for p in pred_texts])
            labels.extend([1 if l.strip().lower() == "yes" else 0 for l in label_texts])

    preds = np.array(preds)
    labels = np.array(labels)

    print(f"AUC-ROC: {roc_auc_score(labels, preds):.4f}")
    print(f"F1 Score: {f1_score(labels, preds):.4f}")
    print(f"Precision: {precision_score(labels, preds):.4f}")
    print(f"Recall: {recall_score(labels, preds):.4f}")

    return preds, labels

print("\n🔍 Validation Set:")
evaluate(model, val_loader, tokenizer, device)
torch.cuda.empty_cache()

print("\n🧪 Test Set:")
evaluate(model, test_loader, tokenizer, device)

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5ForConditionalGeneration, AutoTokenizer
from torch.optim import AdamW
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, classification_report
from tqdm import tqdm
import numpy as np

# ========================
# 🔧 CONFIG
# ========================
model_path = "imvladikon/het5-base"
train_file = "/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_train.csv"
val_file = "/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_val.csv"
test_file = "/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_test.csv"
batch_size = 4
max_len = 512
epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

# ========================
# 🎯 CORE: Improved Logits Classification
# ========================
def classify_with_threshold_search(model, tokenizer, input_ids, attention_mask, threshold=0.0):
    """Classify using threshold-based method"""
    with torch.no_grad():
        batch_size = input_ids.shape[0]
        
        decoder_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=input_ids.device)
        decoder_input_ids[:, 0] = tokenizer.pad_token_id
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids
        )
        
        logits = outputs.logits[:, -1, :]
        
        yes_tokens = [259, 1903]  # כן
        no_tokens = [1124]        # לא
        
        predictions = []
        scores = []
        
        for batch_idx in range(batch_size):
            batch_logits = logits[batch_idx]
            
            yes_score = torch.mean(batch_logits[yes_tokens]).item()
            no_score = torch.mean(batch_logits[no_tokens]).item()
            
            score_diff = yes_score - no_score
            
            if score_diff > threshold:
                prediction = 1
                predicted_text = "כן"
            else:
                prediction = 0
                predicted_text = "לא"
            
            predictions.append(prediction)
            scores.append({
                'prediction': prediction,
                'predicted_text': predicted_text,
                'score_diff': score_diff,
                'yes_score': yes_score,
                'no_score': no_score
            })
        
        return predictions, scores

def find_best_threshold(model, tokenizer, dataloader, device, true_labels):
    """Find optimal threshold for balanced predictions"""
    print("🔍 Finding best threshold...")
    
    all_score_diffs = []
    
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Collecting scores"):
            batch = {k: v.to(device) for k, v in batch.items()}
            
            _, scores = classify_with_threshold_search(
                model, tokenizer,
                batch["input_ids"],
                batch["attention_mask"],
                threshold=0.0
            )
            
            for score in scores:
                all_score_diffs.append(score['score_diff'])
    
    # Test thresholds
    thresholds = np.linspace(min(all_score_diffs), max(all_score_diffs), 50)
    best_threshold = 0.0
    best_f1 = 0.0
    
    for threshold in thresholds:
        predictions = []
        
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            batch_preds, _ = classify_with_threshold_search(
                model, tokenizer,
                batch["input_ids"],
                batch["attention_mask"],
                threshold=threshold
            )
            
            predictions.extend(batch_preds)
        
        predictions = np.array(predictions)
        f1 = f1_score(true_labels, predictions, zero_division=0)
        
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
    
    print(f"Best threshold: {best_threshold:.4f} (F1: {best_f1:.4f})")
    return best_threshold

# ========================
# 🧠 IMPROVED Dataset with Specific Legal Prompts
# ========================
class LegalSentencingCitationDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=512, prompt_version=1):
        """
        Dataset with legally-specific prompts for sentencing citation prediction
        """
        
        # Multiple versions of detailed legal prompts
        self.legal_prompts = {
            1: {
                "hebrew": """משימה: חיזוי ציטוטים לתמיכה במדיניות גזר דין
הקשר: בפסקי דין פליליים, שופטים מצטטים פסקי דין קודמים כדי לתמוך בטווח הענישה שהם מציעים. לא כל הציטוטים רלוונטיים - אנו מתמקדים רק בציטוטים התומכים בהחלטות טווח הענישה.
שאלה: האם פסק דין א' יצטט פסק דין ב' כדי לתמוך במדיניות גזר הדין שלו, על בסיס עובדות כתב האישום?""",
                "english": """Task: Predict citations supporting sentencing policy decisions
Context: In criminal verdicts, judges cite previous rulings to support their proposed sentencing range. Not all citations are relevant - we focus specifically on citations supporting sentencing range decisions.
Question: Will verdict A cite verdict B to support its sentencing policy, based on indictment facts?"""
            },
            
            2: {
                "hebrew": """ניתוח ציטוטים בפסקי דין פליליים
מטרה: זיהוי ציטוטים הרלוונטיים למדיניות ענישה
הגדרה: ציטוט רלוונטי = הפניה לפסק דין קודם המשמש כתקדים לטווח העונש המוצע
מיקום: בדרך כלל בחלק "מדיניות הענישה" או "טווח הענישה" של פסק הדין
שאלה: בהתבסס על עובדות כתב האישום, האם צפוי שפסק דין א' יצטט פסק דין ב' לתמיכה בטווח הענישה?""",
                "english": """Criminal verdict citation analysis
Goal: Identify citations relevant to sentencing policy
Definition: Relevant citation = reference to prior ruling used as precedent for proposed punishment range
Location: Typically found in "Sentencing Policy" or "Sentencing Range" sections
Question: Based on indictment facts, will verdict A likely cite verdict B to support sentencing range?"""
            },
            
            3: {
                "hebrew": """מערכת חיזוי ציטוטים משפטיים מתמחה
תחום: דין פלילי - מדיניות ענישה
מטרה: חיזוי ציטוטים בין פסקי דין על בסיס דמיון בעובדות כתב האישום
קריטריונים: ציטוט רלוונטי אם הוא תומך בהחלטת טווח העונש (לא הליכים, הגדרות, או פסקי דין לא קשורים)
פסק דין א' יצטט פסק דין ב' אם יש דמיון בעבירות ובנסיבות העוולות המוצגות בכתבי האישום.""",
                "english": """Specialized legal citation prediction system
Domain: Criminal law - sentencing policy
Purpose: Predict citations between verdicts based on indictment facts similarity
Criteria: Citation is relevant if it supports sentencing range decision (not procedures, definitions, or unrelated verdicts)
Verdict A will cite verdict B if there is similarity in offenses and circumstances presented in indictments."""
            }
        }
        
        chosen_prompt = self.legal_prompts[prompt_version]["hebrew"]
        
        # Create more detailed inputs with legal context
        self.inputs = []
        for idx, row in df.iterrows():
            # Format with detailed legal context
            legal_input = f"""{chosen_prompt}

עובדות כתב אישום - פסק דין א':
{row['gpt_facts_a']}

עובדות כתב אישום - פסק דין ב':
{row['gpt_facts_b']}

על בסיס דמיון העבירות והנסיבות, האם פסק דין א' יצטט פסק דין ב' לתמיכה במדיניות הענישה?"""
            
            self.inputs.append(legal_input)
        
        self.targets = df["label"].apply(lambda l: "כן" if l == 1 else "לא").tolist()
        self.labels = df["label"].values
        
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        print(f"Legal Dataset created: {len(self.inputs)} samples")
        print(f"Label distribution: {np.bincount(self.labels)}")
        print(f"Sample input length: {len(self.inputs[0])} characters")
        print(f"Prompt version: {prompt_version}")

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_text = self.inputs[idx]
        target_text = self.targets[idx]
        
        # Tokenize with longer sequences due to detailed prompt
        input_enc = self.tokenizer(
            input_text, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_len, 
            return_tensors="pt"
        )
        
        target_enc = self.tokenizer(
            target_text, 
            padding='max_length', 
            truncation=True, 
            max_length=5,
            return_tensors="pt"
        )
        
        labels = target_enc["input_ids"].squeeze(0)
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        return {
            "input_ids": input_enc["input_ids"].squeeze(0),
            "attention_mask": input_enc["attention_mask"].squeeze(0),
            "labels": labels,
            "numeric_label": self.labels[idx]
        }

# ========================
# 📊 Evaluation Function
# ========================
def evaluate_legal_model(model, dataloader, tokenizer, device, use_threshold_tuning=True):
    """Evaluate the legal citation model"""
    model.eval()
    
    # Collect true labels
    true_labels = []
    for batch in dataloader:
        true_labels.extend(batch["numeric_label"].numpy())
    true_labels = np.array(true_labels)
    
    all_predictions = []
    all_confidence_scores = []

    if use_threshold_tuning:
        best_threshold = find_best_threshold(model, tokenizer, dataloader, device, true_labels)

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Legal Evaluation"):
                batch = {k: v.to(device) for k, v in batch.items()}

                predictions, confidence_scores = classify_with_threshold_search(
                    model, tokenizer,
                    batch["input_ids"],
                    batch["attention_mask"],
                    threshold=best_threshold
                )

                all_predictions.extend(predictions)
                all_confidence_scores.extend(confidence_scores)
    else:
        best_threshold = None
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Legal Evaluation (No Threshold Tuning)"):
                batch = {k: v.to(device) for k, v in batch.items()}
                generated = model.generate(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    max_length=5
                )
                decoded_preds = tokenizer.batch_decode(generated, skip_special_tokens=True)
                predictions = [1 if p.strip() == "כן" else 0 for p in decoded_preds]

                all_predictions.extend(predictions)
                all_confidence_scores.extend([{} for _ in predictions])  # <- דמmy score dicts

    predictions = np.array(all_predictions)

    # Calculate metrics
    f1 = f1_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions, zero_division=0)
    recall = recall_score(true_labels, predictions, zero_division=0)
    accuracy = np.mean(predictions == true_labels)

    print(f"\n📊 LEGAL CITATION PREDICTION RESULTS:")
    print(f"F1 Score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"Accuracy: {accuracy:.4f}")

    print(f"\nPrediction Distribution: {np.bincount(predictions)}")
    print(f"True Label Distribution: {np.bincount(true_labels)}")

    if len(np.unique(predictions)) > 1 and len(np.unique(true_labels)) > 1:
        auc = roc_auc_score(true_labels, predictions)
        print(f"AUC-ROC: {auc:.4f}")

    print(f"\nClassification Report:")
    print(classification_report(true_labels, predictions))

    return f1, {
        'precision': precision,
        'recall': recall,
        'accuracy': accuracy,
        'predictions': predictions,
        'threshold': best_threshold,
        'scores': all_confidence_scores  # ← תמיד קיים
    }


# ========================
# 📥 Load Everything
# ========================
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Loading data...")
df_train = pd.read_csv(train_file)
df_val = pd.read_csv(val_file)
df_test = pd.read_csv(test_file)

# Try different prompt versions
prompt_versions_to_try = [1, 2, 3]
best_prompt_version = 1
best_baseline_f1 = 0

print("\n🔍 TESTING DIFFERENT LEGAL PROMPTS:")
print("="*60)

for prompt_version in prompt_versions_to_try:
    print(f"\n📋 Testing Prompt Version {prompt_version}:")
    
    # Create datasets with this prompt version
    val_dataset = LegalSentencingCitationDataset(df_val, tokenizer, max_len=max_len, prompt_version=prompt_version)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    # Test baseline performance
    baseline_f1, baseline_metrics = evaluate_legal_model(model, val_loader, tokenizer, device, use_threshold_tuning=True)
    
    print(f"Prompt {prompt_version} Baseline F1: {baseline_f1:.4f}")
    
    if baseline_f1 > best_baseline_f1:
        best_baseline_f1 = baseline_f1
        best_prompt_version = prompt_version

print(f"\n🏆 BEST PROMPT VERSION: {best_prompt_version} (F1: {best_baseline_f1:.4f})")

# ========================
# 🏗️ Create Final Datasets with Best Prompt
# ========================
print(f"\nCreating final datasets with prompt version {best_prompt_version}...")
train_dataset = LegalSentencingCitationDataset(df_train, tokenizer, max_len=max_len, prompt_version=best_prompt_version)
val_dataset = LegalSentencingCitationDataset(df_val, tokenizer, max_len=max_len, prompt_version=best_prompt_version)
test_dataset = LegalSentencingCitationDataset(df_test, tokenizer, max_len=max_len, prompt_version=best_prompt_version)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# ========================
# 🔁 Training (if needed)
# ========================
if best_baseline_f1 < 0.6:
    print(f"\nBaseline F1 ({best_baseline_f1:.4f}) needs improvement. Starting training...")
    
    optimizer = AdamW(model.parameters(), lr=2e-5)
    best_val_f1 = best_baseline_f1
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch in progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"]
            )
            
            loss = outputs.loss
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
            
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        print(f"\nEpoch {epoch+1}: Average Loss = {total_loss / len(train_loader):.4f}")
        
        # Validation
        val_f1, val_metrics = evaluate_legal_model(model, val_loader, tokenizer, device, use_threshold_tuning=True)
        
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(model.state_dict(), "best_legal_citation_model.pt")
            print(f"✅ New best model! F1: {best_val_f1:.4f}")
else:
    print(f"Baseline F1 ({best_baseline_f1:.4f}) is already good!")

# ========================
# 🧪 Final Test Evaluation
# ========================
print("\n" + "="*80)
print("🧪 FINAL LEGAL CITATION PREDICTION TEST:")
print("="*80)

test_f1, test_metrics = evaluate_legal_model(model, test_loader, tokenizer, device, use_threshold_tuning=True)

print(f"\n🎯 FINAL RESULTS:")
print(f"Best Prompt Version: {best_prompt_version}")
print(f"Baseline F1: {best_baseline_f1:.4f}")
print(f"Final Test F1: {test_f1:.4f}")
print(f"Final Test Accuracy: {test_metrics['accuracy']:.4f}")
print(f"Optimal Threshold: {test_metrics['threshold']:.4f}")

# Save results
results = {
    'best_prompt_version': best_prompt_version,
    'baseline_f1': best_baseline_f1,
    'test_f1': test_f1,
    'test_accuracy': test_metrics['accuracy'],
    'optimal_threshold': test_metrics['threshold'],
    'model_type': 'legal_sentencing_citation_prediction'

}

import json
with open('legal_citation_results.json', 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

if test_f1 > 0.6:
    print(f"\n🎉 SUCCESS! Legal citation model achieved {test_f1:.4f} F1 score!")
    print(f"🏆 The detailed legal prompts significantly improved performance!")
else:
    print(f"\n📈 F1: {test_f1:.4f} - Consider fine-tuning hyperparameters or trying longer training")

print(f"\n💾 Results saved to 'legal_citation_results.json'")

Using device: cuda
Loading model and tokenizer...
Loading data...

🔍 TESTING DIFFERENT LEGAL PROMPTS:

📋 Testing Prompt Version 1:
Legal Dataset created: 869 samples
Label distribution: [579 290]
Sample input length: 3919 characters
Prompt version: 1
🔍 Finding best threshold...


Collecting scores: 100%|██████████| 218/218 [00:06<00:00, 32.23it/s]


Best threshold: 3.4868 (F1: 0.5009)


Legal Evaluation: 100%|██████████| 218/218 [00:06<00:00, 31.32it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5009
Precision: 0.3382
Recall: 0.9655
Accuracy: 0.3579

Prediction Distribution: [ 41 828]
True Label Distribution: [579 290]
AUC-ROC: 0.5095

Classification Report:
              precision    recall  f1-score   support

           0       0.76      0.05      0.10       579
           1       0.34      0.97      0.50       290

    accuracy                           0.36       869
   macro avg       0.55      0.51      0.30       869
weighted avg       0.62      0.36      0.23       869

Prompt 1 Baseline F1: 0.5009

📋 Testing Prompt Version 2:
Legal Dataset created: 869 samples
Label distribution: [579 290]
Sample input length: 3919 characters
Prompt version: 2
🔍 Finding best threshold...


Collecting scores: 100%|██████████| 218/218 [00:06<00:00, 31.23it/s]


Best threshold: 3.5317 (F1: 0.5085)


Legal Evaluation: 100%|██████████| 218/218 [00:06<00:00, 31.15it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5085
Precision: 0.3589
Recall: 0.8724
Accuracy: 0.4373

Prediction Distribution: [164 705]
True Label Distribution: [579 290]
AUC-ROC: 0.5459

Classification Report:
              precision    recall  f1-score   support

           0       0.77      0.22      0.34       579
           1       0.36      0.87      0.51       290

    accuracy                           0.44       869
   macro avg       0.57      0.55      0.43       869
weighted avg       0.64      0.44      0.40       869

Prompt 2 Baseline F1: 0.5085

📋 Testing Prompt Version 3:
Legal Dataset created: 869 samples
Label distribution: [579 290]
Sample input length: 3924 characters
Prompt version: 3
🔍 Finding best threshold...


Collecting scores: 100%|██████████| 218/218 [00:06<00:00, 31.20it/s]


Best threshold: 3.5402 (F1: 0.5142)


Legal Evaluation: 100%|██████████| 218/218 [00:07<00:00, 31.05it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5142
Precision: 0.3837
Recall: 0.7793
Accuracy: 0.5086

Prediction Distribution: [280 589]
True Label Distribution: [579 290]
AUC-ROC: 0.5762

Classification Report:
              precision    recall  f1-score   support

           0       0.77      0.37      0.50       579
           1       0.38      0.78      0.51       290

    accuracy                           0.51       869
   macro avg       0.58      0.58      0.51       869
weighted avg       0.64      0.51      0.51       869

Prompt 3 Baseline F1: 0.5142

🏆 BEST PROMPT VERSION: 3 (F1: 0.5142)

Creating final datasets with prompt version 3...
Legal Dataset created: 4052 samples
Label distribution: [2699 1353]
Sample input length: 1989 characters
Prompt version: 3
Legal Dataset created: 869 samples
Label distribution: [579 290]
Sample input length: 3924 characters
Prompt version: 3
Legal Dataset created: 870 samples
Label distribution: [579 291]
Sample input length: 2199 chara

Epoch 1/5: 100%|██████████| 1013/1013 [02:05<00:00,  8.07it/s, loss=0.3464]



Epoch 1: Average Loss = 3.2148
🔍 Finding best threshold...


Collecting scores: 100%|██████████| 218/218 [00:07<00:00, 29.34it/s]


Best threshold: -5.6755 (F1: 0.5115)


Legal Evaluation: 100%|██████████| 218/218 [00:07<00:00, 29.23it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5115
Precision: 0.3488
Recall: 0.9586
Accuracy: 0.3890

Prediction Distribution: [ 72 797]
True Label Distribution: [579 290]
AUC-ROC: 0.5311

Classification Report:
              precision    recall  f1-score   support

           0       0.83      0.10      0.18       579
           1       0.35      0.96      0.51       290

    accuracy                           0.39       869
   macro avg       0.59      0.53      0.35       869
weighted avg       0.67      0.39      0.29       869



Epoch 2/5: 100%|██████████| 1013/1013 [02:05<00:00,  8.10it/s, loss=0.3013]



Epoch 2: Average Loss = 0.3555
🔍 Finding best threshold...


Collecting scores: 100%|██████████| 218/218 [00:07<00:00, 29.48it/s]


Best threshold: -5.2674 (F1: 0.5227)


Legal Evaluation: 100%|██████████| 218/218 [00:07<00:00, 29.33it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5227
Precision: 0.3763
Recall: 0.8552
Accuracy: 0.4787

Prediction Distribution: [210 659]
True Label Distribution: [579 290]
AUC-ROC: 0.5727

Classification Report:
              precision    recall  f1-score   support

           0       0.80      0.29      0.43       579
           1       0.38      0.86      0.52       290

    accuracy                           0.48       869
   macro avg       0.59      0.57      0.47       869
weighted avg       0.66      0.48      0.46       869

✅ New best model! F1: 0.5227


Epoch 3/5: 100%|██████████| 1013/1013 [02:04<00:00,  8.11it/s, loss=0.4273]



Epoch 3: Average Loss = 0.3281
🔍 Finding best threshold...


Collecting scores: 100%|██████████| 218/218 [00:07<00:00, 29.39it/s]


Best threshold: -5.2343 (F1: 0.5260)


Legal Evaluation: 100%|██████████| 218/218 [00:07<00:00, 29.26it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5260
Precision: 0.3625
Recall: 0.9586
Accuracy: 0.4235

Prediction Distribution: [102 767]
True Label Distribution: [579 290]
AUC-ROC: 0.5570

Classification Report:
              precision    recall  f1-score   support

           0       0.88      0.16      0.26       579
           1       0.36      0.96      0.53       290

    accuracy                           0.42       869
   macro avg       0.62      0.56      0.40       869
weighted avg       0.71      0.42      0.35       869

✅ New best model! F1: 0.5260


Epoch 4/5: 100%|██████████| 1013/1013 [02:05<00:00,  8.10it/s, loss=0.2293]



Epoch 4: Average Loss = 0.3010
🔍 Finding best threshold...


Collecting scores: 100%|██████████| 218/218 [00:07<00:00, 29.45it/s]


Best threshold: -4.6886 (F1: 0.5330)


Legal Evaluation: 100%|██████████| 218/218 [00:07<00:00, 29.32it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5330
Precision: 0.3947
Recall: 0.8207
Accuracy: 0.5201

Prediction Distribution: [266 603]
True Label Distribution: [579 290]
AUC-ROC: 0.5951

Classification Report:
              precision    recall  f1-score   support

           0       0.80      0.37      0.51       579
           1       0.39      0.82      0.53       290

    accuracy                           0.52       869
   macro avg       0.60      0.60      0.52       869
weighted avg       0.67      0.52      0.52       869

✅ New best model! F1: 0.5330


Epoch 5/5: 100%|██████████| 1013/1013 [02:04<00:00,  8.11it/s, loss=0.3243]



Epoch 5: Average Loss = 0.3167
🔍 Finding best threshold...


Collecting scores: 100%|██████████| 218/218 [00:07<00:00, 29.35it/s]


Best threshold: -6.5397 (F1: 0.5533)


Legal Evaluation: 100%|██████████| 218/218 [00:07<00:00, 29.26it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5533
Precision: 0.4239
Recall: 0.7966
Accuracy: 0.5708

Prediction Distribution: [324 545]
True Label Distribution: [579 290]
AUC-ROC: 0.6271

Classification Report:
              precision    recall  f1-score   support

           0       0.82      0.46      0.59       579
           1       0.42      0.80      0.55       290

    accuracy                           0.57       869
   macro avg       0.62      0.63      0.57       869
weighted avg       0.69      0.57      0.58       869

✅ New best model! F1: 0.5533

🧪 FINAL LEGAL CITATION PREDICTION TEST:
🔍 Finding best threshold...


Collecting scores: 100%|██████████| 218/218 [00:07<00:00, 29.64it/s]


Best threshold: -6.5047 (F1: 0.5338)


Legal Evaluation: 100%|██████████| 218/218 [00:07<00:00, 29.37it/s]


📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5338
Precision: 0.4157
Recall: 0.7457
Accuracy: 0.5644

Prediction Distribution: [348 522]
True Label Distribution: [579 291]
AUC-ROC: 0.6095

Classification Report:
              precision    recall  f1-score   support

           0       0.79      0.47      0.59       579
           1       0.42      0.75      0.53       291

    accuracy                           0.56       870
   macro avg       0.60      0.61      0.56       870
weighted avg       0.66      0.56      0.57       870


🎯 FINAL RESULTS:
Best Prompt Version: 3
Baseline F1: 0.5142
Final Test F1: 0.5338
Final Test Accuracy: 0.5644
Optimal Threshold: -6.5047

📈 F1: 0.5338 - Consider fine-tuning hyperparameters or trying longer training

💾 Results saved to 'legal_citation_results.json'





In [4]:
def show_sample_predictions(dataset, predictions, scores, tokenizer, num_samples=10):
    print(f"\n🔎 Showing {num_samples} Sample Predictions:")
    print("="*80)

    indices = np.random.choice(len(dataset), size=min(num_samples, len(dataset)), replace=False)

    for idx in indices:
        input_ids = dataset[idx]["input_ids"]
        decoded_input = tokenizer.decode(input_ids, skip_special_tokens=True)

        label = dataset[idx]["numeric_label"]
        prediction = predictions[idx]
        score = scores[idx]
        
        print(f"\n📄 Input #{idx}")
        print("-" * 80)
        print(f"🔹 Decoded Input:\n{decoded_input[:1000]}...")  # Truncate long text
        print(f"✅ True Label: {'כן' if label == 1 else 'לא'}")
        print(f"🧠 Predicted: {'כן' if prediction == 1 else 'לא'}")
        print(f"🧮 Score Diff: {score['score_diff']:.4f} | Yes Score: {score['yes_score']:.4f} | No Score: {score['no_score']:.4f}")
        print("=" * 80)
show_sample_predictions(test_dataset, test_metrics["predictions"], test_metrics["scores"], tokenizer, num_samples=10)



🔎 Showing 10 Sample Predictions:

📄 Input #379
--------------------------------------------------------------------------------
🔹 Decoded Input:
מערכת חיזוי ציטוטים משפטיים מתמחה תחום: דין פלילי - מדיניות ענישה מטרה: חיזוי ציטוטים בין פסקי דין על בסיס דמיון בעובדות כתב האישום קריטריונים: ציטוט רלוונטי אם הוא תומך בהחלטת טווח העונש (לא הליכים, הגדרות, או פסקי דין לא קשורים) פסק דין א' יצטט פסק דין ב' אם יש דמיון בעבירות ובנסיבות העוולות המוצגות בכתבי האישום. עובדות כתב אישום - פסק דין א': הנאשם הורשע על פי הודאתו בעבירות של החזקת חלק של נשק או תחמושת, לפי סעיף 144 (א) לחוק העונשין, תשל"ז 1977 ונשיאה/הובלת חלק של נשק או תחמושת, לפי סעיף 144(ב) לחוק העונשין. על פי הנטען בכתב האישום, ביום 28.8.2022 בשעה 00:20 לערך, נהג הנאשם ברכב מסוג קיה ספורטג' עם לוחית רישוי מספר 13-608-201 לכיוון מעבר הל"ה בדרכו לשטחי האזור, כאשר מתחת למושב הנהג ברכב הייתה שקית ובה 6 מכלולים של נשק מסוג M16. בנוסף, בתא המטען של הרכב נשא שבעה ארגזי תחמושת וארגז קרטון שהכילו יחדיו כ-9000 כדורים בקוטר 5.56 מ"מ, שהיו מכוס

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5ForConditionalGeneration, AutoTokenizer
from torch.optim import AdamW
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, classification_report
from tqdm import tqdm
import numpy as np


from scipy import stats
from collections import defaultdict
from torch.utils.data import Subset
from sklearn.model_selection import KFold
import copy

# ========================
# 🔧 CONFIG
# ========================
train_file = "/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_train.csv"
val_file = "/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_val.csv"
test_file = "/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_test.csv"
batch_size = 4
max_len = 1024
epochs = 7
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

# ========================
# 🎯 CORE: Improved Logits Classification
# ========================
def classify_with_threshold_search(model, tokenizer, input_ids, attention_mask, threshold=0.0):
    """Classify using threshold-based method"""
    with torch.no_grad():
        batch_size = input_ids.shape[0]
        
        decoder_input_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=input_ids.device)
        decoder_input_ids[:, 0] = tokenizer.pad_token_id
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids
        )
        
        logits = outputs.logits[:, -1, :]
        
        yes_tokens = [259, 1903]  # כן
        no_tokens = [1124]        # לא
        
        predictions = []
        scores = []
        
        for batch_idx in range(batch_size):
            batch_logits = logits[batch_idx]
            
            yes_score = torch.mean(batch_logits[yes_tokens]).item()
            no_score = torch.mean(batch_logits[no_tokens]).item()
            
            score_diff = yes_score - no_score
            
            if score_diff > threshold:
                prediction = 1
                predicted_text = "כן"
            else:
                prediction = 0
                predicted_text = "לא"
            
            predictions.append(prediction)
            scores.append({
                'prediction': prediction,
                'predicted_text': predicted_text,
                'score_diff': score_diff,
                'yes_score': yes_score,
                'no_score': no_score
            })
        
        return predictions, scores

def find_best_threshold(model, tokenizer, dataloader, device, true_labels):
    """Find optimal threshold for balanced predictions"""
    print("🔍 Finding best threshold...")
    
    all_score_diffs = []
    
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Collecting scores"):
            batch = {k: v.to(device) for k, v in batch.items()}
            
            _, scores = classify_with_threshold_search(
                model, tokenizer,
                batch["input_ids"],
                batch["attention_mask"],
                threshold=0.0
            )
            
            for score in scores:
                all_score_diffs.append(score['score_diff'])
    
    # Test thresholds
    thresholds = np.linspace(min(all_score_diffs), max(all_score_diffs), 50)
    best_threshold = 0.0
    best_f1 = 0.0
    
    for threshold in thresholds:
        predictions = []
        
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            batch_preds, _ = classify_with_threshold_search(
                model, tokenizer,
                batch["input_ids"],
                batch["attention_mask"],
                threshold=threshold
            )
            
            predictions.extend(batch_preds)
        
        predictions = np.array(predictions)
        f1 = f1_score(true_labels, predictions, zero_division=0)
        
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
    
    print(f"Best threshold: {best_threshold:.4f} (F1: {best_f1:.4f})")
    return best_threshold

# ========================
# 🧠 IMPROVED Dataset with Specific Legal Prompts
# ========================
class LegalSentencingCitationDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=512):
        """
        Dataset with legally-specific prompts for sentencing citation prediction
        """
        
        # Multiple versions of detailed legal prompts
        prompt = """מערכת חיזוי ציטוטים משפטיים מתמחה
תחום: דין פלילי - מדיניות ענישה
מטרה: חיזוי ציטוטים בין פסקי דין על בסיס דמיון בעובדות כתב האישום
קריטריונים: ציטוט רלוונטי אם הוא תומך בהחלטת טווח העונש (לא הליכים, הגדרות, או פסקי דין לא קשורים)
פסק דין א' יצטט פסק דין ב' אם יש דמיון בעבירות ובנסיבות העוולות המוצגות בכתבי האישום.
שאלה: בהתבסס על עובדות כתב האישום, האם צפוי שפסק דין א' יצטט פסק דין ב' לתמיכה בטווח הענישה?
"""

#                 "english": """Specialized legal citation prediction system
# Domain: Criminal law - sentencing policy
# Purpose: Predict citations between verdicts based on indictment facts similarity
# Criteria: Citation is relevant if it supports sentencing range decision (not procedures, definitions, or unrelated verdicts)
# Verdict A will cite verdict B if there is similarity in offenses and circumstances presented in indictments.
# Question: Based on indictment facts, will verdict A likely cite verdict B to support sentencing range?"""


        
        # Create more detailed inputs with legal context
        self.inputs = []
        for idx, row in df.iterrows():
            # Format with detailed legal context
            legal_input = f"""{prompt}

עובדות כתב אישום - פסק דין א':
{row['gpt_facts_a']}

עובדות כתב אישום - פסק דין ב':
{row['gpt_facts_b']}

על בסיס דמיון העבירות והנסיבות, האם פסק דין א' יצטט פסק דין ב' לתמיכה במדיניות הענישה?"""
            
            self.inputs.append(legal_input)
        
        self.targets = df["label"].apply(lambda l: "כן" if l == 1 else "לא").tolist()
        self.labels = df["label"].values
        
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        print(f"Legal Dataset created: {len(self.inputs)} samples")
        print(f"Label distribution: {np.bincount(self.labels)}")
        print(f"Sample input length: {len(self.inputs[0])} characters")

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_text = self.inputs[idx]
        target_text = self.targets[idx]
        
        # Tokenize with longer sequences due to detailed prompt
        input_enc = self.tokenizer(
            input_text, 
            padding='max_length', 
            truncation=True, 
            max_length=self.max_len, 
            return_tensors="pt"
        )
        
        target_enc = self.tokenizer(
            target_text, 
            padding='max_length', 
            truncation=True, 
            max_length=5,
            return_tensors="pt"
        )
        
        labels = target_enc["input_ids"].squeeze(0)
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        return {
            "input_ids": input_enc["input_ids"].squeeze(0),
            "attention_mask": input_enc["attention_mask"].squeeze(0),
            "labels": labels,
            "numeric_label": self.labels[idx]
        }

# ========================
# 📊 Evaluation Function
# ========================
def evaluate_legal_model(model, dataloader, tokenizer, device, use_threshold_tuning=True, fix_threshold=None):
    """Evaluate the legal citation model"""
    model.eval()
    
    # Collect true labels
    true_labels = []
    for batch in dataloader:
        true_labels.extend(batch["numeric_label"].numpy())
    true_labels = np.array(true_labels)
    
    all_predictions = []
    all_confidence_scores = []

    # THREE OPTIONS NOW:
    if fix_threshold is not None:
        # ✅ USE FIXED THRESHOLD (no tuning)
        print(f"Using fixed threshold: {fix_threshold}")
        best_threshold = fix_threshold
        
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Legal Evaluation (Fixed Threshold)"):
                batch = {k: v.to(device) for k, v in batch.items()}

                predictions, confidence_scores = classify_with_threshold_search(
                    model, tokenizer,
                    batch["input_ids"],
                    batch["attention_mask"],
                    threshold=best_threshold
                )

                all_predictions.extend(predictions)
                all_confidence_scores.extend(confidence_scores)
                
    elif use_threshold_tuning:
        # ❌ TUNE THRESHOLD ON THIS DATASET (causes data leakage if used on test set)
        best_threshold = find_best_threshold(model, tokenizer, dataloader, device, true_labels)

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Legal Evaluation (Tuned Threshold)"):
                batch = {k: v.to(device) for k, v in batch.items()}

                predictions, confidence_scores = classify_with_threshold_search(
                    model, tokenizer,
                    batch["input_ids"],
                    batch["attention_mask"],
                    threshold=best_threshold
                )

                all_predictions.extend(predictions)
                all_confidence_scores.extend(confidence_scores)
    else:
        # 🔄 USE MODEL.GENERATE() (different approach)
        best_threshold = None
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Legal Evaluation (Generation)"):
                batch = {k: v.to(device) for k, v in batch.items()}
                generated = model.generate(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    max_length=5
                )
                decoded_preds = tokenizer.batch_decode(generated, skip_special_tokens=True)
                predictions = [1 if p.strip() == "כן" else 0 for p in decoded_preds]

                all_predictions.extend(predictions)
                all_confidence_scores.extend([{} for _ in predictions])

    predictions = np.array(all_predictions)

    # Calculate metrics
    f1 = f1_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions, zero_division=0)
    recall = recall_score(true_labels, predictions, zero_division=0)
    accuracy = np.mean(predictions == true_labels)

    print(f"\n📊 LEGAL CITATION PREDICTION RESULTS:")
    print(f"F1 Score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"Accuracy: {accuracy:.4f}")

    print(f"\nPrediction Distribution: {np.bincount(predictions)}")
    print(f"True Label Distribution: {np.bincount(true_labels)}")

    if len(np.unique(predictions)) > 1 and len(np.unique(true_labels)) > 1:
        auc = roc_auc_score(true_labels, predictions)
        print(f"AUC-ROC: {auc:.4f}")

    print(f"\nClassification Report:")
    print(classification_report(true_labels, predictions))

    return f1, {
        'precision': precision,
        'recall': recall,
        'accuracy': accuracy,
        'predictions': predictions,
        'threshold': best_threshold,
        'scores': all_confidence_scores  # ← תמיד קיים
    }

# ========================
# 📥 Load Everything
# ========================

# model_paths = ["/home/liorkob/M.Sc/thesis/t5/het5-mlm-final","imvladikon/het5-base"]

# for model_path in model_paths:
#     print(f"Loading {model_path} and tokenizer...")
#     tokenizer = AutoTokenizer.from_pretrained(model_path)
#     model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)

#     if tokenizer.pad_token is None:
#         tokenizer.pad_token = tokenizer.eos_token

#     print("Loading data...")
#     df_train = pd.read_csv(train_file)
#     df_val = pd.read_csv(val_file)
#     df_test = pd.read_csv(test_file)

#     train_dataset = LegalSentencingCitationDataset(df_train, tokenizer, max_len=max_len)
#     val_dataset = LegalSentencingCitationDataset(df_val, tokenizer, max_len=max_len)
#     test_dataset = LegalSentencingCitationDataset(df_test, tokenizer, max_len=max_len)

#     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#     val_loader = DataLoader(val_dataset, batch_size=batch_size)
#     test_loader = DataLoader(test_dataset, batch_size=batch_size)

        
#     optimizer = AdamW(model.parameters(), lr=2e-5)
#     best_val_f1 = 0
#     best_val_threshold=0
#     for epoch in range(epochs):
#         model.train()
#         total_loss = 0
        
#         progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
#         for batch in progress_bar:
#             batch = {k: v.to(device) for k, v in batch.items()}
            
#             outputs = model(
#                 input_ids=batch["input_ids"],
#                 attention_mask=batch["attention_mask"],
#                 labels=batch["labels"]
#             )
            
#             loss = outputs.loss
#             loss.backward()
            
#             torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#             optimizer.step()
#             optimizer.zero_grad()
            
#             total_loss += loss.item()
#             progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
        
#         print(f"\nEpoch {epoch+1}: Average Loss = {total_loss / len(train_loader):.4f}")
        
#         # Validation
#         val_f1, val_metrics = evaluate_legal_model(model, val_loader, tokenizer, device, use_threshold_tuning=True)
        
#         if val_f1 > best_val_f1:
#             best_val_f1 = val_f1
#             torch.save(model.state_dict(), "best_legal_clm_citation_model.pt")
#             best_val_threshold = val_metrics['threshold']
#             print(f"✅ New best model! F1: {best_val_f1:.4f}")

#     # ========================
#     # 🧪 Final Test Evaluation
#     # ========================
#     print("\n" + "="*80)
#     print("🧪 FINAL LEGAL CITATION PREDICTION TEST:")
#     print("="*80)




#     test_f1, test_metrics = evaluate_legal_model(model, test_loader, tokenizer, device, 
#                                             use_threshold_tuning=False, 
#                                             fix_threshold=best_val_threshold)

#     print(f"\n🎯 FINAL RESULTS:")
#     print(f"Best Prompt Version: {best_prompt_version}")
#     print(f"Baseline F1: {best_baseline_f1:.4f}")
#     print(f"Final Test F1: {test_f1:.4f}")
#     print(f"Final Test Accuracy: {test_metrics['accuracy']:.4f}")
#     print(f"Optimal Threshold: {test_metrics['threshold']:.4f}")

#     # Save results
#     results = {
#         'best_prompt_version': best_prompt_version,
#         'baseline_f1': best_baseline_f1,
#         'test_f1': test_f1,
#         'test_accuracy': test_metrics['accuracy'],
#         'optimal_threshold': test_metrics['threshold'],
#         'model_type': 'legal_sentencing_citation_prediction'

#     }

# K-fold cross-validation setup
K_FOLDS = 5
RANDOM_SEED = 42

# Store results for statistical comparison
model_fold_results = defaultdict(list)  # {model_name: [fold1_f1, fold2_f1, ...]}
model_fold_accuracy = defaultdict(list)  # {model_name: [fold1_acc, fold2_acc, ...]}

def reset_model_weights(model):
    """Reset model weights to initial state for each fold"""
    for layer in model.children():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()

def create_k_fold_datasets(full_dataset, k_folds=K_FOLDS, random_seed=RANDOM_SEED):
    """Create k-fold splits of the dataset"""
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=random_seed)
    dataset_size = len(full_dataset)
    indices = list(range(dataset_size))
    
    fold_splits = []
    for train_indices, val_indices in kfold.split(indices):
        fold_splits.append((train_indices.tolist(), val_indices.tolist()))
    
    return fold_splits

# Load and prepare the full dataset (combine train, val, test for k-fold)
print("Loading and combining datasets for k-fold cross-validation...")
df_train = pd.read_csv(train_file)
df_val = pd.read_csv(val_file)
df_test = pd.read_csv(test_file)

# Combine all data for k-fold CV
df_full = pd.concat([df_train, df_val, df_test], ignore_index=True)
print(f"Total samples for k-fold CV: {len(df_full)}")

model_paths = ["/home/liorkob/M.Sc/thesis/t5/m5-mlm-final", "google/mt5-base"]

for model_idx, model_path in enumerate(model_paths):
    model_name = model_path.split('/')[-1]
    print(f"\n{'='*80}")
    print(f"🔄 STARTING K-FOLD EVALUATION FOR: {model_name}")
    print(f"{'='*80}")
    
    # Load tokenizer once
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Create full dataset for k-fold splitting
    full_dataset = LegalSentencingCitationDataset(df_full, tokenizer, max_len=max_len)
    fold_splits = create_k_fold_datasets(full_dataset, k_folds=K_FOLDS, random_seed=RANDOM_SEED)
    
    fold_f1_scores = []
    fold_accuracies = []
    
    for fold, (train_indices, test_indices) in enumerate(fold_splits):
        print(f"\n📁 FOLD {fold + 1}/{K_FOLDS}")
        print(f"Train samples: {len(train_indices)}, Test samples: {len(test_indices)}")
        
        # Create fold-specific datasets
        train_fold_dataset = Subset(full_dataset, train_indices)
        test_fold_dataset = Subset(full_dataset, test_indices)
        
        # Split training data into train/validation (80/20 split)
        train_size = int(0.8 * len(train_indices))
        val_size = len(train_indices) - train_size
        
        # Create train/val split from training indices
        train_train_indices = train_indices[:train_size]
        train_val_indices = train_indices[train_size:]
        
        train_train_dataset = Subset(full_dataset, train_train_indices)
        train_val_dataset = Subset(full_dataset, train_val_indices)
        
        # Create data loaders
        train_loader = DataLoader(train_train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(train_val_dataset, batch_size=batch_size)
        test_loader = DataLoader(test_fold_dataset, batch_size=batch_size)
        
        # Load fresh model for this fold
        print(f"Loading fresh model for fold {fold + 1}...")
        model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)
        optimizer = AdamW(model.parameters(), lr=2e-5)
        
        best_val_f1 = 0
        best_val_threshold = 0.5
        best_model_state = None
        
        # Training loop for this fold
        for epoch in range(epochs):
            model.train()
            total_loss = 0
            
            progress_bar = tqdm(train_loader, desc=f"Fold {fold+1}, Epoch {epoch+1}/{epochs}")
            
            for batch in progress_bar:
                batch = {k: v.to(device) for k, v in batch.items()}
                
                outputs = model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"]
                )
                
                loss = outputs.loss
                loss.backward()
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()
                
                total_loss += loss.item()
                progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            # Validation on this fold
            val_f1, val_metrics = evaluate_legal_model(model, val_loader, tokenizer, device, use_threshold_tuning=True)
            
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                best_val_threshold = val_metrics['threshold']
                best_model_state = copy.deepcopy(model.state_dict())
                
            print(f"Fold {fold+1}, Epoch {epoch+1}: Val F1 = {val_f1:.4f}, Best = {best_val_f1:.4f}")
        
        # Load best model for testing
        model.load_state_dict(best_model_state)
        
        # Test evaluation for this fold
        test_f1, test_metrics = evaluate_legal_model(
            model, test_loader, tokenizer, device,
            use_threshold_tuning=False,
            fix_threshold=best_val_threshold
        )
        
        fold_f1_scores.append(test_f1)
        fold_accuracies.append(test_metrics['accuracy'])
        
        print(f"✅ Fold {fold+1} Results: F1 = {test_f1:.4f}, Accuracy = {test_metrics['accuracy']:.4f}")
        
        # Clean up GPU memory
        del model
        torch.cuda.empty_cache()
    
    # Store results for this model
    model_fold_results[model_name] = fold_f1_scores
    model_fold_accuracy[model_name] = fold_accuracies
    
    print(f"\n📊 {model_name} - K-Fold Summary:")
    print(f"F1 Scores: {fold_f1_scores}")
    print(f"Mean F1: {np.mean(fold_f1_scores):.4f} ± {np.std(fold_f1_scores):.4f}")
    print(f"Accuracy: {fold_accuracies}")
    print(f"Mean Accuracy: {np.mean(fold_accuracies):.4f} ± {np.std(fold_accuracies):.4f}")

# ========================
# 📊 Statistical Significance Testing with K-Fold Results
# ========================
print("\n" + "="*80)
print("📊 K-FOLD STATISTICAL SIGNIFICANCE ANALYSIS")
print("="*80)

model_names = list(model_fold_results.keys())
if len(model_names) == 2:
    model1, model2 = model_names
    
    # Extract k-fold results
    f1_scores_1 = model_fold_results[model1]
    f1_scores_2 = model_fold_results[model2]
    accuracy_scores_1 = model_fold_accuracy[model1]
    accuracy_scores_2 = model_fold_accuracy[model2]
    
    print(f"\nComparing {model1} vs {model2}")
    print(f"K-Fold validation with {K_FOLDS} folds")
    
    # Paired t-test for F1 scores (each fold is a paired observation)
    print("\n--- F1 Score Comparison (K-Fold) ---")
    print(f"{model1} - Mean F1: {np.mean(f1_scores_1):.4f} ± {np.std(f1_scores_1):.4f}")
    print(f"{model2} - Mean F1: {np.mean(f1_scores_2):.4f} ± {np.std(f1_scores_2):.4f}")
    print(f"Fold-wise F1 scores:")
    for i in range(K_FOLDS):
        print(f"  Fold {i+1}: {f1_scores_1[i]:.4f} vs {f1_scores_2[i]:.4f}")
    
    f1_t_stat, f1_p_value = stats.ttest_rel(f1_scores_1, f1_scores_2)
    print(f"\nPaired t-test (K-Fold): t={f1_t_stat:.4f}, p={f1_p_value:.6f}")
    
    # Determine significance level
    if f1_p_value < 0.001:
        significance_f1 = "***"
    elif f1_p_value < 0.01:
        significance_f1 = "**"
    elif f1_p_value < 0.05:
        significance_f1 = "*"
    else:
        significance_f1 = "ns"
    
    print(f"Significance: {significance_f1}")
    
    # Effect size (Cohen's d for paired samples)
    diff_f1 = np.array(f1_scores_1) - np.array(f1_scores_2)
    cohens_d_f1 = np.mean(diff_f1) / np.std(diff_f1)
    print(f"Cohen's d (effect size): {cohens_d_f1:.4f}")
    
    # Paired t-test for Accuracy
    print("\n--- Accuracy Comparison (K-Fold) ---")
    print(f"{model1} - Mean Accuracy: {np.mean(accuracy_scores_1):.4f} ± {np.std(accuracy_scores_1):.4f}")
    print(f"{model2} - Mean Accuracy: {np.mean(accuracy_scores_2):.4f} ± {np.std(accuracy_scores_2):.4f}")
    print(f"Fold-wise Accuracy scores:")
    for i in range(K_FOLDS):
        print(f"  Fold {i+1}: {accuracy_scores_1[i]:.4f} vs {accuracy_scores_2[i]:.4f}")
    
    acc_t_stat, acc_p_value = stats.ttest_rel(accuracy_scores_1, accuracy_scores_2)
    print(f"\nPaired t-test (K-Fold): t={acc_t_stat:.4f}, p={acc_p_value:.6f}")
    
    if acc_p_value < 0.001:
        significance_acc = "***"
    elif acc_p_value < 0.01:
        significance_acc = "**"
    elif acc_p_value < 0.05:
        significance_acc = "*"
    else:
        significance_acc = "ns"
    
    print(f"Significance: {significance_acc}")
    
    # Effect size for accuracy
    diff_acc = np.array(accuracy_scores_1) - np.array(accuracy_scores_2)
    cohens_d_acc = np.mean(diff_acc) / np.std(diff_acc)
    print(f"Cohen's d (effect size): {cohens_d_acc:.4f}")
    
    # Confidence intervals
    from scipy.stats import t
    alpha = 0.05
    dof = K_FOLDS - 1
    t_critical = t.ppf(1 - alpha/2, dof)
    
    f1_diff_mean = np.mean(diff_f1)
    f1_diff_se = stats.sem(diff_f1)
    f1_ci_lower = f1_diff_mean - t_critical * f1_diff_se
    f1_ci_upper = f1_diff_mean + t_critical * f1_diff_se
    
    acc_diff_mean = np.mean(diff_acc)
    acc_diff_se = stats.sem(diff_acc)
    acc_ci_lower = acc_diff_mean - t_critical * acc_diff_se
    acc_ci_upper = acc_diff_mean + t_critical * acc_diff_se
    
    print(f"\n95% Confidence Interval for F1 difference: [{f1_ci_lower:.4f}, {f1_ci_upper:.4f}]")
    print(f"95% Confidence Interval for Accuracy difference: [{acc_ci_lower:.4f}, {acc_ci_upper:.4f}]")
    
    # Wilcoxon signed-rank test (non-parametric)
    print("\n--- Non-parametric Tests ---")
    f1_wilcoxon_stat, f1_wilcoxon_p = stats.wilcoxon(f1_scores_1, f1_scores_2)
    acc_wilcoxon_stat, acc_wilcoxon_p = stats.wilcoxon(accuracy_scores_1, accuracy_scores_2)
    
    print(f"Wilcoxon signed-rank test (F1): p={f1_wilcoxon_p:.6f}")
    print(f"Wilcoxon signed-rank test (Accuracy): p={acc_wilcoxon_p:.6f}")
    
    # Summary table
    print("\n" + "="*70)
    print("📋 K-FOLD STATISTICAL SUMMARY TABLE")
    print("="*70)
    print(f"{'Metric':<15} {'Model 1':<15} {'Model 2':<15} {'p-value':<12} {'Significance':<12}")
    print("-" * 70)
    print(f"{'F1 Score':<15} {np.mean(f1_scores_1):<15.4f} {np.mean(f1_scores_2):<15.4f} {f1_p_value:<12.6f} {significance_f1:<12}")
    print(f"{'Accuracy':<15} {np.mean(accuracy_scores_1):<15.4f} {np.mean(accuracy_scores_2):<15.4f} {acc_p_value:<12.6f} {significance_acc:<12}")
    
    print(f"\nEffect Sizes (Cohen's d):")
    print(f"F1 Score: {cohens_d_f1:.4f}")
    print(f"Accuracy: {cohens_d_acc:.4f}")
    
    print("\n*** p < 0.001, ** p < 0.01, * p < 0.05, ns = not significant")
    print("Effect size interpretation: |d| < 0.2 (small), 0.2-0.8 (medium), > 0.8 (large)")
    
    # Save detailed k-fold results
    kfold_results = pd.DataFrame({
        'fold': range(1, K_FOLDS + 1),
        f'{model1}_f1': f1_scores_1,
        f'{model2}_f1': f1_scores_2,
        f'{model1}_accuracy': accuracy_scores_1,
        f'{model2}_accuracy': accuracy_scores_2,
        'f1_difference': diff_f1,
        'accuracy_difference': diff_acc
    })
    
    # Add summary statistics
    summary_stats = pd.DataFrame({
        'statistic': ['mean', 'std', 'sem', 't_stat', 'p_value', 'cohens_d'],
        f'{model1}_f1': [np.mean(f1_scores_1), np.std(f1_scores_1), stats.sem(f1_scores_1), 
                        f1_t_stat, f1_p_value, cohens_d_f1],
        f'{model2}_f1': [np.mean(f1_scores_2), np.std(f1_scores_2), stats.sem(f1_scores_2), 
                        f1_t_stat, f1_p_value, -cohens_d_f1],
        f'{model1}_accuracy': [np.mean(accuracy_scores_1), np.std(accuracy_scores_1), stats.sem(accuracy_scores_1),
                              acc_t_stat, acc_p_value, cohens_d_acc],
        f'{model2}_accuracy': [np.mean(accuracy_scores_2), np.std(accuracy_scores_2), stats.sem(accuracy_scores_2),
                              acc_t_stat, acc_p_value, -cohens_d_acc]
    })
    
    # Save results
    kfold_results.to_csv('kfold_detailed_results.csv', index=False)
    summary_stats.to_csv('kfold_statistical_summary.csv', index=False)
    
    print(f"\n💾 Results saved:")
    print(f"  - Detailed k-fold results: 'kfold_detailed_results.csv'")
    print(f"  - Statistical summary: 'kfold_statistical_summary.csv'")

else:
    print("Note: Statistical comparison requires exactly 2 models for paired testing.")

print(f"\n🎯 K-FOLD CROSS-VALIDATION COMPLETED")
print(f"Both models evaluated on the same {K_FOLDS} folds for fair comparison.")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Loading and combining datasets for k-fold cross-validation...


You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


Total samples for k-fold CV: 5791

🔄 STARTING K-FOLD EVALUATION FOR: m5-mlm-final




Legal Dataset created: 5791 samples
Label distribution: [3857 1934]
Sample input length: 2083 characters

📁 FOLD 1/5
Train samples: 4632, Test samples: 1159
Loading fresh model for fold 1...


Fold 1, Epoch 1/7:   0%|          | 0/927 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Fold 1, Epoch 1/7: 100%|██████████| 927/927 [03:48<00:00,  4.06it/s, loss=0.4440]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.52it/s]


Best threshold: -5.2493 (F1: 0.5103)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.53it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5103
Precision: 0.3815
Recall: 0.7705
Accuracy: 0.5135

Prediction Distribution: [311 616]
True Label Distribution: [622 305]
AUC-ROC: 0.5790

Classification Report:
              precision    recall  f1-score   support

           0       0.77      0.39      0.52       622
           1       0.38      0.77      0.51       305

    accuracy                           0.51       927
   macro avg       0.58      0.58      0.51       927
weighted avg       0.65      0.51      0.51       927

Fold 1, Epoch 1: Val F1 = 0.5103, Best = 0.5103


Fold 1, Epoch 2/7: 100%|██████████| 927/927 [03:51<00:00,  4.00it/s, loss=0.2527]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.51it/s]


Best threshold: -3.8732 (F1: 0.5514)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.50it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5514
Precision: 0.4462
Recall: 0.7213
Accuracy: 0.6138

Prediction Distribution: [434 493]
True Label Distribution: [622 305]
AUC-ROC: 0.6412

Classification Report:
              precision    recall  f1-score   support

           0       0.80      0.56      0.66       622
           1       0.45      0.72      0.55       305

    accuracy                           0.61       927
   macro avg       0.63      0.64      0.61       927
weighted avg       0.69      0.61      0.62       927

Fold 1, Epoch 2: Val F1 = 0.5514, Best = 0.5514


Fold 1, Epoch 3/7: 100%|██████████| 927/927 [03:51<00:00,  4.00it/s, loss=0.2994]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.51it/s]


Best threshold: -4.5635 (F1: 0.6349)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.51it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.6349
Precision: 0.6154
Recall: 0.6557
Accuracy: 0.7519

Prediction Distribution: [602 325]
True Label Distribution: [622 305]
AUC-ROC: 0.7274

Classification Report:
              precision    recall  f1-score   support

           0       0.83      0.80      0.81       622
           1       0.62      0.66      0.63       305

    accuracy                           0.75       927
   macro avg       0.72      0.73      0.72       927
weighted avg       0.76      0.75      0.75       927

Fold 1, Epoch 3: Val F1 = 0.6349, Best = 0.6349


Fold 1, Epoch 4/7: 100%|██████████| 927/927 [03:51<00:00,  4.00it/s, loss=0.0180]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.53it/s]


Best threshold: -4.9943 (F1: 0.6707)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.50it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.6707
Precision: 0.6268
Recall: 0.7213
Accuracy: 0.7670

Prediction Distribution: [576 351]
True Label Distribution: [622 305]
AUC-ROC: 0.7554

Classification Report:
              precision    recall  f1-score   support

           0       0.85      0.79      0.82       622
           1       0.63      0.72      0.67       305

    accuracy                           0.77       927
   macro avg       0.74      0.76      0.75       927
weighted avg       0.78      0.77      0.77       927

Fold 1, Epoch 4: Val F1 = 0.6707, Best = 0.6707


Fold 1, Epoch 5/7: 100%|██████████| 927/927 [03:51<00:00,  4.00it/s, loss=0.0045]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.56it/s]


Best threshold: -3.2699 (F1: 0.6964)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.59it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.6964
Precision: 0.7647
Recall: 0.6393
Accuracy: 0.8166

Prediction Distribution: [672 255]
True Label Distribution: [622 305]
AUC-ROC: 0.7714

Classification Report:
              precision    recall  f1-score   support

           0       0.84      0.90      0.87       622
           1       0.76      0.64      0.70       305

    accuracy                           0.82       927
   macro avg       0.80      0.77      0.78       927
weighted avg       0.81      0.82      0.81       927

Fold 1, Epoch 5: Val F1 = 0.6964, Best = 0.6964


Fold 1, Epoch 6/7: 100%|██████████| 927/927 [03:50<00:00,  4.02it/s, loss=0.0031]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.65it/s]


Best threshold: -7.6882 (F1: 0.7072)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.60it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.7072
Precision: 0.6736
Recall: 0.7443
Accuracy: 0.7972

Prediction Distribution: [590 337]
True Label Distribution: [622 305]
AUC-ROC: 0.7837

Classification Report:
              precision    recall  f1-score   support

           0       0.87      0.82      0.84       622
           1       0.67      0.74      0.71       305

    accuracy                           0.80       927
   macro avg       0.77      0.78      0.78       927
weighted avg       0.80      0.80      0.80       927

Fold 1, Epoch 6: Val F1 = 0.7072, Best = 0.7072


Fold 1, Epoch 7/7: 100%|██████████| 927/927 [03:49<00:00,  4.04it/s, loss=0.0451]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.69it/s]


Best threshold: -6.0102 (F1: 0.7141)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.68it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.7141
Precision: 0.6307
Recall: 0.8230
Accuracy: 0.7832

Prediction Distribution: [529 398]
True Label Distribution: [622 305]
AUC-ROC: 0.7933

Classification Report:
              precision    recall  f1-score   support

           0       0.90      0.76      0.83       622
           1       0.63      0.82      0.71       305

    accuracy                           0.78       927
   macro avg       0.76      0.79      0.77       927
weighted avg       0.81      0.78      0.79       927

Fold 1, Epoch 7: Val F1 = 0.7141, Best = 0.7141
Using fixed threshold: -6.010154292291524


Legal Evaluation (Fixed Threshold): 100%|██████████| 290/290 [00:22<00:00, 12.70it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.6981
Precision: 0.6069
Recall: 0.8217
Accuracy: 0.7627

Prediction Distribution: [635 524]
True Label Distribution: [772 387]
AUC-ROC: 0.7774

Classification Report:
              precision    recall  f1-score   support

           0       0.89      0.73      0.80       772
           1       0.61      0.82      0.70       387

    accuracy                           0.76      1159
   macro avg       0.75      0.78      0.75      1159
weighted avg       0.80      0.76      0.77      1159

✅ Fold 1 Results: F1 = 0.6981, Accuracy = 0.7627

📁 FOLD 2/5
Train samples: 4633, Test samples: 1158
Loading fresh model for fold 2...


Fold 2, Epoch 1/7: 100%|██████████| 927/927 [03:49<00:00,  4.04it/s, loss=0.3480]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.64it/s]


Best threshold: -5.6131 (F1: 0.5033)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.64it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5033
Precision: 0.3548
Recall: 0.8660
Accuracy: 0.4358

Prediction Distribution: [180 747]
True Label Distribution: [621 306]
AUC-ROC: 0.5449

Classification Report:
              precision    recall  f1-score   support

           0       0.77      0.22      0.35       621
           1       0.35      0.87      0.50       306

    accuracy                           0.44       927
   macro avg       0.56      0.54      0.43       927
weighted avg       0.63      0.44      0.40       927

Fold 2, Epoch 1: Val F1 = 0.5033, Best = 0.5033


Fold 2, Epoch 2/7: 100%|██████████| 927/927 [03:49<00:00,  4.05it/s, loss=0.2061]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.69it/s]


Best threshold: -4.7372 (F1: 0.5259)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.75it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5259
Precision: 0.3964
Recall: 0.7810
Accuracy: 0.5351

Prediction Distribution: [324 603]
True Label Distribution: [621 306]
AUC-ROC: 0.5974

Classification Report:
              precision    recall  f1-score   support

           0       0.79      0.41      0.54       621
           1       0.40      0.78      0.53       306

    accuracy                           0.54       927
   macro avg       0.59      0.60      0.53       927
weighted avg       0.66      0.54      0.54       927

Fold 2, Epoch 2: Val F1 = 0.5259, Best = 0.5259


Fold 2, Epoch 3/7: 100%|██████████| 927/927 [03:48<00:00,  4.06it/s, loss=0.1631]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.63it/s]


Best threshold: -5.8495 (F1: 0.5811)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.53it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5811
Precision: 0.4817
Recall: 0.7320
Accuracy: 0.6516

Prediction Distribution: [462 465]
True Label Distribution: [621 306]
AUC-ROC: 0.6720

Classification Report:
              precision    recall  f1-score   support

           0       0.82      0.61      0.70       621
           1       0.48      0.73      0.58       306

    accuracy                           0.65       927
   macro avg       0.65      0.67      0.64       927
weighted avg       0.71      0.65      0.66       927

Fold 2, Epoch 3: Val F1 = 0.5811, Best = 0.5811


Fold 2, Epoch 4/7: 100%|██████████| 927/927 [03:50<00:00,  4.03it/s, loss=0.4540]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.58it/s]


Best threshold: -4.8180 (F1: 0.6510)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.75it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.6510
Precision: 0.6690
Recall: 0.6340
Accuracy: 0.7756

Prediction Distribution: [637 290]
True Label Distribution: [621 306]
AUC-ROC: 0.7397

Classification Report:
              precision    recall  f1-score   support

           0       0.82      0.85      0.83       621
           1       0.67      0.63      0.65       306

    accuracy                           0.78       927
   macro avg       0.75      0.74      0.74       927
weighted avg       0.77      0.78      0.77       927

Fold 2, Epoch 4: Val F1 = 0.6510, Best = 0.6510


Fold 2, Epoch 5/7: 100%|██████████| 927/927 [03:47<00:00,  4.07it/s, loss=0.1693]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.75it/s]


Best threshold: -5.4289 (F1: 0.6919)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.54it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.6919
Precision: 0.6697
Recall: 0.7157
Accuracy: 0.7896

Prediction Distribution: [600 327]
True Label Distribution: [621 306]
AUC-ROC: 0.7709

Classification Report:
              precision    recall  f1-score   support

           0       0.85      0.83      0.84       621
           1       0.67      0.72      0.69       306

    accuracy                           0.79       927
   macro avg       0.76      0.77      0.77       927
weighted avg       0.79      0.79      0.79       927

Fold 2, Epoch 5: Val F1 = 0.6919, Best = 0.6919


Fold 2, Epoch 6/7: 100%|██████████| 927/927 [03:50<00:00,  4.03it/s, loss=0.7968]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.63it/s]


Best threshold: -6.1773 (F1: 0.6978)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.56it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.6978
Precision: 0.7133
Recall: 0.6830
Accuracy: 0.8047

Prediction Distribution: [634 293]
True Label Distribution: [621 306]
AUC-ROC: 0.7739

Classification Report:
              precision    recall  f1-score   support

           0       0.85      0.86      0.86       621
           1       0.71      0.68      0.70       306

    accuracy                           0.80       927
   macro avg       0.78      0.77      0.78       927
weighted avg       0.80      0.80      0.80       927

Fold 2, Epoch 6: Val F1 = 0.6978, Best = 0.6978


Fold 2, Epoch 7/7: 100%|██████████| 927/927 [03:48<00:00,  4.05it/s, loss=0.0263]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.60it/s]


Best threshold: -6.6320 (F1: 0.7068)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.53it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.7068
Precision: 0.6696
Recall: 0.7484
Accuracy: 0.7950

Prediction Distribution: [585 342]
True Label Distribution: [621 306]
AUC-ROC: 0.7832

Classification Report:
              precision    recall  f1-score   support

           0       0.87      0.82      0.84       621
           1       0.67      0.75      0.71       306

    accuracy                           0.80       927
   macro avg       0.77      0.78      0.77       927
weighted avg       0.80      0.80      0.80       927

Fold 2, Epoch 7: Val F1 = 0.7068, Best = 0.7068
Using fixed threshold: -6.631978000913348


Legal Evaluation (Fixed Threshold): 100%|██████████| 290/290 [00:23<00:00, 12.59it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.7170
Precision: 0.6566
Recall: 0.7896
Accuracy: 0.7927

Prediction Distribution: [695 463]
True Label Distribution: [773 385]
AUC-ROC: 0.7920

Classification Report:
              precision    recall  f1-score   support

           0       0.88      0.79      0.84       773
           1       0.66      0.79      0.72       385

    accuracy                           0.79      1158
   macro avg       0.77      0.79      0.78      1158
weighted avg       0.81      0.79      0.80      1158

✅ Fold 2 Results: F1 = 0.7170, Accuracy = 0.7927

📁 FOLD 3/5
Train samples: 4633, Test samples: 1158
Loading fresh model for fold 3...


Fold 3, Epoch 1/7: 100%|██████████| 927/927 [03:49<00:00,  4.03it/s, loss=0.3951]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.71it/s]


Best threshold: -4.8456 (F1: 0.5148)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.67it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5148
Precision: 0.3639
Recall: 0.8795
Accuracy: 0.4509

Prediction Distribution: [185 742]
True Label Distribution: [620 307]
AUC-ROC: 0.5591

Classification Report:
              precision    recall  f1-score   support

           0       0.80      0.24      0.37       620
           1       0.36      0.88      0.51       307

    accuracy                           0.45       927
   macro avg       0.58      0.56      0.44       927
weighted avg       0.66      0.45      0.42       927

Fold 3, Epoch 1: Val F1 = 0.5148, Best = 0.5148


Fold 3, Epoch 2/7: 100%|██████████| 927/927 [03:48<00:00,  4.06it/s, loss=0.3832]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.80it/s]


Best threshold: -5.7337 (F1: 0.5392)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.65it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5392
Precision: 0.3920
Recall: 0.8632
Accuracy: 0.5113

Prediction Distribution: [251 676]
True Label Distribution: [620 307]
AUC-ROC: 0.6001

Classification Report:
              precision    recall  f1-score   support

           0       0.83      0.34      0.48       620
           1       0.39      0.86      0.54       307

    accuracy                           0.51       927
   macro avg       0.61      0.60      0.51       927
weighted avg       0.69      0.51      0.50       927

Fold 3, Epoch 2: Val F1 = 0.5392, Best = 0.5392


Fold 3, Epoch 3/7: 100%|██████████| 927/927 [03:49<00:00,  4.04it/s, loss=0.1313]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.68it/s]


Best threshold: -5.5333 (F1: 0.5532)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.64it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5532
Precision: 0.4600
Recall: 0.6938
Accuracy: 0.6289

Prediction Distribution: [464 463]
True Label Distribution: [620 307]
AUC-ROC: 0.6453

Classification Report:
              precision    recall  f1-score   support

           0       0.80      0.60      0.68       620
           1       0.46      0.69      0.55       307

    accuracy                           0.63       927
   macro avg       0.63      0.65      0.62       927
weighted avg       0.69      0.63      0.64       927

Fold 3, Epoch 3: Val F1 = 0.5532, Best = 0.5532


Fold 3, Epoch 4/7: 100%|██████████| 927/927 [03:48<00:00,  4.06it/s, loss=0.0893]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.77it/s]


Best threshold: -3.0634 (F1: 0.6291)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.62it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.6291
Precision: 0.6054
Recall: 0.6547
Accuracy: 0.7443

Prediction Distribution: [595 332]
True Label Distribution: [620 307]
AUC-ROC: 0.7217

Classification Report:
              precision    recall  f1-score   support

           0       0.82      0.79      0.80       620
           1       0.61      0.65      0.63       307

    accuracy                           0.74       927
   macro avg       0.71      0.72      0.72       927
weighted avg       0.75      0.74      0.75       927

Fold 3, Epoch 4: Val F1 = 0.6291, Best = 0.6291


Fold 3, Epoch 5/7: 100%|██████████| 927/927 [03:49<00:00,  4.03it/s, loss=0.2186]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.68it/s]


Best threshold: -5.3259 (F1: 0.6747)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.60it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.6747
Precision: 0.6299
Recall: 0.7264
Accuracy: 0.7681

Prediction Distribution: [573 354]
True Label Distribution: [620 307]
AUC-ROC: 0.7575

Classification Report:
              precision    recall  f1-score   support

           0       0.85      0.79      0.82       620
           1       0.63      0.73      0.67       307

    accuracy                           0.77       927
   macro avg       0.74      0.76      0.75       927
weighted avg       0.78      0.77      0.77       927

Fold 3, Epoch 5: Val F1 = 0.6747, Best = 0.6747


Fold 3, Epoch 6/7: 100%|██████████| 927/927 [03:49<00:00,  4.04it/s, loss=0.0262]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.64it/s]


Best threshold: -5.9222 (F1: 0.7014)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.63it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.7014
Precision: 0.6810
Recall: 0.7231
Accuracy: 0.7961

Prediction Distribution: [601 326]
True Label Distribution: [620 307]
AUC-ROC: 0.7777

Classification Report:
              precision    recall  f1-score   support

           0       0.86      0.83      0.85       620
           1       0.68      0.72      0.70       307

    accuracy                           0.80       927
   macro avg       0.77      0.78      0.77       927
weighted avg       0.80      0.80      0.80       927

Fold 3, Epoch 6: Val F1 = 0.7014, Best = 0.7014


Fold 3, Epoch 7/7: 100%|██████████| 927/927 [03:48<00:00,  4.05it/s, loss=0.0519]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.71it/s]


Best threshold: -6.9139 (F1: 0.7071)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.61it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.7071
Precision: 0.6477
Recall: 0.7785
Accuracy: 0.7864

Prediction Distribution: [558 369]
True Label Distribution: [620 307]
AUC-ROC: 0.7844

Classification Report:
              precision    recall  f1-score   support

           0       0.88      0.79      0.83       620
           1       0.65      0.78      0.71       307

    accuracy                           0.79       927
   macro avg       0.76      0.78      0.77       927
weighted avg       0.80      0.79      0.79       927

Fold 3, Epoch 7: Val F1 = 0.7071, Best = 0.7071
Using fixed threshold: -6.91388987521736


Legal Evaluation (Fixed Threshold): 100%|██████████| 290/290 [00:22<00:00, 12.66it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.6969
Precision: 0.6450
Recall: 0.7580
Accuracy: 0.7694

Prediction Distribution: [682 476]
True Label Distribution: [753 405]
AUC-ROC: 0.7668

Classification Report:
              precision    recall  f1-score   support

           0       0.86      0.78      0.81       753
           1       0.64      0.76      0.70       405

    accuracy                           0.77      1158
   macro avg       0.75      0.77      0.76      1158
weighted avg       0.78      0.77      0.77      1158

✅ Fold 3 Results: F1 = 0.6969, Accuracy = 0.7694

📁 FOLD 4/5
Train samples: 4633, Test samples: 1158
Loading fresh model for fold 4...


Fold 4, Epoch 1/7: 100%|██████████| 927/927 [03:48<00:00,  4.05it/s, loss=0.0610]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.63it/s]


Best threshold: -5.3705 (F1: 0.5292)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.63it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5292
Precision: 0.3978
Recall: 0.7905
Accuracy: 0.5221

Prediction Distribution: [301 626]
True Label Distribution: [612 315]
AUC-ROC: 0.5872

Classification Report:
              precision    recall  f1-score   support

           0       0.78      0.38      0.51       612
           1       0.40      0.79      0.53       315

    accuracy                           0.52       927
   macro avg       0.59      0.59      0.52       927
weighted avg       0.65      0.52      0.52       927

Fold 4, Epoch 1: Val F1 = 0.5292, Best = 0.5292


Fold 4, Epoch 2/7: 100%|██████████| 927/927 [03:48<00:00,  4.05it/s, loss=0.3636]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.69it/s]


Best threshold: -5.7057 (F1: 0.5443)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.65it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5443
Precision: 0.4150
Recall: 0.7905
Accuracy: 0.5502

Prediction Distribution: [327 600]
True Label Distribution: [612 315]
AUC-ROC: 0.6085

Classification Report:
              precision    recall  f1-score   support

           0       0.80      0.43      0.56       612
           1       0.41      0.79      0.54       315

    accuracy                           0.55       927
   macro avg       0.61      0.61      0.55       927
weighted avg       0.67      0.55      0.55       927

Fold 4, Epoch 2: Val F1 = 0.5443, Best = 0.5443


Fold 4, Epoch 3/7: 100%|██████████| 927/927 [03:49<00:00,  4.04it/s, loss=0.3350]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.68it/s]


Best threshold: -3.8132 (F1: 0.5680)


Legal Evaluation (Tuned Threshold): 100%|██████████| 232/232 [00:18<00:00, 12.69it/s]



📊 LEGAL CITATION PREDICTION RESULTS:
F1 Score: 0.5680
Precision: 0.4551
Recall: 0.7556
Accuracy: 0.6095

Prediction Distribution: [404 523]
True Label Distribution: [612 315]
AUC-ROC: 0.6449

Classification Report:
              precision    recall  f1-score   support

           0       0.81      0.53      0.64       612
           1       0.46      0.76      0.57       315

    accuracy                           0.61       927
   macro avg       0.63      0.64      0.61       927
weighted avg       0.69      0.61      0.62       927

Fold 4, Epoch 3: Val F1 = 0.5680, Best = 0.5680


Fold 4, Epoch 4/7: 100%|██████████| 927/927 [03:49<00:00,  4.05it/s, loss=0.4517]


🔍 Finding best threshold...


Collecting scores: 100%|██████████| 232/232 [00:18<00:00, 12.65it/s]
