In [None]:
from transformers import AutoTokenizer, AutoModel
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from torch.optim import Adam
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
from sklearn.metrics import roc_curve
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")




: 

## data prep:

#### for citation based data :

In [None]:
import os
import pandas as pd

# Paths
tag_dir = '/home/liorkob/M.Sc/thesis/data/drugs_3k/gpt/verdicts_tagged_citations'
gpt_facts_path = '/home/liorkob/M.Sc/thesis/data/drugs_3k/gpt/processed_verdicts_with_gpt.csv'
output_path = '/home/liorkob/M.Sc/thesis/data/drugs_3k/gpt/verdict_pairs_with_similarity.csv'

# Load verdict facts
verdict_facts = pd.read_csv(gpt_facts_path)
facts_dict = dict(zip(verdict_facts['verdict'], verdict_facts['extracted_gpt_facts']))

# Collect data
rows = []
for file in os.listdir(tag_dir):
    if file.endswith('.csv'):
        verdict_a = file.replace('.csv', '')
        file_path = os.path.join(tag_dir, file)
        
        # Skip empty files
        if os.path.getsize(file_path) == 0:
            continue
        
        try:
            df = pd.read_csv(file_path)
        except pd.errors.EmptyDataError:
            continue

        for _, row in df.iterrows():
            if row['predicted_label'] == 1:
                verdict_b = row['citation']
                a_facts = facts_dict.get(verdict_a, "")
                b_facts = facts_dict.get(verdict_b, "")
                rows.append([verdict_a, a_facts, verdict_b, b_facts, 1])

# Save result
result_df = pd.DataFrame(rows, columns=[
    'verdict_a_name', 'verdict_a_extracted_gpt_facts',
    'verdict_b_name', 'verdict_b_extracted_gpt_facts', 'similarity_score'
])
result_df.to_csv(output_path, index=False)


#### genrate non similar pairs

In [None]:
import pandas as pd
import random
import networkx as nx
from tqdm import tqdm
from openai import OpenAI
import os

# ========== Setup ==========
os.environ["OPENAI_API_KEY"] = "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"
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# ========== Load Similar Pairs ==========
df_pairs = pd.read_csv("/home/liorkob/M.Sc/thesis/data/drugs_3k/gpt/valid_pairs_with_log.csv")
valid_pairs = df_pairs[df_pairs["log"] != "missing cited verdict"].copy()
similar_pairs = set(tuple(sorted([a, b])) for a, b in zip(valid_pairs["verdict_a"], valid_pairs["verdict_b"]))

# Build similarity graph
G = nx.Graph()
G.add_edges_from(similar_pairs)

# ========== Load Facts from All Sources ==========
facts1 = pd.read_csv("/home/liorkob/M.Sc/thesis/data/drugs_3k/gpt/processed_verdicts_with_gpt.csv")  
# facts2 = pd.read_csv("/home/liorkob/M.Sc/thesis/data/5k/gpt/processed_appeals_with_gpt_2.csv")

pairs_facts_a = df_pairs[["verdict_a", "gpt_facts_a"]].rename(columns={"verdict_a": "verdict", "gpt_facts_a": "extracted_gpt_facts"})
pairs_facts_b = df_pairs[["verdict_b", "gpt_facts_b"]].rename(columns={"verdict_b": "verdict", "gpt_facts_b": "extracted_gpt_facts"})
all_facts_df = pd.concat([facts1, pairs_facts_a, pairs_facts_b])

# all_facts_df = pd.concat([facts1, facts2, pairs_facts_a, pairs_facts_b])
all_facts_df = all_facts_df.dropna(subset=["verdict", "extracted_gpt_facts"]).drop_duplicates(subset="verdict")
facts_dict = dict(zip(all_facts_df["verdict"], all_facts_df["extracted_gpt_facts"]))

# ========== Prepare Verdict Sets ==========
verdicts_in_pairs = set(valid_pairs["verdict_a"]) | set(valid_pairs["verdict_b"])
all_verdicts = set(all_facts_df["verdict"])
extra_verdicts = list(all_verdicts - verdicts_in_pairs)
combined_verdicts = list(verdicts_in_pairs | set(extra_verdicts))

# ========== Generate Candidate Non-Similar Pairs ==========
def generate_candidate_pairs(verdicts, existing_pairs, G, num_pairs):
    non_similar = set()
    attempts = 0
    max_attempts = num_pairs * 20

    while len(non_similar) < num_pairs and attempts < max_attempts:
        a, b = random.sample(list(verdicts), 2)
        pair = tuple(sorted([a, b]))
        attempts += 1

        if pair in existing_pairs or G.has_edge(*pair):
            continue

        # ✅ Only check transitive path if both nodes exist in graph
        if a in G and b in G:
            if nx.has_path(G, a, b):
                continue

        non_similar.add(pair)

    return list(non_similar)

# ========== GPT Verification ==========
def verify_with_gpt(pair, facts_dict):
    print(f"Checking: {pair}, verified so far: {len(verified_non_similar)}")

    fact_a = facts_dict.get(pair[0], "")
    fact_b = facts_dict.get(pair[1], "")

    if not fact_a or not fact_b:
        return False
    prompt = f"""אתה עוזר משפטי. תפקידך לבדוק האם שתי מערכות עובדתיות מתארות **מצבים משפטיים דומים**, כך שניתן יהיה להסתמך על האחד לצורך גזירת העונש בשני.

התייחס רק לנסיבות שקשורות ישירות לביצוע העבירה – כגון סוג העבירה, מהות המעשה, אופן הביצוע, משך הזמן, כוונה פלילית, ומאפיינים רלוונטיים של הנאשם או הקורבן *שנוגעים למעשה עצמו*.

מאפיינים כלליים כמו גיל הנאשם, מקום המגורים או תוצאה מקרית שאינה נובעת מהמעשה – אינם רלוונטיים לדמיון המשפטי.

ענה רק "כן" או "לא". אל תסביר. אל תוסיף סימני פיסוק.

עובדות א: {fact_a}
עובדות ב: {fact_b}
תשובה:"""
    try:
        response = client.chat.completions.create(
            model="gpt-4.1-mini", 
            messages=[
                {"role": "system", "content": "You are an AI trained to analyz legal text."},
                {"role": "user", "content": prompt}
            ]
        )

        answer = response.choices[0].message.content.strip().lower()
        print(answer)
        return answer == "לא"
    
    except Exception as e:
        print(f"🚨 GPT API error: {e}")
        return 

