In [3]:
import os
import json
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel, PeftConfig
import evaluate
from tqdm import tqdm

In [4]:
# Step 1: Set model path and load PEFT config
peft_model_path = "output-qlora-latest"  # <-- Change to your QLoRA output directory
config = PeftConfig.from_pretrained(peft_model_path)

In [5]:
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

In [6]:
# Step 3: Setup 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

In [7]:
base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    quantization_config=bnb_config,
    device_map="auto"
)

In [8]:
# Step 5: Load LoRA weights into quantized base model
model = PeftModel.from_pretrained(base_model, peft_model_path)
model = model.merge_and_unload()  # Merges LoRA into the base model for inference
model.eval()



LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear4bit(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear4bit(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear4bit(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), e

In [9]:
# Step 6: Move model to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear4bit(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear4bit(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear4bit(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), e

In [10]:
# Step 7: Load ROUGE metric
rouge = evaluate.load("rouge")

In [11]:
# Step 8: Load test dataset
def load_test_dataset(jsonl_file, max_input_length=1024, max_samples=None):
    system_prompt = "Summarize the following legal text."
    inputs, references = [], []

    with open(jsonl_file, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if max_samples and i >= max_samples:
                break
            item = json.loads(line)
            judgement = item["judgement"].strip()[:max_input_length]
            summary = item["summary"].strip()
            prompt = f"""### Instruction: {system_prompt}

### Input:
{judgement}

### Response:"""
            inputs.append(prompt)
            references.append(summary)
    return inputs, references

In [12]:
# Step 9: Set test file path
test_jsonl_path = r"processed-IN-Abs/test-data/full_summaries.jsonl"
test_inputs, test_references = load_test_dataset(test_jsonl_path, max_samples=10)

In [13]:
# Step 10: Generate summaries
def generate_summary(text, max_new_tokens=256):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=2048).to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return output_text


In [14]:
# Step 11: Inference loop
predictions = []

start_time = time.time()
for inp in tqdm(test_inputs, desc="Running Inference"):
    pred = generate_summary(inp)
    predictions.append(pred)
inference_time = time.time() - start_time

Running Inference: 100%|██████████| 10/10 [01:16<00:00,  7.62s/it]


In [15]:
pred

'### Instruction: Summarize the following legal text.\n\n### Input:\nAppeal No. 251 of 1963.\nAppeal by special leave from the judgment and order dated March 20, 1957, of the Patna High Court in Civil Revision No. 40 of 1956.\nM. C. Setalvad, and R. C. Prasad, for the appellants.\nThe respondent did not appear.\nMarch 24, 1964.\nThe short question which arises in this appeal is whether the term "wages" as defined by section 2(vi) of the (No. 4 of 1936) (hereinafter called \'the Act \') includes wages fixed by an award in an industrial dispute between the employer and his employees.\nThis question has to be answered in the light of the definition prescribed by section 2(vi) before it was amended in 1958.\nThe subsequent amendment expressly provides by section 2(vi) (a) that any remuneration payable under any award or settlement between the parties or order of a Court, would be included in the main definition under section 2(vi).\nThe point which we have to decide in the present appeal i

In [16]:
rouge_result = rouge.compute(predictions=predictions, references=test_references)

In [17]:
# Step 13: Print results
print(f"\n🕒 Inference time for {len(test_inputs)} samples: {inference_time:.2f} seconds")
print("\n📊 ROUGE scores:")
for key, value in rouge_result.items():
    print(f"  {key}: {value:.4f}")


🕒 Inference time for 10 samples: 76.15 seconds

📊 ROUGE scores:
  rouge1: 0.3427
  rouge2: 0.0851
  rougeL: 0.1900
  rougeLsum: 0.3130
