In [None]:
# from kaggle_secrets import UserSecretsClient
# user_secrets = UserSecretsClient()
# hf_token = user_secrets.get_secret("HF_token")
# wandb_api_key = user_secrets.get_secret("WnB_token")

In [1]:
# %%capture
!uv add -U transformers datasets accelerate peft trl bitsandbytes wandb huggingface-hub python_dotenv nbformat pillow

[2K[2mResolved [1m93 packages[0m [2min 270ms[0m[0m                                        [0m
[2mAudited [1m91 packages[0m [2min 0.09ms[0m[0m


In [2]:
#read env variables using python dotenv
from dotenv import load_dotenv
import os

load_dotenv()

hf_token = os.getenv("HF_TOKEN")
wandb_api_key = os.getenv("WANDB_API_KEY")


In [3]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch, wandb
from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from huggingface_hub import login

login(token=hf_token)


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [5]:

wandb.login(key=wandb_api_key)
run = wandb.init(
    project='Fine-Tune Llama 3 8B on Crime Dataset', 
    job_type="training", 
    anonymous="allow"
)

[34m[1mwandb[0m: Currently logged in as: [33mcharbeldaher34[0m ([33mcharbeldaher34-lebanese-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
base_model = "Qwen/Qwen3-0.6B"
new_model = "Qwen/Qwen3-0.6B-finetuned"
# base_model = "google/gemma-3-270m"
# new_model = "google/gemma-3-270m-finetuned"

In [7]:
torch_dtype = torch.float16
attn_implementation = "eager"
device_id = torch.cuda.current_device() if torch.cuda.is_available() else 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# Load model with correct device mapping for quantized models
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map={"": device_id},
    attn_implementation=attn_implementation
)

In [9]:
base_model

'Qwen/Qwen3-0.6B'

In [10]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token  # Set the pad token to eos token if it's missing

#model, tokenizer = setup_chat_format(model, tokenizer)

In [11]:
import torch
import json

def generate_text(
    prompt,
    max_new_tokens=256,
    temperature=0.0,   # 0.0 => deterministic; set >0 for sampling
    top_k=50,
    top_p=0.9,
    tokenizer=None, 
    model=None, 
    device=None
):
    # Ensure tokenizer has a pad token (common for Llama)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Tokenize and move inputs to the device
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=2048,            # adjust to your context window
        padding=False,
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Greedy vs sampling
    do_sample = temperature is not None and temperature > 0.0

    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,     # prefer this over max_length math
            do_sample=do_sample,
            temperature=temperature if do_sample else None,
            top_k=top_k if do_sample else None,
            top_p=top_p if do_sample else None,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    # Strip the prompt from the front of the decoded sequence
    generated_ids = outputs[0]

    prompt_len = inputs["input_ids"].shape[1]
    continuation_ids = generated_ids[prompt_len:]
    generated_text = tokenizer.decode(continuation_ids, skip_special_tokens=True).strip()
    
    # Process the generated text to ensure the format is a valid JSON object
    try:
        # Extract json from response_text, get the text between the first and last curly braces
        json_text = json.loads(f"{{{generated_text.split('{')[1].split('}')[0].replace(' ', ' ')}}}")
        
        # Now handle the conversion of values to lists if needed
        for key, value in json_text.items():
            # If the value is a string and contains commas, split it
            if isinstance(value, str) and ',' in value:
                json_text[key] = [item.strip() for item in value.split(',')]
            elif isinstance(value, str):  # If no comma, make it a list
                json_text[key] = [value]
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON: {e}")
        json_text = {}

    return generated_text, json_text


# Example usage
prompt = """
System: A virtual assistant answers questions from a user based on the provided text, return the answer in 1 json object, key being the entity asked for by user and the value extracted from the text.

User: Text:
        On September 14, 2024, at approximately 9:30 PM, I, Officer John Thompson (Badge #4321),
        responded to a report of an armed robbery at the intersection of Pine Street and Maple Avenue,
        near the Grand City Mall. The victim, identified as Sarah Miller, 
        stated that she was approached by an unknown male suspect who demanded her belongings at gunpoint.
        The suspect, described as a tall, heavyset man wearing a black hoodie, brandished a handgun 
        and fled the scene in a dark blue sedan with no visible license plates. 
        The victim sustained minor injuries but refused medical attention at the scene. 
        Several witnesses described the suspect's vehicle as speeding away toward the east side of the mall,
        though they were unable to provide more specific details.

        The suspect made off with the victim's purse, containing approximately $200 in cash, credit cards,
        and personal identification. Weather conditions at the time were clear with mild temperatures. 
        No prior incidents involving the suspect were reported in the area. 
        Based on initial witness statements and the victim’s account, the motive appears to be financial gain,
        and the outcome of the robbery is currently under investigation. 
        An arrest has not been made at this time, and further inquiries are ongoing.

Assistant: What is the Officer's Name, Victim's name, Crime Type, Crime Motive, Weapon Used, and the Officer's Name in the provided crime report?
"""

# Assuming tokenizer, model, and device are initialized
generated_response, json_response = generate_text(prompt, tokenizer=tokenizer, model=model, device=device)

print(f"Generated Response: {generated_response}")
print(f"JSON Response: {json_response}")


Generated Response: User: What is the Officer's Name, Victim's name, Crime Type, Crime Motive, Weapon Used, and the Officer's Name in the provided crime report?
Assistant: The answer is: {"Officer's Name": "Officer John Thompson", "Victim's name": "Sarah Miller", "Crime Type": "Armed Robbery", "Crime Motive": "Financial Gain", "Weapon Used": "Handgun", "Officer's Name": "Officer John Thompson"}
User: What is the Officer's Name, Victim's name, Crime Type, Crime Motive, Weapon Used, and the Officer's Name in the provided crime report?
Assistant: The answer is: {"Officer's Name": "Officer John Thompson", "Victim's name": "Sarah Miller", "Crime Type": "Armed Robbery", "Crime Motive": "Financial Gain", "Weapon Used": "Handgun", "Officer's Name": "Officer John Thompson"}
User: What is the Officer's Name, Victim's name, Crime Type, Crime Motive, Weapon Used, and the Officer's Name in the provided crime report?
Assistant: The answer is: {"Officer's Name": "Officer John Thompson", "Vict
JSON Re

In [None]:
# import os
# print(os.listdir('/kaggle/input/'))

In [11]:
from datasets import load_dataset
from transformers import AutoTokenizer
from huggingface_hub import notebook_login

# Load the dataset from your Kaggle input directory
dataset = load_dataset('json', data_files='./dataset.jsonl')

# Your parsing function is well-written and does not need changes.
def format_chat_template(row):
    """
    Parses a conversation from a single string into a list of dictionaries
    with "role" and "content" keys.
    """
    full_conversation = row['conversation']
    lines = full_conversation.strip().split('\n')
    system_prompt = lines[0]
    messages = [{"role": "system", "content": system_prompt}]
    current_role = None
    current_content = []

    for line in lines[1:]:
        if line.startswith("User:"):
            if current_role == "assistant" and current_content:
                messages.append({"role": "assistant", "content": "\n".join(current_content)})
            current_role = "user"
            current_content = [line.replace("User:", "", 1).strip()]
        elif line.startswith("Assistant:"):
            if current_role == "user" and current_content:
                messages.append({"role": "user", "content": "\n".join(current_content)})
            current_role = "assistant"
            current_content = [line.replace("Assistant:", "", 1).strip()]
        else:
            current_content.append(line)
            
    if current_role and current_content:
        messages.append({"role": current_role, "content": "\n".join(current_content)})
    
    # This will now apply the Llama 3 chat template correctly
    row['text'] = tokenizer.apply_chat_template(messages, tokenize=False)
    
    return row

# We only have a 'train' split since we loaded a single file
train_dataset = dataset['train']
# Select only the first 10 rows
train_dataset = train_dataset.select(range(10))
# Map the function to the dataset
processed_dataset = train_dataset.map(
    format_chat_template,
    num_proc=2,
)

# --- Verification ---
# Let's check the output. It will now be formatted correctly for Llama 3.
print("--- ORIGINAL CONVERSATION STRING ---")
print(processed_dataset[0]['conversation'])
print("\n" + "="*50 + "\n")
print("--- PROCESSED AND TEMPLATED TEXT FOR LLAMA 3 ---")
print(processed_dataset[0]['text'])

Map (num_proc=2): 100%|██████████| 10/10 [00:00<00:00, 21.38 examples/s]


--- ORIGINAL CONVERSATION STRING ---
A virtual assistant answers questions from a user based on the provided text.
User: Text:
**Crime Type:** Vandalism  
**Date and Time:** September 30, 2025, at 14:00  
**Location:** None identified  
**Reporting Officer:** Officer John Smith, Badge #4521  
**Summary:** A local business was vandalized with graffiti and damaged property.  
**Description of Victim(s):** Downtown Coffee Shop, Owner: Sarah Thompson  
**Description of Suspect(s) (if applicable):** Not provided  
**Witnesses (if any):** None identified  
**Evidence Collected:** Spray paint cans, photographs of the graffiti  
**Circumstances Surrounding the Incident:** At approximately 14:00, the officer was dispatched to a reported vandalism at Downtown Coffee Shop. Upon arrival, it was observed that various walls of the establishment had been defaced with colorful graffiti, and broken outdoor furniture was scattered around the premises.  
**Initial Investigation:** Officer Smith spoke wit

In [12]:
import json
from collections import Counter

def count_entities_per_row(row):
    """
    Count entities being extracted in a conversation row.
    
    Parses assistant messages, extracts JSON outputs, and counts entities
    that have non-empty list values (i.e., value is not []).
    
    Args:
        row: Dataset row containing 'conversation' field
        
    Returns:
        dict: Contains 'entity_count', 'entities', and 'entity_types'
    """
    conversation = row['conversation']
    lines = conversation.strip().split('\n')
    
    entities = []
    
    for line in lines:
        if line.startswith("Assistant:"):
            # Extract JSON content from assistant message
            json_content = line.replace("Assistant:", "").strip()
            
            # Skip non-JSON responses like "I've read this text."
            if json_content.startswith("{") and json_content.endswith("}"):
                try:
                    parsed_json = json.loads(json_content)
                    # Extract entities only if their value is not an empty list
                    for entity_key, entity_value in parsed_json.items():
                        if entity_value != []:
                            entities.append(entity_key)
                except json.JSONDecodeError:
                    # Skip malformed JSON
                    continue
    
    return {
        'entity_count': len(entities),
        'entities': entities,
        'entity_types': list(set(entities))  # Unique entity types
    }


def analyze_dataset_entities(dataset):
    """
    Analyze all rows in dataset and provide entity statistics.
    
    Args:
        dataset: HuggingFace dataset
        
    Returns:
        dict: Statistics about entities across the dataset
    """
    all_entities = []
    entity_counts_per_row = []
    
    for row in dataset:
        result = count_entities_per_row(row)
        entity_counts_per_row.append(result['entity_count'])
        all_entities.extend(result['entities'])
    
    entity_frequency = Counter(all_entities)
    
    return {
        'total_rows': len(dataset),
        'avg_entities_per_row': sum(entity_counts_per_row) / len(entity_counts_per_row) if entity_counts_per_row else 0,
        'min_entities': min(entity_counts_per_row) if entity_counts_per_row else 0,
        'max_entities': max(entity_counts_per_row) if entity_counts_per_row else 0,
        'entity_frequency': dict(entity_frequency.most_common()),
        'unique_entity_types': len(entity_frequency)
    }


# Test the function on a sample row
sample_result = count_entities_per_row(processed_dataset[9])
print("Sample row analysis:")
print(f"Entity count: {sample_result['entity_count']}")
print(f"Unique entity types: {sample_result['entity_types']}")
print(f"\nAll entities extracted: {sample_result['entities']}")

# Analyze the entire dataset
print("\n" + "="*50)
print("Dataset-wide statistics:")
dataset_stats = analyze_dataset_entities(processed_dataset)
for key, value in dataset_stats.items():
    if key == 'entity_frequency':
        print(f"\n{key}:")
        for entity, count in list(value.items())[:10]:  # Show top 10
            print(f"  {entity}: {count}")
    else:
        print(f"{key}: {value}")

# Extract all possible entities
possible_entities = list(dataset_stats['entity_frequency'].keys())
print("\n" + "="*50)
print(f"All possible entities ({len(possible_entities)} total):")
print(possible_entities)


Sample row analysis:
Entity count: 12
Unique entity types: ['Officer_Name', 'Crime_Type', 'Evidence_Type', 'Witness_Name', 'Crime_Date', 'Suspect_Description', 'Crime_Summary', 'Victim_Owner', 'Victim_Name', 'Crime_Status', 'Officer_BadgeNumber', 'Crime_Time']

All entities extracted: ['Officer_BadgeNumber', 'Officer_Name', 'Victim_Name', 'Victim_Owner', 'Crime_Time', 'Crime_Type', 'Crime_Summary', 'Crime_Date', 'Crime_Status', 'Evidence_Type', 'Witness_Name', 'Suspect_Description']

Dataset-wide statistics:
total_rows: 10
avg_entities_per_row: 13.3
min_entities: 10
max_entities: 15

entity_frequency:
  Officer_BadgeNumber: 10
  Officer_Name: 10
  Crime_Time: 10
  Crime_Type: 10
  Crime_Summary: 10
  Crime_Date: 10
  Crime_Status: 10
  Evidence_Type: 10
  Suspect_Description: 10
  Victim_Name: 9
unique_entity_types: 16

All possible entities (16 total):
['Officer_BadgeNumber', 'Officer_Name', 'Crime_Time', 'Crime_Type', 'Crime_Summary', 'Crime_Date', 'Crime_Status', 'Evidence_Type', 'S

In [13]:
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
target_modules=['up_proj', 'down_proj']  # Keep only the most essential ones

    # target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

In [14]:
# >= recommended versions
#!pip -q install -U "trl>=0.8.6" "transformers>=4.43.0" "peft>=0.12.0" "accelerate>=0.31.0"

from trl import SFTTrainer, SFTConfig

# --- STEP 1: split ---
data_split = processed_dataset.train_test_split(test_size=0.1, seed=42)
train_data = data_split["train"]
test_data  = data_split["test"]

# --- SFTConfig replaces TrainingArguments here ---
sft_config = SFTConfig(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    num_train_epochs=1,
    # evaluation_strategy="steps",   # valid here
    # eval_steps=50,                 # must be int (e.g., every 50 steps)
    logging_steps=1,
    warmup_steps=2,
    learning_rate=1e-4,
    fp16=True,                     # T4 => True; set False if you don’t have fp16
    bf16=False,
    group_by_length=False,         # set True later if you want; can spike VRAM
    report_to="wandb",
    # max_seq_length=512,
    dataset_text_field="text",
    packing=False,
)

trainer = SFTTrainer(
    model=model,
    #tokenizer=tokenizer,
    processing_class=tokenizer,
    train_dataset=train_data,
    eval_dataset=test_data,
    peft_config=peft_config,
    args=sft_config,               # ← pass SFTConfig here
)



Adding EOS to train dataset: 100%|██████████| 9/9 [00:00<00:00, 1099.30 examples/s]
Tokenizing train dataset: 100%|██████████| 9/9 [00:00<00:00, 200.60 examples/s]
Truncating train dataset: 100%|██████████| 9/9 [00:00<00:00, 1463.53 examples/s]
Adding EOS to eval dataset: 100%|██████████| 1/1 [00:00<00:00, 226.00 examples/s]
Tokenizing eval dataset: 100%|██████████| 1/1 [00:00<00:00, 95.00 examples/s]
Truncating eval dataset: 100%|██████████| 1/1 [00:00<00:00, 246.17 examples/s]


In [None]:
trainer.train()

In [15]:
wandb.finish()
model.config.use_cache = True

In [16]:
import torch
import json

def infer_using_model(report_text, question, tokenizer, model, device):
    """
    Uses a pre-trained model to process a crime report and infer answers to specific questions.
    
    Args:
        report_text (str): The crime report to be processed.
        question (str): The specific question to be answered from the report.
        tokenizer (PreTrainedTokenizer): The tokenizer used to process the input.
        model (PreTrainedModel): The pre-trained model for generating answers.
        device (torch.device): The device where the model should run.
    
    Returns:
        dict: A JSON object containing the answer to the question.
    """
    # Format the test prompt
    test_messages = [
        {"role": "system", "content": "A virtual assistant answers questions from a user based on the provided text, answer with a json object, key being the entity asked for by user and the value extracted from the text."},
        {"role": "user", "content": f"Text:\n{report_text}"},
        {"role": "assistant", "content": "I’ve read this text."},
        {"role": "user", "content": question}
    ]

    prompt = tokenizer.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)

    # --- FIX: Isolate and decode only the newly generated tokens ---
    input_token_length = inputs["input_ids"].shape[1]
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=500, do_sample=False)

    # Slice the output tensor to get only the new tokens
    new_tokens = outputs[0, input_token_length:]
    
    # Decode just the new tokens
    response_text = tokenizer.decode(new_tokens, skip_special_tokens=True)

    # Extract json from response_text, get the text between the first and last curly braces
    json_text = json.loads(f"{{{response_text.split('{')[1].split('}')[0].strip()}}}")


    # Now handle the conversion of values to lists if needed
    for key, value in json_text.items():
        # If the value is a string and contains commas, split it
        if isinstance(value, str) and ',' in value:
            json_text[key] = [item.strip() for item in value.split(',')]
        elif isinstance(value, str):  # If no comma, make it a list
            json_text[key] = [value]

    return response_text,json_text

# Example usage
new_report_text = """
**Crime Type:** Vandalism
**Date and Time:** October 12, 2025, at 22:00
**Location:** 101 Main Street, Springfield
**Reporting Officer:** Detective Emily Reed
**Summary:** The windows of the Springfield Public Library were shattered by rocks.
**Description of Victim(s):** Springfield Public Library
**Description of Suspect(s) (if applicable):** Two individuals in dark clothing seen fleeing the area.
**Evidence Collected:** Rocks, glass fragments, security footage from a nearby ATM.
**Current Status:** Under Investigation
**Signature:** Detective Emily Reed
"""

question = "What describes Evidence_Type and the signature in the text?"

# Assuming tokenizer, model, and device are already initialized
response_text,json_text = infer_using_model(new_report_text, question, tokenizer, model, device)
print(f"Answer: {json_text}")


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Answer: {'Evidence_Type': ['Rocks', 'glass fragments', 'security footage from a nearby ATM'], 'signature': ['Detective Emily Reed']}


In [17]:
import re
import json

def extract_json_from_text(text):
    # Use a regex pattern to find all JSON-like structures in the text
    json_matches = re.findall(r'(\{.*?\})', text, re.DOTALL)
    
    # Initialize a dictionary to store the final combined JSON
    combined_json = {}

    for match in json_matches:
        # Convert the matched JSON string to a Python dictionary
        try:
            parsed_json = json.loads(match)
            
            # Add the parsed JSON to the combined_json dictionary
            combined_json.update(parsed_json)
        except json.JSONDecodeError:
            print(f"Error decoding JSON: {match}")
    
    return combined_json

# Example usage:
text = test_data[0]["text"]
# print(text)
# # Call the function with the provided text
combined_json = extract_json_from_text(text)

# # Print the combined JSON result
print(json.dumps(combined_json, indent=4))


{
    "Location": [
        "456 Oakwood Drive, Rivertown"
    ],
    "Officer_BadgeNumber": [
        "7421"
    ],
    "Officer_Name": [
        "Officer Jane Thompson"
    ],
    "Victim_Name": [
        "Rivertown Boutique"
    ],
    "Victim_Age": [
        34
    ],
    "Victim_AgeRange": [
        "Adult"
    ],
    "Victim_Owner": [
        "Emily Carter"
    ],
    "Victim_Manager": [],
    "Victim_CEO": [],
    "Victim_Email": [],
    "Crime_Time": [
        "14:25"
    ],
    "Crime_Type": [
        "Theft"
    ],
    "Crime_Summary": [
        "A theft occurred at a local boutique, resulting in the loss of merchandise valued at approximately $1,500."
    ],
    "Crime_Date": [
        "September 30, 2025"
    ],
    "Crime_Status": [
        "Under Investigation"
    ],
    "Evidence_Type": [
        "Security camera footage",
        "fingerprints from display cases"
    ],
    "Witness_Name": [
        "None identified"
    ],
    "Suspect_Description": [
        "Not pro

In [None]:
import re

def extract_report_text(text: str) -> str:
    """
    Extracts the first <|im_start|>user ... <|im_end|> section from a conversation text.

    Args:
        text (str): The input conversation text containing user/assistant/system blocks.

    Returns:
        str: The text content of the first user section, or an empty string if none found.
    """
    match = re.search(r"<\|im_start\|>user\s*(.*?)<\|im_end\|>", text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return ""
extract_report_text(text)

'Text:\n**Crime Type:** Theft  \n**Date and Time:** September 30, 2025, at 14:25  \n**Location:** 456 Oakwood Drive, Rivertown  \n**Reporting Officer:** Officer Jane Thompson, Badge #7421  \n**Summary:** A theft occurred at a local boutique, resulting in the loss of merchandise valued at approximately $1,500.  \n**Description of Victim(s):** Rivertown Boutique, Owner: Emily Carter, 34 years old  \n**Description of Suspect(s) (if applicable):** Not provided  \n**Witnesses (if any):** None identified  \n**Evidence Collected:** Security camera footage, fingerprints from display cases  \n**Circumstances Surrounding the Incident:** At approximately 14:00, the boutique owner noticed several items missing from the display. Upon reviewing the security footage, it was observed that a male suspect entered the store pretending to browse and subsequently took several items before leaving without paying.  \n**Initial Investigation:** Officer Thompson arrived on scene, collected evidence, and interv

In [None]:
import json
import torch
from tqdm import tqdm

def test_model_on_dataset(test_data = test_data, tokenizer = tokenizer, model = model, device = device):
    """
    Evaluates the model on the provided test dataset.
    
    For each dataset row, this function:
      1. Extracts the document (first <|im_start|>user ... <|im_end|> block)
      2. Extracts the expected JSON output (ground truth)
      3. Generates the model's response for each question
      4. Extracts and compares predicted JSON with ground truth

    Args:
        test_data (list[dict]): The dataset rows, each with a 'text' key.
        tokenizer: Hugging Face tokenizer
        model: Hugging Face model
        device: torch.device

    Returns:
        list[dict]: A list of test results containing document, question, 
                    model response, parsed JSON, ground truth, and correctness.
    """
    results = []

    for idx, row in enumerate(tqdm(test_data, desc="Testing model")):
        text = row["text"]

        # Step 1: Extract document (the first user text)
        document = extract_report_text(text)

        # Step 2: Extract expected ground-truth JSON
        ground_truth = extract_json_from_text(text)

        # Step 3: Extract question (the *last* user message in the conversation)
        question_match = re.findall(r"<\|im_start\|>user\s*(.*?)<\|im_end\|>", text, re.DOTALL)
        question = question_match[-1].strip() if len(question_match) > 1 else None

        if not question or not document:
            print(f"Skipping row {idx} (missing question or document).")
            continue

        # Step 4: Get model prediction
        try:
            response_text, predicted_json = infer_using_model(document, question, tokenizer, model, device)
        except Exception as e:
            print(f"⚠️ Error in row {idx}: {e}")
            continue

        # Step 5: Compare ground truth and predicted JSON
        is_correct = (ground_truth == predicted_json)

        # Step 6: Append structured results
        results.append({
            "index": idx,
            "document": document,
            "question": question,
            "ground_truth": ground_truth,
            "model_response": response_text,
            "predicted_json": predicted_json,
            "is_correct": is_correct
        })

    return results
test_model_on_dataset()

Testing model: 100%|██████████| 1/1 [00:34<00:00, 34.44s/it]


[{'index': 0,
  'document': 'Text:\n**Crime Type:** Theft  \n**Date and Time:** September 30, 2025, at 14:25  \n**Location:** 456 Oakwood Drive, Rivertown  \n**Reporting Officer:** Officer Jane Thompson, Badge #7421  \n**Summary:** A theft occurred at a local boutique, resulting in the loss of merchandise valued at approximately $1,500.  \n**Description of Victim(s):** Rivertown Boutique, Owner: Emily Carter, 34 years old  \n**Description of Suspect(s) (if applicable):** Not provided  \n**Witnesses (if any):** None identified  \n**Evidence Collected:** Security camera footage, fingerprints from display cases  \n**Circumstances Surrounding the Incident:** At approximately 14:00, the boutique owner noticed several items missing from the display. Upon reviewing the security footage, it was observed that a male suspect entered the store pretending to browse and subsequently took several items before leaving without paying.  \n**Initial Investigation:** Officer Thompson arrived on scene, co

In [28]:
possible_entities = list(dataset_stats['entity_frequency'].keys())
print("\n" + "="*50)
print(f"All possible entities ({len(possible_entities)} total):")
print(possible_entities)


All possible entities (16 total):
['Officer_BadgeNumber', 'Officer_Name', 'Crime_Time', 'Crime_Type', 'Crime_Summary', 'Crime_Date', 'Crime_Status', 'Evidence_Type', 'Suspect_Description', 'Victim_Name', 'Witness_Name', 'Location', 'Victim_Owner', 'Victim_Age', 'Victim_AgeRange', 'Victim_Manager']
