In [1]:
# pip install -U transformers accelerate bitsandbytes torch  # install/upgrade first
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import logging

logging.set_verbosity_error()  # quieter

HF_TOKEN = "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"  # prefer env var
MODEL_ID = "khalidrajan/Llama-3.1-8B-Instruct-Legal-NLI"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, use_fast=True)

def load_model():
    # Try low-memory safe load first
    try:
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            token=HF_TOKEN,
            device_map="auto",                # let accelerate dispatch
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True,           # key flag to avoid partial offload/meta devices
            offload_folder="/tmp/model_offload",  # optional: where to offload
            offload_state_dict=True,          # helpful for very large models
            use_safetensors=True,             # safer & faster if available
        )
        return model
    except Exception as e:
        print("Primary load failed:", repr(e))

    # Fallback: try loading in 4-bit (requires bitsandbytes)
    try:
        print("Attempting 4-bit (bitsandbytes) load as fallback...")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            token=HF_TOKEN,
            device_map="auto",
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True,
        )
        return model
    except Exception as e2:
        print("4-bit fallback also failed:", repr(e2))

    # Last-resort: load to CPU only (may be slow and may OOM)
    try:
        print("Final fallback: loading onto CPU (may be very slow).")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            token=HF_TOKEN,
            torch_dtype=torch.float32,
            device_map={"": "cpu"},
            low_cpu_mem_usage=True,
        )
        return model
    except Exception as e3:
        print("CPU fallback failed:", repr(e3))
        raise RuntimeError("All model load attempts failed. See traces above.")

# Usage
if __name__ == "__main__":
    # upgrade accelerate & transformers if you run into weird device_map behaviour:
    # pip install -U accelerate transformers
    model = load_model()
    print("Model loaded on devices:", {k: v.device for k, v in model.named_parameters() if hasattr(v, "device")}.keys())


Device: cuda
Base model: Equall/Saul-7B-Instruct-v1
Judge model: mistralai/Mistral-7B-Instruct-v0.2
Total pairs: 40506
Unique sentences to classify: 1000
Loading base model...


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/653 [00:00<?, ?B/s]

2026-01-05 06:14:11.478490: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1767593651.653041      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1767593651.700938      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

model-00005-of-00006.safetensors:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

model-00001-of-00006.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00006-of-00006.safetensors:   0%|          | 0.00/4.25G [00:00<?, ?B/s]

model-00003-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00006.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

Running base model in batches...


  0%|          | 0/250 [00:00<?, ?it/s]

Base pass done. Distribution: Counter({'non-argumentative': 585, 'premise': 301, 'conclusion': 114})
Freed base model from GPU.
Loading judge model...


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

Running judge model in batches...


  0%|          | 0/250 [00:00<?, ?it/s]

Intermediate saved up to 4/1000
Intermediate saved up to 204/1000
Intermediate saved up to 404/1000
Intermediate saved up to 604/1000
Intermediate saved up to 804/1000
Judge pass done.
Freed judge model from GPU.
Final distributions (base): Counter({'non-argumentative': 585, 'premise': 301, 'conclusion': 114})
Final distributions (final): Counter({'premise': 714, 'conclusion': 189, 'non-argumentative': 97})
Saved: /kaggle/working/sentence_argument_labels_with_big_judge.csv


In [2]:
import pandas as pd
from collections import Counter
import numpy as np

# Load results (already in memory if run in same notebook)
df = df_sentences.copy()

base = df["base_llm_label"].astype(str)
final = df["final_llm_label"].astype(str)

assert len(base) == len(final), "Label length mismatch"

N = len(df)

# -------------------------
# 1. Overall Change Metrics
# -------------------------
changed_mask = base != final
num_changed = changed_mask.sum()

change_rate = num_changed / N
stability_rate = 1.0 - change_rate