# ========== Main Loop ==========
target_count = 2000
verified_non_similar = set()
used_pairs = set()

pbar = tqdm(total=target_count, desc="🔍 Verifying with GPT")

while len(verified_non_similar) < target_count:
    needed = (target_count - len(verified_non_similar)) * 2
    candidate_pairs = generate_candidate_pairs(combined_verdicts, similar_pairs | used_pairs, G, needed)

    for pair in candidate_pairs:
        if pair in used_pairs:
            continue
        used_pairs.add(pair)

        if verify_with_gpt(pair, facts_dict):
            verified_non_similar.add(pair)
            pbar.update(1)

pbar.close()

# ========== Save Output with Facts ==========
pairs_list = []
for a, b in verified_non_similar:
    pairs_list.append({
        "verdict_a": a,
        "verdict_b": b,
        "gpt_facts_a": facts_dict.get(a, ""),
        "gpt_facts_b": facts_dict.get(b, ""),
        "label": 0
    })

output_df = pd.DataFrame(pairs_list)
output_df.to_csv("/home/liorkob/M.Sc/thesis/data/drugs_3k/gpt/verified_non_similar_pairs_2.csv", index=False)
print(f"\n✅ Saved {len(output_df)} pairs with facts to verified_non_similar_pairs_2.csv")


### data statistics

In [None]:
import pandas as pd

# Load your CSVs
similar = pd.read_csv("/home/liorkob/M.Sc/thesis/data/drugs_3k/gpt/valid_pairs_with_log.csv")
non_similar = pd.read_csv("/home/liorkob/M.Sc/thesis/data/drugs_3k/gpt/verified_non_similar_pairs_2.csv")

# Filter relevant rows
similar = similar[similar["log"] != "missing cited verdict"].copy()
similar["label"] = 1

non_similar = non_similar[non_similar["label"] == 0].copy()
non_similar["label"] = 0

# Select only required columns
cols = ["verdict_a", "verdict_b", "gpt_facts_a", "gpt_facts_b", "label"]
similar = similar[cols]
non_similar = non_similar[cols]

print(f"✅ Similar pairs: {len(similar)}")
print(f"✅ Non-similar pairs: {len(non_similar)}")

# Normalize order to detect duplicates
def normalize_pair(row):
    a, b = row['verdict_a'], row['verdict_b']
    return tuple(sorted((a, b)))

similar['pair_key'] = similar.apply(normalize_pair, axis=1)
non_similar['pair_key'] = non_similar.apply(normalize_pair, axis=1)

# Combine and remove duplicates
data = pd.concat([similar, non_similar], ignore_index=True)
print(f"📦 Total before removing duplicates: {len(data)}")

data = data.drop_duplicates(subset='pair_key').drop(columns='pair_key')
print(f"🧹 Total after removing duplicates: {len(data)}")

# Shuffle
data = data.sample(frac=1, random_state=42).reset_index(drop=True)

# Save
data.to_csv("data_pairs_5k.csv", index=False)

# Print distribution
print("📊 Label distribution:")
print(data["label"].value_counts())


### split to test-tain:

In [None]:
# === Data Load ===
df = pd.read_csv("/home/liorkob/M.Sc/thesis/similarity-model/data_pairs_5k.csv")
df_pos = df[df['label'] == 1]
df_neg = df[df['label'] == 0]
df_pos_train, df_pos_val = train_test_split(df_pos, test_size=0.3, random_state=42)
df_neg_train, df_neg_val = train_test_split(df_neg, test_size=0.3, random_state=42)
df_train = pd.concat([df_pos_train, df_neg_train]).sample(frac=1, random_state=42).reset_index(drop=True)
df_val = pd.concat([df_pos_val, df_neg_val]).sample(frac=1, random_state=42).reset_index(drop=True)
df_val, df_test = train_test_split(df_val, test_size=0.5, random_state=42, stratify=df_val['label'])

# Save splits
df_train.to_csv("crossencoder_train.csv", index=False, encoding="utf-8-sig")
df_val.to_csv("crossencoder_val.csv", index=False, encoding="utf-8-sig")
df_test.to_csv("crossencoder_test.csv", index=False, encoding="utf-8-sig")



## CrossEncoderHeBERT 

In [2]:
# --- Cross-Encoder Dataset ---
class CrossEncoderVerdictDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=512):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = f"[CLS] {row['gpt_facts_a']} [SEP] {row['gpt_facts_b']} [SEP]"
        enc = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt')
        return {
            'input_ids': enc['input_ids'].squeeze(),
            'attention_mask': enc['attention_mask'].squeeze(),
            'label': torch.tensor(row['label'], dtype=torch.float)
        }

