In [1]:
import os
import json
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import evaluate
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "final_layer_finetune"  # Update this to your saved model directory
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

model.to(device)
model.eval()


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=2048, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.05, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2048, out_features=8, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=8, out_features=2048, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=256, bia

In [4]:
rouge = evaluate.load("rouge")

In [5]:
def load_test_dataset(jsonl_file, max_input_length=1024, max_samples=None):
    system_prompt = "Summarize the following legal text."
    inputs, references = [], []

    with open(jsonl_file, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if max_samples and i >= max_samples:
                break
            item = json.loads(line)
            judgement = item["judgement"].strip()[:max_input_length]
            summary = item["summary"].strip()
            prompt = f"""### Instruction: {system_prompt}

### Input:
{judgement}

### Response:"""
            inputs.append(prompt)
            references.append(summary)
    return inputs, references


In [6]:
test_jsonl_path = r"processed-IN-Abs/test-data/full_summaries.jsonl"  # Update this to your test dataset path
test_inputs, test_references = load_test_dataset(test_jsonl_path, max_samples=10)  # change max_samples=None for full data

In [7]:
def generate_summary(text, max_new_tokens=256):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=2048).to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return output_text

In [9]:
predictions = []
start_time = time.time()

for inp in tqdm(test_inputs, desc="Running Inference"):
    pred = generate_summary(inp)
    predictions.append(pred)

inference_time = time.time() - start_time

Running Inference: 100%|██████████| 10/10 [12:15<00:00, 73.60s/it]


In [11]:
pred

'### Instruction: Summarize the following legal text.\n\n### Input:\nAppeal No. 251 of 1963.\nAppeal by special leave from the judgment and order dated March 20, 1957, of the Patna High Court in Civil Revision No. 40 of 1956.\nM. C. Setalvad, and R. C. Prasad, for the appellants.\nThe respondent did not appear.\nMarch 24, 1964.\nThe short question which arises in this appeal is whether the term "wages" as defined by section 2(vi) of the (No. 4 of 1936) (hereinafter called \'the Act \') includes wages fixed by an award in an industrial dispute between the employer and his employees.\nThis question has to be answered in the light of the definition prescribed by section 2(vi) before it was amended in 1958.\nThe subsequent amendment expressly provides by section 2(vi) (a) that any remuneration payable under any award or settlement between the parties or order of a Court, would be included in the main definition under section 2(vi).\nThe point which we have to decide in the present appeal i

In [10]:
rouge_result = rouge.compute(predictions=predictions, references=test_references)

In [12]:
print(f"\n🕒 Inference time for {len(test_inputs)} samples: {inference_time:.2f} seconds")
print("\n📊 ROUGE scores:")
for key, value in rouge_result.items():
    print(f"  {key}: {value:.4f}")


🕒 Inference time for 10 samples: 735.97 seconds

📊 ROUGE scores:
  rouge1: 0.3216
  rouge2: 0.0843
  rougeL: 0.1721
  rougeLsum: 0.2972