print("===== Overall Judge Impact =====")
print(f"Total sentences          : {N}")
print(f"Labels changed by judge  : {num_changed}")
print(f"Change rate              : {change_rate:.4f} ({change_rate*100:.2f}%)")
print(f"Stability (agreement)    : {stability_rate:.4f} ({stability_rate*100:.2f}%)")

# -------------------------
# 2. Flip Confusion Matrix
# -------------------------
labels = sorted(set(base) | set(final))
flip_matrix = pd.crosstab(base, final, rownames=["Base"], colnames=["Judge"])

print("\n===== Label Flip Matrix (Base → Judge) =====")
display(flip_matrix)

# -------------------------
# 3. Normalized Flip Matrix
# -------------------------
flip_norm = flip_matrix.div(flip_matrix.sum(axis=1), axis=0).fillna(0)

print("\n===== Normalized Flip Matrix (Row-wise %) =====")
display((flip_norm * 100).round(2))

# -------------------------
# 4. Per-Class Change Rates
# -------------------------
print("\n===== Per-Class Change Rate =====")
for lbl in labels:
    mask = base == lbl
    if mask.sum() == 0:
        continue
    changed = (final[mask] != lbl).sum()
    rate = changed / mask.sum()
    print(f"{lbl:18s}: {changed:5d}/{mask.sum():5d}  ({rate*100:6.2f}%)")

# -------------------------
# 5. Net Gain / Loss per Class
# -------------------------
base_counts = Counter(base)
final_counts = Counter(final)

print("\n===== Net Label Redistribution =====")
for lbl in labels:
    delta = final_counts[lbl] - base_counts[lbl]
    sign = "+" if delta >= 0 else ""
    print(f"{lbl:18s}: {base_counts[lbl]:6d} → {final_counts[lbl]:6d}  ({sign}{delta})")

# -------------------------
# 6. Judge Correction Focus
# -------------------------
corrections = df.loc[changed_mask, ["base_llm_label", "final_llm_label"]]
top_corrections = (
    corrections
    .value_counts()
    .rename("count")
    .reset_index()
)

print("\n===== Most Common Corrections (Top 10) =====")
display(top_corrections.head(10))

# -------------------------
# 7. Paper-Ready Summary Numbers
# -------------------------
summary = {
    "total_sentences": N,
    "changed_labels": num_changed,
    "change_rate_pct": round(change_rate * 100, 2),
    "agreement_pct": round(stability_rate * 100, 2),
}

summary_df = pd.DataFrame([summary])
print("\n===== Paper-Ready Summary =====")
display(summary_df)


===== Overall Judge Impact =====
Total sentences          : 1000
Labels changed by judge  : 610
Change rate              : 0.6100 (61.00%)
Stability (agreement)    : 0.3900 (39.00%)

===== Label Flip Matrix (Base → Judge) =====


Judge,conclusion,non-argumentative,premise
Base,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
conclusion,72,2,40
non-argumentative,49,90,446
premise,68,5,228



===== Normalized Flip Matrix (Row-wise %) =====


Judge,conclusion,non-argumentative,premise
Base,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
conclusion,63.16,1.75,35.09
non-argumentative,8.38,15.38,76.24
premise,22.59,1.66,75.75



===== Per-Class Change Rate =====
conclusion        :    42/  114  ( 36.84%)
non-argumentative :   495/  585  ( 84.62%)
premise           :    73/  301  ( 24.25%)

===== Net Label Redistribution =====
conclusion        :    114 →    189  (+75)
non-argumentative :    585 →     97  (-488)
premise           :    301 →    714  (+413)

===== Most Common Corrections (Top 10) =====


Unnamed: 0,base_llm_label,final_llm_label,count
0,non-argumentative,premise,446
1,premise,conclusion,68
2,non-argumentative,conclusion,49
3,conclusion,premise,40
4,premise,non-argumentative,5
5,conclusion,non-argumentative,2



===== Paper-Ready Summary =====


Unnamed: 0,total_sentences,changed_labels,change_rate_pct,agreement_pct
0,1000,610,61.0,39.0