# --- Cross-Encoder Model ---
class CrossEncoderHeBERT(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden = self.encoder.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(hidden, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0]  # [CLS] token
        return self.classifier(pooled).squeeze(-1)

# --- Training Loop ---
from sklearn.metrics import roc_auc_score

def train_cross_encoder_with_early_stopping(model, train_loader, val_loader, optimizer, device, epochs=10, patience=3):
    criterion = nn.BCEWithLogitsLoss()
    best_auc = 0
    no_improve = 0

    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            for key in batch:
                batch[key] = batch[key].to(device)

            logits = model(batch['input_ids'], batch['attention_mask'])
            loss = criterion(logits, batch['label'])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds = (torch.sigmoid(logits) >= 0.5).long()
            correct += (preds == batch['label'].long()).sum().item()
            total += batch['label'].size(0)

        acc = correct / total
        print(f"Epoch {epoch+1} | Train Loss: {total_loss / len(train_loader):.4f}, Accuracy: {acc:.4f}")

        # Evaluate on validation set
        model.eval()
        val_probs, val_labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                for key in batch:
                    batch[key] = batch[key].to(device)
                logits = model(batch['input_ids'], batch['attention_mask'])
                prob = torch.sigmoid(logits).cpu().numpy()
                label = batch['label'].cpu().numpy()
                val_probs.extend(prob)
                val_labels.extend(label)
        val_auc = roc_auc_score(val_labels, val_probs)
        print(f"Epoch {epoch+1} | Validation AUC: {val_auc:.4f}")

        # Early stopping check
        if val_auc > best_auc:
            best_auc = val_auc
            no_improve = 0
            torch.save(model.state_dict(), "best_crossencoder_ft_mlm.pt")
            print("✅ New best model saved.")
        else:
            no_improve += 1
            if no_improve >= patience:
                print("⏹️ Early stopping triggered.")
                break

def train_cross_encoder(model, dataloader, optimizer, device, epochs=15):
    criterion = nn.BCEWithLogitsLoss()
    model.train()
    for epoch in range(epochs):
        total_loss, correct, total = 0, 0, 0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
            for key in batch:
                batch[key] = batch[key].to(device)

            logits = model(batch['input_ids'], batch['attention_mask'])
            loss = criterion(logits, batch['label'])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds = (torch.sigmoid(logits) >= 0.5).long()
            correct += (preds == batch['label'].long()).sum().item()
            total += batch['label'].size(0)

        acc = correct / total
        print(f"Epoch {epoch+1} | Loss: {total_loss / len(dataloader):.4f}, Accuracy: {acc:.4f}")

        
from sklearn.metrics import f1_score

def find_best_threshold(y_true, y_probs):
    best_thresh = 0.5
    best_f1 = 0
    for t in np.linspace(0.01, 0.99, 100):
        preds = (np.array(y_probs) >= t).astype(int)
        f1 = f1_score(y_true, preds)
        if f1 > best_f1:
            best_f1 = f1
            best_thresh = t
    return best_thresh, best_f1

# --- Evaluation ---
def evaluate_cross_encoder(model, dataloader, device):
    model.eval()
    probs, targets = [], []
    with torch.no_grad():
        for batch in dataloader:
            for key in batch:
                batch[key] = batch[key].to(device)
            logits = model(batch['input_ids'], batch['attention_mask'])
            prob = torch.sigmoid(logits).cpu().numpy()
            label = batch['label'].cpu().numpy()
            probs.extend(prob)
            targets.extend(label)

    preds = (np.array(probs) >= 0.5).astype(int)
    print(f"[Default @0.5] AUC-ROC: {roc_auc_score(targets, probs):.4f}")
    print(f"[Default @0.5] F1 Score: {f1_score(targets, preds):.4f}")
    print(f"[Default @0.5] Precision: {precision_score(targets, preds):.4f}")
    print(f"[Default @0.5] Recall: {recall_score(targets, preds):.4f}")

    best_thresh, best_f1 = find_best_threshold(targets, probs)
    best_preds = (np.array(probs) >= best_thresh).astype(int)
    print(f"🔍 Best threshold: {best_thresh:.4f}")
    print(f"[Best] F1 Score: {f1_score(targets, best_preds):.4f}")
    print(f"[Best] Precision: {precision_score(targets, best_preds):.4f}")
    print(f"[Best] Recall: {recall_score(targets, best_preds):.4f}")
    
    return probs, targets, best_thresh  

def evaluate_with_threshold(model, dataloader, device, threshold):
    model.eval()
    probs, targets = [], []
    with torch.no_grad():
        for batch in dataloader:
            for key in batch:
                batch[key] = batch[key].to(device)
            logits = model(batch['input_ids'], batch['attention_mask'])
            prob = torch.sigmoid(logits).cpu().numpy()
            label = batch['label'].cpu().numpy()
            probs.extend(prob)
            targets.extend(label)
    print("-------test-------")       
    preds = (np.array(probs) >= 0.5).astype(int)
    print(f"[Default @0.5] AUC-ROC: {roc_auc_score(targets, probs):.4f}")
    print(f"[Default @0.5] F1 Score: {f1_score(targets, preds):.4f}")
    print(f"[Default @0.5] Precision: {precision_score(targets, preds):.4f}")
    print(f"[Default @0.5] Recall: {recall_score(targets, preds):.4f}")

    preds = (np.array(probs) >= threshold).astype(int)
    print(f"[Test @{threshold:.4f}] AUC-ROC: {roc_auc_score(targets, probs):.4f}")
    print(f"[Test @{threshold:.4f}] F1 Score: {f1_score(targets, preds):.4f}")
    print(f"[Test @{threshold:.4f}] Precision: {precision_score(targets, preds):.4f}")
    print(f"[Test @{threshold:.4f}] Recall: {recall_score(targets, preds):.4f}")

# # --- Usage ---
# model_name = "/home/liorkob/M.Sc/thesis/pre-train/mlm/Legal-heBERT-mlm-3k-drugs/final"
# tokenizer = AutoTokenizer.from_pretrained("/home/liorkob/M.Sc/thesis/pre-train/mlm/Legal-heBERT-mlm-3k-drugs/final")

# model_name = "/home/liorkob/M.Sc/thesis/pre-train/hebert-mlm-3k-drugs-punishment"
# tokenizer = AutoTokenizer.from_pretrained("/home/liorkob/M.Sc/thesis/pre-train/hebert-mlm-3k-drugs-punishment")


# model_name ="avichr/Legal-heBERT"
# tokenizer = AutoTokenizer.from_pretrained("avichr/Legal-heBERT")
# # df_train, df_val = train_test_split(df, stratify=df.label, test_size=0.2, random_state=42)
# df_train = pd.read_csv("/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_train.csv")
# df_val = pd.read_csv("/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_val.csv")
# df_test= pd.read_csv("/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_test.csv")


# train_dataset = CrossEncoderVerdictDataset(df_train, tokenizer)
# val_dataset = CrossEncoderVerdictDataset(df_val, tokenizer)
# test_dataset = CrossEncoderVerdictDataset(df_test, tokenizer)

# train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=8)
# test_loader = DataLoader(test_dataset, batch_size=8)

# model = CrossEncoderHeBERT(model_name).to(device)
# optimizer = Adam(model.parameters(), lr=2e-5)

# train_cross_encoder_with_early_stopping(model, train_loader, val_loader, optimizer, device, epochs=20, patience=5)

# print("✅ CrossEncoderHeBERT model saved.")
# # model.load_state_dict(torch.load("best_crossencoder.pt"))  

# _, _, best_thresh = evaluate_cross_encoder(model, val_loader, device)
# evaluate_with_threshold(model, test_loader, device, best_thresh)


In [3]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from torch.optim import Adam
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
from scipy.stats import ttest_rel
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- Your existing classes (unchanged) ---
class CrossEncoderVerdictDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=512):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = f"[CLS] {row['gpt_facts_a']} [SEP] {row['gpt_facts_b']} [SEP]"
        enc = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt')
        return {
            'input_ids': enc['input_ids'].squeeze(),
            'attention_mask': enc['attention_mask'].squeeze(),
            'label': torch.tensor(row['label'], dtype=torch.float)
        }

