In [1]:
# two_model_pipeline_lowmem.py
import os
import json
import time
import re
from collections import Counter
from typing import List
import gc

import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.auto import tqdm

# -------------------------
# CONFIG
# -------------------------
DATA_DIR = "/kaggle/input/lrec-dataset"
FILE_POS = os.path.join(DATA_DIR, "sentencePair.txt")
FILE_NEG = os.path.join(DATA_DIR, "sentencePair_neg.txt")

BASE_MODEL_NAME = "Equall/Saul-7B-Instruct-v1"
JUDGE_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

MAX_SENTENCES = 1000  # or None
BATCH_SIZE_BASE = 4    # small to avoid OOM
BATCH_SIZE_JUDGE = 4   # small to avoid OOM
MAX_NEW_TOKENS_BASE = 64
MAX_NEW_TOKENS_JUDGE = 64
MAX_PROMPT_TOKENS = 1024  # truncate long prompts
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

OUTPUT_CSV = "/kaggle/working/sentence_argument_labels_with_big_judge.csv"
CHECKPOINT_INTERVAL = 200
SEED = 42

CLASS_LABELS = ["premise", "conclusion", "non-argumentative"]

torch.manual_seed(SEED)

print("Device:", DEVICE)
print("Base model:", BASE_MODEL_NAME)
print("Judge model:", JUDGE_MODEL_NAME)

# -------------------------
# LOAD DATA
# -------------------------
col_names = [
    "pair_id",
    "doc1",
    "line1",
    "sent1",
    "doc2",
    "line2",
    "sent2",
    "rel_code",
    "rel_label",
]

df_pos = pd.read_csv(FILE_POS, sep="\t", header=None, names=col_names, quoting=3, encoding="utf-8")
df_neg = pd.read_csv(FILE_NEG, sep="\t", header=None, names=col_names, quoting=3, encoding="utf-8")

df_pairs = pd.concat([df_pos, df_neg], ignore_index=True)
print("Total pairs:", len(df_pairs))

sent1_df = df_pairs[["doc1", "line1", "sent1"]].rename(columns={"doc1": "doc", "line1": "line", "sent1": "text"})
sent2_df = df_pairs[["doc2", "line2", "sent2"]].rename(columns={"doc2": "doc", "line2": "line", "sent2": "text"})

df_sentences = (
    pd.concat([sent1_df, sent2_df], ignore_index=True)
    .drop_duplicates(subset=["doc", "line", "text"])
    .reset_index(drop=True)
)
df_sentences.insert(0, "sent_id", range(1, len(df_sentences) + 1))

if MAX_SENTENCES is not None:
    df_sentences = df_sentences.head(MAX_SENTENCES).reset_index(drop=True)

print("Unique sentences to classify:", len(df_sentences))

