In [None]:
# pip install -U transformers accelerate bitsandbytes torch  # install/upgrade first
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import logging

logging.set_verbosity_error()  # quieter

HF_TOKEN = ""  # prefer env var
MODEL_ID = "khalidrajan/Llama-3.1-8B-Instruct-Legal-NLI"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, use_fast=True)

def load_model():
    # Try low-memory safe load first
    try:
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            token=HF_TOKEN,
            device_map="auto",                # let accelerate dispatch
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True,           # key flag to avoid partial offload/meta devices
            offload_folder="/tmp/model_offload",  # optional: where to offload
            offload_state_dict=True,          # helpful for very large models
            use_safetensors=True,             # safer & faster if available
        )
        return model
    except Exception as e:
        print("Primary load failed:", repr(e))

    # Fallback: try loading in 4-bit (requires bitsandbytes)
    try:
        print("Attempting 4-bit (bitsandbytes) load as fallback...")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            token=HF_TOKEN,
            device_map="auto",
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True,
        )
        return model
    except Exception as e2:
        print("4-bit fallback also failed:", repr(e2))

    # Last-resort: load to CPU only (may be slow and may OOM)
    try:
        print("Final fallback: loading onto CPU (may be very slow).")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            token=HF_TOKEN,
            torch_dtype=torch.float32,
            device_map={"": "cpu"},
            low_cpu_mem_usage=True,
        )
        return model
    except Exception as e3:
        print("CPU fallback failed:", repr(e3))
        raise RuntimeError("All model load attempts failed. See traces above.")

# Usage
if __name__ == "__main__":
    # upgrade accelerate & transformers if you run into weird device_map behaviour:
    # pip install -U accelerate transformers
    model = load_model()
    print("Model loaded on devices:", {k: v.device for k, v in model.named_parameters() if hasattr(v, "device")}.keys())


In [None]:
import torch
import torch.nn.functional as F
from typing import List, Tuple

# assume `model` and `tokenizer` are already loaded and on the right device
device = next(model.parameters()).device

LABELS = ["Entailed", "Contradicted", "Neutral"]

# Pre-tokenize label strings to token id lists (keeps leading space so tokenization matches generation)
label_token_ids = [tokenizer.encode(" " + lbl, add_special_tokens=False) for lbl in LABELS]
# Show if any label has >1 token (handled below)
print("Label token lengths:", {lbl: len(tok) for lbl, tok in zip(LABELS, label_token_ids)})

def score_label_sequence(prompt_input_ids: torch.Tensor, label_ids: List[int]) -> float:
    """
    Given input_ids for the prompt (1 x L_prompt) and a label token id list,
    returns the log-probability of the label sequence being generated next:
      log P(label_0, label_1, ... | prompt)
    """
    # concat prompt + label tokens and run model once
    concat = torch.cat([prompt_input_ids, torch.tensor(label_ids, device=device).unsqueeze(0)], dim=1)
    # get logits for the whole concatenated sequence
    with torch.no_grad():
        outputs = model(concat)
        logits = outputs.logits  # shape (1, seq_len, vocab_size)

    # We only need logits at positions corresponding to label tokens:
    # those are logits[:, prompt_len-1 : prompt_len+label_len-1] -> but for causal LM, the logit that predicts token t is at index t-1
    prompt_len = prompt_input_ids.shape[1]
    label_len = len(label_ids)
    # logits predicting label token i are at position: prompt_len + i - 1
    logprob = 0.0
    for i in range(label_len):
        logit_pos = prompt_len + i - 1
        # if logit_pos < 0 (i==0 and prompt_len==0) handle gracefully (shouldn't happen here)
        token_logits = logits[0, logit_pos]  # vocab logits
        probs = F.log_softmax(token_logits, dim=-1)
        token_id = label_ids[i]
        logprob += probs[token_id].item()
    return logprob

def classify_one(premise: str, hypothesis: str, labels: List[str] = LABELS) -> str:
    """
    Returns the best label from `labels` for a single premise/hypothesis pair.
    """
    system = ("You are a legal NLI assistant. Given a legal PREMISE and a HYPOTHESIS, "
              "respond with exactly one word: Entailed, Contradicted, or Neutral.")
    user = f"PREMISE: {premise}\nHYPOTHESIS: {hypothesis}\nLabel:"
    prompt = system + "\n\n" + user

    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
    input_ids = inputs["input_ids"].to(device)

    # compute score for each label token sequence
    best_label = None
    best_score = -1e9
    for lbl, lbl_tokens in zip(labels, label_token_ids):
        score = score_label_sequence(input_ids, lbl_tokens)
        if score > best_score:
            best_score = score
            best_label = lbl
    return best_label