class CrossEncoderHeBERT(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden = self.encoder.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(hidden, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0]  # [CLS] token
        return self.classifier(pooled).squeeze(-1)

# --- Modified training function for k-fold ---
def train_model_fold(model, train_loader, val_loader, optimizer, device, epochs=15, patience=3, verbose=False):
    """Train model for one fold with early stopping"""
    criterion = nn.BCEWithLogitsLoss()
    best_auc = 0
    no_improve = 0
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            for key in batch:
                batch[key] = batch[key].to(device)

            logits = model(batch['input_ids'], batch['attention_mask'])
            loss = criterion(logits, batch['label'])

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Validation
        model.eval()
        val_probs, val_labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                for key in batch:
                    batch[key] = batch[key].to(device)
                logits = model(batch['input_ids'], batch['attention_mask'])
                prob = torch.sigmoid(logits).cpu().numpy()
                label = batch['label'].cpu().numpy()
                val_probs.extend(prob)
                val_labels.extend(label)
        
        val_auc = roc_auc_score(val_labels, val_probs)
        
        if verbose:
            print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f}, Val AUC: {val_auc:.4f}")
        
        # Early stopping
        if val_auc > best_auc:
            best_auc = val_auc
            no_improve = 0
            best_state = model.state_dict().copy()
        else:
            no_improve += 1
            if no_improve >= patience:
                if verbose:
                    print("Early stopping triggered")
                break
    
    # Load best model
    model.load_state_dict(best_state)
    return model

def evaluate_model_fold(model, test_loader, device):
    """Evaluate model on test set and return AUC"""
    model.eval()
    probs, targets = [], []
    
    with torch.no_grad():
        for batch in test_loader:
            for key in batch:
                batch[key] = batch[key].to(device)
            logits = model(batch['input_ids'], batch['attention_mask'])
            prob = torch.sigmoid(logits).cpu().numpy()
            label = batch['label'].cpu().numpy()
            probs.extend(prob)
            targets.extend(label)
    
    auc = roc_auc_score(targets, probs)
    return auc, probs, targets