# -------------------------
# MODEL LOADER
# -------------------------
def load_model_and_tokenizer(name: str):
    tokenizer = AutoTokenizer.from_pretrained(name, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    dtype = torch.float16 if DEVICE == "cuda" else torch.float32
    model = AutoModelForCausalLM.from_pretrained(
        name,
        torch_dtype=dtype,
        device_map="auto" if DEVICE == "cuda" else None,
    )
    model.eval()
    return tokenizer, model

# -------------------------
# PROMPT TEMPLATES (placeholders)
# -------------------------
BASE_PROMPT_TEMPLATE = """
You are an expert in legal argument mining.

Task:
Given a SINGLE sentence from a court case, classify it into exactly ONE of these categories:

1. "premise"            - a reason, supporting fact, or evidence offered in support of some conclusion.
2. "conclusion"         - a main claim, decision, ruling, or statement that is being supported by reasons.
3. "non-argumentative"  - purely narrative, descriptive, procedural, factual background, citations, headings, etc.

Important rules:
- Focus ONLY on the content of the sentence itself.
- Ignore any dataset labels, case numbers, or line numbers.
- Output ONLY a single JSON object with this exact schema:

  {"label": "premise"}            OR
  {"label": "conclusion"}         OR
  {"label": "non-argumentative"}

Do NOT add explanations, comments, or extra text.

Examples:

Sentence:
"The defendant was seen leaving the scene of the crime carrying the victim's briefcase."
JSON:
{"label": "premise"}

Sentence:
"Therefore, the defendant is liable to be convicted under Section 420 of the Indian Penal Code."
JSON:
{"label": "conclusion"}

Sentence:
"The present appeal is directed against the judgment dated 12.03.2019 passed by the High Court of Delhi."
JSON:
{"label": "non-argumentative"}

Now classify the following sentence.

Sentence:
\"\"\"@@SENT@@\"\"\"

JSON:
""".strip()

JUDGE_PROMPT_TEMPLATE = """
You are a highly capable legal argument mining JUDGE model.

You are given:
1. A single sentence from a legal judgment.
2. A candidate label predicted by a smaller model.

Your job:
- Analyze the sentence.
- Decide whether the candidate label is correct.
- If correct, KEEP it.
- If wrong, CHANGE it to the best label.

Possible labels (exactly one):
- "premise"
- "conclusion"
- "non-argumentative"

Output ONLY a single JSON object (no explanation), one of:

  {"label": "premise"}
  {"label": "conclusion"}
  {"label": "non-argumentative"}

Sentence:
\"\"\"@@SENT@@\"\"\"

Candidate label:
"@@CAND@@"

JSON:
""".strip()

# -------------------------
# LABEL EXTRACTION
# -------------------------
_json_re = re.compile(r"\{[^{}]*\blabel\b\s*:\s*\"?([a-zA-Z0-9\-\_ ]+)\"?[^{}]*\}", re.IGNORECASE)

def normalize_label(raw: str) -> str:
    if raw is None:
        return "non-argumentative"
    l = str(raw).strip().lower()
    if l in {"no relation", "no_relation", "no-relation", "none", "background", ""}:
        return "non-argumentative"
    if l not in CLASS_LABELS:
        if "premise" in l:
            return "premise"
        if "conclusion" in l:
            return "conclusion"
        if "non" in l or "none" in l or "background" in l:
            return "non-argumentative"
        return "non-argumentative"
    return l

def extract_label_from_text(text: str) -> str:
    if not isinstance(text, str):
        return "non-argumentative"
    text = text.strip()

    # 1) regex JSON-like
    m = _json_re.search(text)
    if m:
        try:
            label = normalize_label(m.group(1))
            return label
        except Exception:
            pass

    # 2) try parse some JSON inside
    try:
        start = text.find("{")
        end = text.rfind("}")
        if start != -1 and end != -1 and end > start:
            maybe = text[start:end+1]
            data = json.loads(maybe.replace("'", '"'))
            label = normalize_label(data.get("label", ""))
            return label
    except Exception:
        pass

    # 3) heuristics
    lower = text.lower()
    if "premise" in lower:
        return "premise"
    if "conclusion" in lower:
        return "conclusion"
    if "non-argumentative" in lower or "no relation" in lower or "none" in lower:
        return "non-argumentative"
    return "non-argumentative"

# -------------------------
# BATCHED GENERATION
# -------------------------
def generate_batch(
    prompts: List[str],
    tokenizer,
    model,
    max_new_tokens: int = 64,
    gen_batch_size: int = 4,
    max_prompt_tokens: int = 1024,
):
    results = []
    model_device = next(model.parameters()).device

    for i in range(0, len(prompts), gen_batch_size):
        batch_prompts = prompts[i : i + gen_batch_size]
        enc = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_prompt_tokens,
        ).to(model_device)

        with torch.inference_mode():
            gen_ids = model.generate(
                **enc,
                max_new_tokens=max_new_tokens,
                do_sample=False,       # greedy
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        decoded = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
        for orig, full in zip(batch_prompts, decoded):
            if full.startswith(orig):
                comp = full[len(orig):].strip()
            else:
                comp = full.strip()
            results.append(comp)

    return results

# -------------------------
# PIPELINE
# -------------------------
sentences = df_sentences["text"].astype(str).tolist()
n = len(sentences)

base_labels = [None] * n
final_labels = [None] * n

# -------------------------
# STEP 1: BASE MODEL (only base in GPU)
# -------------------------
print("Loading base model...")
tokenizer_base, model_base = load_model_and_tokenizer(BASE_MODEL_NAME)

base_prompts = [BASE_PROMPT_TEMPLATE.replace("@@SENT@@", s) for s in sentences]

print("Running base model in batches...")
for start in tqdm(range(0, n, BATCH_SIZE_BASE)):
    end = min(n, start + BATCH_SIZE_BASE)
    batch_prompts = base_prompts[start:end]

    outputs = generate_batch(
        batch_prompts,
        tokenizer_base,
        model_base,
        max_new_tokens=MAX_NEW_TOKENS_BASE,
        gen_batch_size=BATCH_SIZE_BASE,
        max_prompt_tokens=MAX_PROMPT_TOKENS,
    )

    for idx_in_batch, out in enumerate(outputs):
        idx = start + idx_in_batch
        lbl = extract_label_from_text(out)
        base_labels[idx] = lbl

print("Base pass done. Distribution:", Counter(base_labels))

# Free base model from GPU
del model_base
del tokenizer_base
gc.collect()
if DEVICE == "cuda":
    torch.cuda.empty_cache()
print("Freed base model from GPU.")

# -------------------------
# STEP 2: JUDGE MODEL (load only now)
# -------------------------
print("Loading judge model...")
tokenizer_judge, model_judge = load_model_and_tokenizer(JUDGE_MODEL_NAME)

judge_prompts = [
    JUDGE_PROMPT_TEMPLATE.replace("@@SENT@@", s).replace("@@CAND@@", base_labels[i])
    for i, s in enumerate(sentences)
]

print("Running judge model in batches...")
for start in tqdm(range(0, n, BATCH_SIZE_JUDGE)):
    end = min(n, start + BATCH_SIZE_JUDGE)
    batch_prompts = judge_prompts[start:end]

    outputs = generate_batch(
        batch_prompts,
        tokenizer_judge,
        model_judge,
        max_new_tokens=MAX_NEW_TOKENS_JUDGE,
        gen_batch_size=BATCH_SIZE_JUDGE,
        max_prompt_tokens=MAX_PROMPT_TOKENS,
    )

    for idx_in_batch, out in enumerate(outputs):
        idx = start + idx_in_batch
        judged = extract_label_from_text(out)
        final_labels[idx] = judged if judged in CLASS_LABELS else base_labels[idx]

    # periodic checkpoint
    if (start // BATCH_SIZE_JUDGE) % max(1, (CHECKPOINT_INTERVAL // BATCH_SIZE_JUDGE)) == 0:
        tmp = df_sentences.copy()
        tmp["base_llm_label"] = base_labels
        tmp["final_llm_label"] = final_labels
        tmp.to_csv(OUTPUT_CSV + ".partial", index=False, encoding="utf-8")
        print(f"Intermediate saved up to {end}/{n}")

print("Judge pass done.")

# Optional: free judge model
del model_judge
del tokenizer_judge
gc.collect()
if DEVICE == "cuda":
    torch.cuda.empty_cache()
print("Freed judge model from GPU.")

# -------------------------
# SAVE RESULTS
# -------------------------
df_sentences["base_llm_label"] = base_labels
df_sentences["final_llm_label"] = final_labels

print("Final distributions (base):", Counter(base_labels))
print("Final distributions (final):", Counter(final_labels))

df_sentences.to_csv(OUTPUT_CSV, index=False, encoding="utf-8")
print("Saved:", OUTPUT_CSV)


Device: cuda
Base model: Equall/Saul-7B-Instruct-v1
Judge model: mistralai/Mistral-7B-Instruct-v0.2
Total pairs: 40506
Unique sentences to classify: 1000
Loading base model...


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/653 [00:00<?, ?B/s]

2025-12-09 18:15:09.629100: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765304109.820324      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765304109.870797      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

model-00006-of-00006.safetensors:   0%|          | 0.00/4.25G [00:00<?, ?B/s]

model-00001-of-00006.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00006.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00003-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00005-of-00006.safetensors:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

Running base model in batches...


  0%|          | 0/250 [00:00<?, ?it/s]

Base pass done. Distribution: Counter({'non-argumentative': 585, 'premise': 301, 'conclusion': 114})
Freed base model from GPU.
Loading judge model...


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

Running judge model in batches...


  0%|          | 0/250 [00:00<?, ?it/s]

Intermediate saved up to 4/1000
Intermediate saved up to 204/1000
Intermediate saved up to 404/1000
Intermediate saved up to 604/1000
Intermediate saved up to 804/1000
Judge pass done.
Freed judge model from GPU.
Final distributions (base): Counter({'non-argumentative': 585, 'premise': 301, 'conclusion': 114})
Final distributions (final): Counter({'premise': 714, 'conclusion': 189, 'non-argumentative': 97})
Saved: /kaggle/working/sentence_argument_labels_with_big_judge.csv