# ---- Batched classification (multiple pairs) ----
def classify_batch(pairs: List[Tuple[str, str]], labels: List[str] = LABELS, batch_size: int = 4) -> List[str]:
    """
    Classify a list of (premise, hypothesis) pairs.
    Returns list of predicted labels in same order.
    """
    results = []
    for i in range(0, len(pairs), batch_size):
        batch = pairs[i:i+batch_size]
        prompts = []
        for premise, hypothesis in batch:
            system = ("You are a legal NLI assistant. Given a legal PREMISE and a HYPOTHESIS, "
                      "respond with exactly one word: Entailed, Contradicted, or Neutral.")
            user = f"PREMISE: {premise}\nHYPOTHESIS: {hypothesis}\nLabel:"
            prompts.append(system + "\n\n" + user)

        inputs = tokenizer(prompts, return_tensors="pt", padding=True, add_special_tokens=False)
        input_ids = inputs["input_ids"].to(device)  # shape (B, L)
        attention_mask = inputs["attention_mask"].to(device) if "attention_mask" in inputs else None

        # We'll compute concat model runs for each example separately (since prompt lengths vary)
        # Could be optimized by grouping equal-length prompts; this is simpler and safe.
        for idx in range(input_ids.shape[0]):
            # determine actual prompt length using attention mask or pad token
            if attention_mask is not None:
                prompt_len = int(attention_mask[idx].sum().item())
            else:
                # if no mask, find first padding id (heuristic)
                seq = input_ids[idx]
                prompt_len = seq.ne(seq[0]).sum().item()  # not perfect; better to use mask

            single_input_ids = input_ids[idx:idx+1, :prompt_len]  # (1, prompt_len)
            best_label = None
            best_score = -1e9
            for lbl, lbl_tokens in zip(labels, label_token_ids):
                score = score_label_sequence(single_input_ids, lbl_tokens)
                if score > best_score:
                    best_score = score
                    best_label = lbl
            results.append(best_label)
    return results

# ---- tiny demo / eval ----
if __name__ == "__main__":
    demo_pairs = [
        ("The tenant must provide 30 days' written notice before terminating the lease.",
         "A tenant can end the lease immediately without notifying the landlord."),
        ("A party who signs a contract is bound by its terms unless fraud is proven.",
         "If someone signs a contract, they are never bound by it."),
        ("All employees are entitled to one day off weekly according to the policy.",
         "Employees get at least one day off every week."),
    ]
    # single predictions
    for p, h in demo_pairs:
        print("P:", p)
        print("H:", h)
        print("->", classify_one(p, h))
        print()

    # batched
    print("Batched results:", classify_batch(demo_pairs))


In [None]:
# ================================
# NLI Evaluation on 1000 Random Pairs
# ================================

import pandas as pd
import random
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# -------------------------
# CONFIG
# -------------------------
DATA_DIR = "/kaggle/input/lrec-dataset"
POS_FILE = os.path.join(DATA_DIR, "sentencePair.txt")
NEG_FILE = os.path.join(DATA_DIR, "sentencePair_neg.txt")
SAMPLE_SIZE = 1000
SEED = 42
BATCH_SIZE = 2   # safe for 8B model

# -------------------------
# LABEL MAPPING
# -------------------------
LABEL_MAP = {
    "SUPPORT": "Entailed",
    "ATTACK": "Contradicted",
    "REFUTE": "Contradicted",
    "NO_REL": "Neutral",
    "NEUTRAL": "Neutral"
}

# -------------------------
# LOAD + CLEAN DATA
# -------------------------
def load_sentencepair_file(path):
    df = pd.read_csv(
        path,
        sep="\t",
        header=None,
        quoting=3,
        dtype=str
    )
    df = df.rename(columns={
        3: "premise",
        6: "hypothesis",
        8: "raw_label"
    })
    df = df[["premise", "hypothesis", "raw_label"]]
    df = df.dropna()
    return df

df_pos = load_sentencepair_file(POS_FILE)
df_neg = load_sentencepair_file(NEG_FILE)

df = pd.concat([df_pos, df_neg], ignore_index=True)

df["label"] = df["raw_label"].map(LABEL_MAP)
df = df.dropna(subset=["label"]).reset_index(drop=True)

print("Label distribution (full dataset):")
print(df["label"].value_counts(), "\n")

# -------------------------
# RANDOM SAMPLE
# -------------------------
df_sample = df.sample(
    n=min(SAMPLE_SIZE, len(df)),
    random_state=SEED
).reset_index(drop=True)

pairs = list(zip(df_sample["premise"], df_sample["hypothesis"]))
true_labels = df_sample["label"].tolist()

print(f"Running NLI on {len(pairs)} sentence pairs...\n")

# -------------------------
# NLI INFERENCE
# -------------------------
pred_labels = classify_batch(
    pairs,
    batch_size=BATCH_SIZE
)

# -------------------------
# EVALUATION
# -------------------------
acc = accuracy_score(true_labels, pred_labels)

print("===== Accuracy =====")
print(f"{acc:.4f}\n")

print("===== Classification Report =====")
print(classification_report(
    true_labels,
    pred_labels,
    labels=["Entailed", "Contradicted", "Neutral"]
))

cm = confusion_matrix(
    true_labels,
    pred_labels,
    labels=["Entailed", "Contradicted", "Neutral"]
)

cm_df = pd.DataFrame(
    cm,
    index=["True_Entailed", "True_Contradicted", "True_Neutral"],
    columns=["Pred_Entailed", "Pred_Contradicted", "Pred_Neutral"]
)

print("===== Confusion Matrix =====")
display(cm_df)

# -------------------------
# SAVE RESULTS
# -------------------------
df_sample["predicted_label"] = pred_labels
df_sample.to_csv("/kaggle/working/nli_eval_1000_pairs.csv", index=False)

print("\nSaved predictions to /kaggle/working/nli_eval_1000_pairs.csv")