def run_kfold_comparison(df_full, baseline_model_name, finetuned_model_name, k=5, epochs=15, patience=3, batch_size=8, random_state=42):
    """
    Run k-fold cross-validation comparing baseline vs fine-tuned model
    
    Args:
        df_full: Full dataset DataFrame
        baseline_model_name: Name/path of baseline model
        finetuned_model_name: Name/path of fine-tuned model
        k: Number of folds
        epochs: Max epochs per fold
        patience: Early stopping patience
        batch_size: Batch size
        random_state: Random seed
    """
    
    # Initialize tokenizers
    baseline_tokenizer = AutoTokenizer.from_pretrained(baseline_model_name)
    finetuned_tokenizer = AutoTokenizer.from_pretrained(finetuned_model_name)
    
    # Set up k-fold
    skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=random_state)
    
    baseline_aucs = []
    finetuned_aucs = []
    fold_results = []
    
    print(f"🚀 Starting {k}-Fold Cross-Validation")
    print(f"Baseline Model: {baseline_model_name}")
    print(f"Fine-tuned Model: {finetuned_model_name}")
    print("=" * 80)
    
    for fold, (train_idx, test_idx) in enumerate(skf.split(df_full, df_full['label'])):
        print(f"\n📁 FOLD {fold + 1}/{k}")
        print("-" * 40)
        
        # Split data
        df_train_fold = df_full.iloc[train_idx].reset_index(drop=True)
        df_test_fold = df_full.iloc[test_idx].reset_index(drop=True)
        
        # Further split training into train/val (80/20)
        train_size = int(0.8 * len(df_train_fold))
        df_train = df_train_fold[:train_size]
        df_val = df_train_fold[train_size:]
        
        print(f"Train: {len(df_train)}, Val: {len(df_val)}, Test: {len(df_test_fold)}")
        
        # === BASELINE MODEL ===
        print("\n🔵 Training Baseline Model...")
        
        # Create datasets and loaders for baseline
        train_dataset_base = CrossEncoderVerdictDataset(df_train, baseline_tokenizer)
        val_dataset_base = CrossEncoderVerdictDataset(df_val, baseline_tokenizer)
        test_dataset_base = CrossEncoderVerdictDataset(df_test_fold, baseline_tokenizer)
        
        train_loader_base = DataLoader(train_dataset_base, batch_size=batch_size, shuffle=True)
        val_loader_base = DataLoader(val_dataset_base, batch_size=batch_size)
        test_loader_base = DataLoader(test_dataset_base, batch_size=batch_size)
        
        # Initialize and train baseline model
        baseline_model = CrossEncoderHeBERT(baseline_model_name).to(device)
        baseline_optimizer = Adam(baseline_model.parameters(), lr=2e-5)
        
        baseline_model = train_model_fold(
            baseline_model, train_loader_base, val_loader_base, 
            baseline_optimizer, device, epochs, patience, verbose=False
        )
        
        # Evaluate baseline
        baseline_auc, _, _ = evaluate_model_fold(baseline_model, test_loader_base, device)
        baseline_aucs.append(baseline_auc)
        print(f"Baseline AUC: {baseline_auc:.4f}")
        
        # Clean up baseline model
        del baseline_model, baseline_optimizer
        torch.cuda.empty_cache()
        
        # === FINE-TUNED MODEL ===
        print("\n🟢 Training Fine-tuned Model...")
        
        # Create datasets and loaders for fine-tuned model
        train_dataset_ft = CrossEncoderVerdictDataset(df_train, finetuned_tokenizer)
        val_dataset_ft = CrossEncoderVerdictDataset(df_val, finetuned_tokenizer)
        test_dataset_ft = CrossEncoderVerdictDataset(df_test_fold, finetuned_tokenizer)
        
        train_loader_ft = DataLoader(train_dataset_ft, batch_size=batch_size, shuffle=True)
        val_loader_ft = DataLoader(val_dataset_ft, batch_size=batch_size)
        test_loader_ft = DataLoader(test_dataset_ft, batch_size=batch_size)
        
        # Initialize and train fine-tuned model
        finetuned_model = CrossEncoderHeBERT(finetuned_model_name).to(device)
        finetuned_optimizer = Adam(finetuned_model.parameters(), lr=2e-5)
        
        finetuned_model = train_model_fold(
            finetuned_model, train_loader_ft, val_loader_ft, 
            finetuned_optimizer, device, epochs, patience, verbose=False
        )
        
        # Evaluate fine-tuned
        finetuned_auc, _, _ = evaluate_model_fold(finetuned_model, test_loader_ft, device)
        finetuned_aucs.append(finetuned_auc)
        print(f"Fine-tuned AUC: {finetuned_auc:.4f}")
        
        # Calculate improvement
        improvement = finetuned_auc - baseline_auc
        print(f"Improvement: {improvement:+.4f}")
        
        # Store fold results
        fold_results.append({
            'fold': fold + 1,
            'baseline_auc': baseline_auc,
            'finetuned_auc': finetuned_auc,
            'improvement': improvement
        })
        
        # Clean up fine-tuned model
        del finetuned_model, finetuned_optimizer
        torch.cuda.empty_cache()
    
    # === STATISTICAL ANALYSIS ===
    print("\n" + "=" * 80)
    print("📊 STATISTICAL ANALYSIS")
    print("=" * 80)
    
    # Convert to numpy arrays
    baseline_aucs = np.array(baseline_aucs)
    finetuned_aucs = np.array(finetuned_aucs)
    improvements = finetuned_aucs - baseline_aucs
    
    # Summary statistics
    print(f"\n📈 Summary Statistics:")
    print(f"Baseline AUC:    {baseline_aucs.mean():.4f} ± {baseline_aucs.std():.4f}")
    print(f"Fine-tuned AUC:  {finetuned_aucs.mean():.4f} ± {finetuned_aucs.std():.4f}")
    print(f"Mean Improvement: {improvements.mean():+.4f} ± {improvements.std():.4f}")
    
    # Paired t-test
    t_stat, p_value = ttest_rel(finetuned_aucs, baseline_aucs)
    
    print(f"\n🧪 Paired T-Test Results:")
    print(f"t-statistic: {t_stat:.4f}")
    print(f"p-value: {p_value:.2e}")
    
    # Significance interpretation
    alpha = 0.05
    if p_value < alpha:
        significance = "✅ SIGNIFICANT"
        interpretation = f"The improvement is statistically significant (p < {alpha})"
    else:
        significance = "❌ NOT SIGNIFICANT"
        interpretation = f"The improvement is not statistically significant (p ≥ {alpha})"
    
    print(f"Result: {significance}")
    print(f"Interpretation: {interpretation}")
    
    # Effect size (Cohen's d)
    pooled_std = np.sqrt((baseline_aucs.var() + finetuned_aucs.var()) / 2)
    cohens_d = improvements.mean() / pooled_std
    print(f"Effect Size (Cohen's d): {cohens_d:.4f}")
    
    # Effect size interpretation
    if abs(cohens_d) < 0.2:
        effect_size_interp = "negligible"
    elif abs(cohens_d) < 0.5:
        effect_size_interp = "small"
    elif abs(cohens_d) < 0.8:
        effect_size_interp = "medium"
    else:
        effect_size_interp = "large"
    
    print(f"Effect Size Interpretation: {effect_size_interp}")
    
    # Fold-wise results table
    print(f"\n📋 Fold-wise Results:")
    print("Fold | Baseline | Fine-tuned | Improvement")
    print("-" * 45)
    for result in fold_results:
        print(f"{result['fold']:4d} | {result['baseline_auc']:8.4f} | {result['finetuned_auc']:10.4f} | {result['improvement']:+10.4f}")
    
    # Confidence interval for mean improvement
    from scipy.stats import t
    confidence_level = 0.95
    df_ci = len(improvements) - 1
    t_critical = t.ppf((1 + confidence_level) / 2, df_ci)
    margin_error = t_critical * (improvements.std() / np.sqrt(len(improvements)))
    ci_lower = improvements.mean() - margin_error
    ci_upper = improvements.mean() + margin_error
    
    print(f"\n🎯 {confidence_level*100}% Confidence Interval for Mean Improvement:")
    print(f"[{ci_lower:+.4f}, {ci_upper:+.4f}]")
    
    return {
        'baseline_aucs': baseline_aucs,
        'finetuned_aucs': finetuned_aucs,
        'improvements': improvements,
        't_statistic': t_stat,
        'p_value': p_value,
        'cohens_d': cohens_d,
        'mean_improvement': improvements.mean(),
        'std_improvement': improvements.std(),
        'confidence_interval': (ci_lower, ci_upper),
        'fold_results': fold_results
    }

# # === MAIN EXECUTION ===
# if __name__ == "__main__":
#     # Load your full dataset
#     print("📂 Loading datasets...")
#     df_train = pd.read_csv("/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_train.csv")
#     df_val = pd.read_csv("/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_val.csv")
#     df_test = pd.read_csv("/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_test.csv")
    
#     # Combine all splits for k-fold CV
#     df_full = pd.concat([df_train, df_val, df_test], ignore_index=True)
#     print(f"Total dataset size: {len(df_full)} samples")
#     print(f"Label distribution: {df_full['label'].value_counts().to_dict()}")
    
#     # Define model names/paths
#     baseline_model_name = "avichr/Legal-heBERT"  # Your baseline
    
#     # Choose your fine-tuned model (uncomment the one you want to test)
#     finetuned_model_name = "/home/liorkob/M.Sc/thesis/pre-train/mlm/Legal-heBERT-mlm-3k-drugs/final"
#     # finetuned_model_name = "/home/liorkob/M.Sc/thesis/pre-train/hebert-mlm-3k-drugs-punishment"
    
#     # Run k-fold comparison
#     results = run_kfold_comparison(
#         df_full=df_full,
#         baseline_model_name=baseline_model_name,
#         finetuned_model_name=finetuned_model_name,
#         k=5,  # 5-fold CV
#         epochs=15,
#         patience=3,
#         batch_size=8,
#         random_state=42
#     )
    
#     print(f"\n🏁 Analysis Complete!")
#     print(f"Your fine-tuned HeBERT shows a mean improvement of {results['mean_improvement']:+.4f} AUC-ROC")
    
#     if results['p_value'] < 0.05:
#         print("🎉 The improvement is statistically significant!")
#     else:
#         print("⚠️  The improvement is not statistically significant.")

Using device: cuda


In [None]:
experiments = [
    {
        "name": "HeBERT MLM vs Baseline",
        "baseline": "avichr/heBERT",
        "finetuned": "/home/liorkob/M.Sc/thesis/pre-train/models/hebert-mlm-3k-drugs/final"
    },
    {
        "name": "mBERT MLM vs Baseline",
        "baseline": "bert-base-multilingual-cased",
        "finetuned": "/home/liorkob/M.Sc/thesis/pre-train/models/mBERT-mlm-3k-drugs/final"
    },
    {
        "name": "Legal-HeBERT MLM vs Baseline",
        "baseline": "avichr/Legal-heBERT",
        "finetuned": "/home/liorkob/M.Sc/thesis/pre-train/models/Legal-heBERT-mlm-3k-drugs/final"
    }
]

for exp in experiments:
        # Load your full dataset
    print("📂 Loading datasets...")
    df_train = pd.read_csv("/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_train.csv")
    df_val = pd.read_csv("/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_val.csv")
    df_test = pd.read_csv("/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_test.csv")
    
    # Combine all splits for k-fold CV
    df_full = pd.concat([df_train, df_val, df_test], ignore_index=True)
    print(f"Total dataset size: {len(df_full)} samples")
    print(f"Label distribution: {df_full['label'].value_counts().to_dict()}")

    print(f"\n\n🚨 Running Experiment: {exp['name']}")
    results = run_kfold_comparison(
        df_full=df_full,
        baseline_model_name=exp["baseline"],
        finetuned_model_name=exp["finetuned"],
        k=5,
        epochs=15,
        patience=3,
        batch_size=8,
        random_state=42
    )
    print(f"📌 {exp['name']} mean AUC improvement: {results['mean_improvement']:+.4f}")
    if results["p_value"] < 0.05:
        print("✅ Statistically significant improvement!\n")
    else:
        print("❌ Not statistically significant.\n")


📂 Loading datasets...
Total dataset size: 5791 samples
Label distribution: {0: 3857, 1: 1934}


🚨 Running Experiment: HeBERT MLM vs Baseline


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


🚀 Starting 5-Fold Cross-Validation
Baseline Model: avichr/heBERT
Fine-tuned Model: /home/liorkob/M.Sc/thesis/pre-train/models/hebert-mlm-3k-drugs/final

📁 FOLD 1/5
----------------------------------------
Train: 3705, Val: 927, Test: 1159

🔵 Training Baseline Model...


Some weights of BertModel were not initialized from the model checkpoint at avichr/heBERT and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Baseline AUC: 0.8258

🟢 Training Fine-tuned Model...


Some weights of BertModel were not initialized from the model checkpoint at /home/liorkob/M.Sc/thesis/pre-train/models/hebert-mlm-3k-drugs/final and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, classification_report
from tqdm import tqdm

# ========================
# CONFIG
# ========================
test_file = "/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_test.csv"
batch_size = 4
max_len = 1024
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========================
# Dataset
# ========================
class LegalSentencingCitationDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=512):
        prompt = """מערכת חיזוי ציטוטים משפטיים מתמחה
תחום: דין פלילי - מדיניות ענישה
מטרה: חיזוי ציטוטים בין פסקי דין על בסיס דמיון בעובדות כתב האישום
קריטריונים: ציטוט רלוונטי אם הוא תומך בהחלטת טווח העונש
שאלה: בהתבסס על עובדות כתב האישום, האם צפוי שפסק דין א' יצטט פסק דין ב'?"""

        self.inputs = []
        for idx, row in df.iterrows():
            text = f"""{prompt}

עובדות כתב אישום - פסק דין א':
{row['gpt_facts_a']}

עובדות כתב אישום - פסק דין ב':
{row['gpt_facts_b']}

על בסיס דמיון העבירות והנסיבות, האם פסק דין א' יצטט פסק דין ב'?"""
            self.inputs.append(text)

        self.targets = df["label"].apply(lambda l: "כן" if l == 1 else "לא").tolist()
        self.labels = df["label"].values
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_enc = self.tokenizer(
            self.inputs[idx], padding='max_length', truncation=True,
            max_length=self.max_len, return_tensors="pt"
        )
        target_enc = self.tokenizer(
            self.targets[idx], padding='max_length', truncation=True,
            max_length=5, return_tensors="pt"
        )
        labels = target_enc["input_ids"].squeeze(0)
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            "input_ids": input_enc["input_ids"].squeeze(0),
            "attention_mask": input_enc["attention_mask"].squeeze(0),
            "labels": labels,
            "numeric_label": self.labels[idx]
        }

# ========================
# Inference Function
# ========================
def classify_with_threshold_search(model, tokenizer, input_ids, attention_mask, threshold=0.0):
    with torch.no_grad():
        decoder_input_ids = torch.full((input_ids.shape[0], 1),
                                       tokenizer.pad_token_id, dtype=torch.long, device=input_ids.device)
        logits = model(input_ids=input_ids,
                       attention_mask=attention_mask,
                       decoder_input_ids=decoder_input_ids).logits[:, -1, :]

        yes_tokens = [259, 1903]
        no_tokens = [1124]

        predictions, scores = [], []
        for batch_logits in logits:
            yes_score = torch.mean(batch_logits[yes_tokens]).item()
            no_score = torch.mean(batch_logits[no_tokens]).item()
            score_diff = yes_score - no_score
            predictions.append(1 if score_diff > threshold else 0)
            scores.append(score_diff)
        return predictions, scores

def find_best_threshold(model, tokenizer, dataloader, labels):
    all_diffs = []
    for batch in tqdm(dataloader, desc="Scoring"):
        batch = {k: v.to(device) for k, v in batch.items()}
        _, diffs = classify_with_threshold_search(model, tokenizer,
                                                  batch["input_ids"],
                                                  batch["attention_mask"])
        all_diffs.extend(diffs)

    best_f1, best_th = 0, 0
    for th in np.linspace(min(all_diffs), max(all_diffs), 50):
        preds = []
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            batch_preds, _ = classify_with_threshold_search(model, tokenizer,
                                                            batch["input_ids"],
                                                            batch["attention_mask"],
                                                            threshold=th)
            preds.extend(batch_preds)
        f1 = f1_score(labels, preds, zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_th = th

    print(f"✅ Best Threshold: {best_th:.4f} (F1: {best_f1:.4f})")
    return best_th

def evaluate(model_path):
    print(f"\n📦 Loading model: {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    df_test = pd.read_csv(test_file)
    test_dataset = LegalSentencingCitationDataset(df_test, tokenizer, max_len=max_len)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    true_labels = test_dataset.labels
    best_th = find_best_threshold(model, tokenizer, test_loader, true_labels)

    all_preds = []
    for batch in tqdm(test_loader, desc="Final Eval"):
        batch = {k: v.to(device) for k, v in batch.items()}
        preds, _ = classify_with_threshold_search(model, tokenizer,
                                                  batch["input_ids"],
                                                  batch["attention_mask"],
                                                  threshold=best_th)
        all_preds.extend(preds)

    print("\n📊 Test Results:")
    print("F1:", f1_score(true_labels, all_preds))
    print("Precision:", precision_score(true_labels, all_preds, zero_division=0))
    print("Recall:", recall_score(true_labels, all_preds, zero_division=0))
    print("Accuracy:", np.mean(np.array(all_preds) == true_labels))
    if len(set(all_preds)) > 1:
        print("AUC-ROC:", roc_auc_score(true_labels, all_preds))
    print(classification_report(true_labels, all_preds))

# ========================
# Run on baseline models
# ========================
if __name__ == "__main__":
    baseline_models = [
        "/home/liorkob/M.Sc/thesis/t5/het5-mlm-final",
        "imvladikon/het5-base"
    ]
    for path in baseline_models:
        evaluate(path)


  from .autonotebook import tqdm as notebook_tqdm
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers



📦 Loading model: /home/liorkob/M.Sc/thesis/t5/het5-mlm-final


Scoring: 100%|██████████| 218/218 [00:15<00:00, 14.07it/s]


✅ Best Threshold: 4.7972 (F1: 0.5074)


Final Eval: 100%|██████████| 218/218 [00:16<00:00, 13.50it/s]



📊 Test Results:
F1: 0.5074365704286964
Precision: 0.3403755868544601
Recall: 0.9965635738831615
Accuracy: 0.3528735632183908
AUC-ROC: 0.5129622705339815
              precision    recall  f1-score   support

           0       0.94      0.03      0.06       579
           1       0.34      1.00      0.51       291

    accuracy                           0.35       870
   macro avg       0.64      0.51      0.28       870
weighted avg       0.74      0.35      0.21       870


📦 Loading model: imvladikon/het5-base


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Scoring: 100%|██████████| 218/218 [00:15<00:00, 14.15it/s]


✅ Best Threshold: 3.5125 (F1: 0.5030)


Final Eval: 100%|██████████| 218/218 [00:15<00:00, 13.91it/s]


📊 Test Results:
F1: 0.5030461270670148
Precision: 0.3368298368298368
Recall: 0.993127147766323
Accuracy: 0.34367816091954023
AUC-ROC: 0.5051991524669267
              precision    recall  f1-score   support

           0       0.83      0.02      0.03       579
           1       0.34      0.99      0.50       291

    accuracy                           0.34       870
   macro avg       0.59      0.51      0.27       870
weighted avg       0.67      0.34      0.19       870






In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, classification_report
from tqdm import tqdm
import pandas as pd
import numpy as np

# -----------------------
# Config
# -----------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

test_file = "/home/liorkob/M.Sc/thesis/citation-prediction/data_splits/crossencoder_test.csv"
batch_size = 8
max_len = 512

models_to_evaluate = [
    ("avichr/heBERT", "HeBERT (baseline)"),
    ("/home/liorkob/M.Sc/thesis/pre-train/models/hebert-mlm-3k-drugs/final", "HeBERT-MLM-3K-Drugs (fine-tuned)"),
    ("bert-base-multilingual-cased", "mBERT (baseline)"),
    ("/home/liorkob/M.Sc/thesis/pre-train/models/mBERT-mlm-3k-drugs/final", "mBERT-MLM-3K-Drugs (fine-tuned)"),
    ("avichr/Legal-heBERT", "Legal-HeBERT (baseline)"),
    ("/home/liorkob/M.Sc/thesis/pre-train/models/Legal-heBERT-mlm-3k-drugs/final", "Legal-HeBERT-MLM-3K-Drugs (fine-tuned)")
]

# -----------------------
# Dataset
# -----------------------
class CrossEncoderVerdictDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=512):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = f"[CLS] {row['gpt_facts_a']} [SEP] {row['gpt_facts_b']} [SEP]"
        enc = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt')
        return {
            'input_ids': enc['input_ids'].squeeze(),
            'attention_mask': enc['attention_mask'].squeeze(),
            'label': torch.tensor(row['label'], dtype=torch.float)
        }

# -----------------------
# Model
# -----------------------
class CrossEncoderModel(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        hidden = encoder.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(hidden, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0]
        return self.classifier(pooled).squeeze(-1)
# -----------------------
# Evaluation
# -----------------------
def evaluate_model(model, dataloader):
    model.eval()
    probs, labels = [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            label = batch["label"].cpu().numpy()

            logits = model(input_ids, attention_mask)
            prob = torch.sigmoid(logits).cpu().numpy()

            probs.extend(prob)
            labels.extend(label)

    probs = np.array(probs)
    preds = (probs >= 0.5).astype(int)
    labels = np.array(labels)

    print("F1:", f1_score(labels, preds))
    print("Precision:", precision_score(labels, preds, zero_division=0))
    print("Recall:", recall_score(labels, preds, zero_division=0))
    print("Accuracy:", np.mean(preds == labels))
    if len(set(preds)) > 1 and len(set(labels)) > 1:
        print("AUC-ROC:", roc_auc_score(labels, probs))
    print("Classification Report:")
    print(classification_report(labels, preds))

# -----------------------
# Main
# -----------------------
if __name__ == "__main__":
    df_test = pd.read_csv(test_file)

    for model_path, label in models_to_evaluate:
        print(f"\n\n====================")
        print(f"🔍 Evaluating: {label}")
        print(f"====================")
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        test_dataset = CrossEncoderVerdictDataset(df_test, tokenizer, max_len=max_len)
        test_loader = DataLoader(test_dataset, batch_size=batch_size)
        from transformers import AutoModel, AutoTokenizer
        encoder = AutoModel.from_pretrained(model_path, trust_remote_code=True)
        model = CrossEncoderModel(encoder).to(device)
        evaluate_model(model, test_loader)


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


🔍 Evaluating: HeBERT (baseline)


Some weights of BertModel were not initialized from the model checkpoint at avichr/heBERT and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                             

F1: 0.49111807732497387
Precision: 0.35285285285285284
Recall: 0.8075601374570447
Accuracy: 0.44022988505747124
AUC-ROC: 0.548706443744102
Classification Report:
              precision    recall  f1-score   support

         0.0       0.73      0.26      0.38       579
         1.0       0.35      0.81      0.49       291

    accuracy                           0.44       870
   macro avg       0.54      0.53      0.43       870
weighted avg       0.60      0.44      0.42       870



🔍 Evaluating: HeBERT-MLM-3K-Drugs (fine-tuned)


Some weights of BertModel were not initialized from the model checkpoint at /home/liorkob/M.Sc/thesis/pre-train/models/hebert-mlm-3k-drugs/final and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                             

F1: 0.300187617260788
Precision: 0.3305785123966942
Recall: 0.27491408934707906
Accuracy: 0.5712643678160919
AUC-ROC: 0.48208191632688185
Classification Report:
              precision    recall  f1-score   support

         0.0       0.66      0.72      0.69       579
         1.0       0.33      0.27      0.30       291

    accuracy                           0.57       870
   macro avg       0.50      0.50      0.50       870
weighted avg       0.55      0.57      0.56       870



🔍 Evaluating: mBERT (baseline)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


F1: 0.5012919896640826
Precision: 0.33448275862068966
Recall: 1.0
Accuracy: 0.33448275862068966
Classification Report:
              precision    recall  f1-score   support

         0.0       0.00      0.00      0.00       579
         1.0       0.33      1.00      0.50       291

    accuracy                           0.33       870
   macro avg       0.17      0.50      0.25       870
weighted avg       0.11      0.33      0.17       870



🔍 Evaluating: mBERT-MLM-3K-Drugs (fine-tuned)


Some weights of BertModel were not initialized from the model checkpoint at /home/liorkob/M.Sc/thesis/pre-train/models/mBERT-mlm-3k-drugs/final and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


F1: 0.0
Precision: 0.0
Recall: 0.0
Accuracy: 0.6655172413793103
Classification Report:
              precision    recall  f1-score   support

         0.0       0.67      1.00      0.80       579
         1.0       0.00      0.00      0.00       291

    accuracy                           0.67       870
   macro avg       0.33      0.50      0.40       870
weighted avg       0.44      0.67      0.53       870



🔍 Evaluating: Legal-HeBERT (baseline)


Some weights of BertModel were not initialized from the model checkpoint at avichr/Legal-heBERT and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                             

F1: 0.38605898123324395
Precision: 0.31648351648351647
Recall: 0.4948453608247423
Accuracy: 0.4735632183908046
AUC-ROC: 0.49145938310512854
Classification Report:
              precision    recall  f1-score   support

         0.0       0.65      0.46      0.54       579
         1.0       0.32      0.49      0.39       291

    accuracy                           0.47       870
   macro avg       0.48      0.48      0.46       870
weighted avg       0.54      0.47      0.49       870



🔍 Evaluating: Legal-HeBERT-MLM-3K-Drugs (fine-tuned)


Some weights of BertModel were not initialized from the model checkpoint at /home/liorkob/M.Sc/thesis/pre-train/models/Legal-heBERT-mlm-3k-drugs/final and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                             

F1: 0.3780663780663781
Precision: 0.32587064676616917
Recall: 0.45017182130584193
Accuracy: 0.5045977011494253
AUC-ROC: 0.47610526503213857
Classification Report:
              precision    recall  f1-score   support

         0.0       0.66      0.53      0.59       579
         1.0       0.33      0.45      0.38       291

    accuracy                           0.50       870
   macro avg       0.49      0.49      0.48       870
weighted avg       0.55      0.50      0.52       870



