In [None]:
import re
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from peft import PeftModel

model_name_or_path = "MTCFLLM-llama-3.2-fine-tuned-model"  # saved both tokenizer + adapter
base_model_id       = "meta-llama/Llama-3.2-1B"     # the *original* upstream

# 1) load tokenizer (it already has the 7 new tokens)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

# 2) load the *base* LLaMA, in bfloat16
base = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# 3) resize
base.resize_token_embeddings(len(tokenizer))

# 4) load adapter weights on top
model = PeftModel.from_pretrained(
    base,
    model_name_or_path,
    torch_dtype=torch.bfloat16,
)

model.eval()

# after model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)



In [None]:
# ─── SPECIAL TOKENS SETUP ────────────────────────────────────────────────────────
special_tokens = [
    "<Speed and Stopping Violations>",
    "<Right-of-Way and Traffic Control Violations>",
    "<Lane and Direction Violations>",
    "<Maneuvering and Signaling Errors>",
    "<General Unsafe Driving>",
    "<NO HAZARDOUS ACTION>",
    "<BOTH DRIVERS TOOK HAZARDOUS ACTION>"
]
special_token_ids = {
    token: tokenizer.convert_tokens_to_ids(token)
    for token in special_tokens
}



# ─── HELPERS ─────────────────────────────────────────────────────────────────────
def extract_prompt(text: str) -> str:
    """Return everything up to and including the assistant marker."""
    marker = "<|start_header_id|>assistant<|end_header_id|>\n"
    idx = text.find(marker)
    return text[: idx + len(marker)] if idx != -1 else text

def extract_ground_truth(text: str) -> str:
    """
    Grab the line immediately after the assistant marker and before <|eot_id|>.
    """
    pattern = r"<\|start_header_id\|>assistant<\|end_header_id\|>\n(.*?)\n<\|eot_id\|>"
    match = re.search(pattern, text, re.DOTALL)
    return match.group(1).strip() if match else None

# ─── EVALUATION LOOP ────────────────────────────────────────────────────────────
predicted_tokens = []
ground_truth_tokens = []

for sample in test_data:
    text = sample["text"]
    true_token = extract_ground_truth(text)
    if true_token is None:
        continue

    prompt = extract_prompt(text)
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)

    # Use AMP autocast to ensure activations are in bfloat16
    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits[:, -1, :]                 # shape (1, vocab_size)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        # pick the special token with highest probability
        special_probs = {
            tok: probs[0, tok_id].item()
            for tok, tok_id in special_token_ids.items()
        }
        pred = max(special_probs, key=special_probs.get)

    predicted_tokens.append(pred)
    ground_truth_tokens.append(true_token)

# ─── METRICS & PLOTTING ─────────────────────────────────────────────────────────
# Confusion matrix
cm = confusion_matrix(
    ground_truth_tokens, predicted_tokens,
    labels=special_tokens
)
disp = ConfusionMatrixDisplay(
    confusion_matrix=cm,
    display_labels=special_tokens
)
plt.figure(figsize=(10, 8))
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45, ax=plt.gca())
plt.title("Confusion Matrix for Hazardous Action Prediction")
plt.tight_layout()
plt.show()

# Classification report
print("Classification Report:\n")
print(classification_report(
    ground_truth_tokens, predicted_tokens,
    labels=special_tokens, zero_division=0
))

# Create a mapping from token to its ID in the tokenizer vocabulary
special_token_ids = {token: tokenizer.convert_tokens_to_ids(token) for token in special_tokens}

def extract_prompt(text):
    """
    Extracts the prompt from the test sample text.
    The prompt is everything up to and including the assistant marker.
    """
    marker = "<|start_header_id|>assistant<|end_header_id|>\n"
    idx = text.find(marker)
    if idx != -1:
        # Return text including the marker so the model knows to generate right after it
        return text[: idx + len(marker)]
    else:
        return text

def extract_ground_truth(text):
    """
    Extracts the ground truth token from the test sample text.
    It is assumed to be the token following the assistant marker and before the next <|eot_id|>.
    """
    pattern = r"<\|start_header_id\|>assistant<\|end_header_id\|>\n(.*?)\n<\|eot_id\|>"
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return None

# Dictionary to accumulate probability lists for each token.
probabilities = {token: [] for token in special_tokens}

model.eval()
# Loop over the test data samples (assuming test_data is a list of dicts with key 'text')
for sample in test_data:
    text = sample['text']
    
    # Extract the input prompt from the text.
    prompt = extract_prompt(text)
    
    # Tokenize the prompt and move to GPU.
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to('cuda')
    attention_mask = torch.ones_like(input_ids, dtype=torch.long).to('cuda')
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits[:, -1, :]  # logits for the next token
        probs = torch.nn.functional.softmax(logits, dim=-1)
        
        # For each special token, store its probability
        for token in special_tokens:
            token_id = special_token_ids[token]
            probabilities[token].append(probs[0, token_id].item())

# Prepare data for box plot: a list of probability lists in the order of special_tokens
data_to_plot = [probabilities[token] for token in special_tokens]

# Create a box plot to show the probability distribution for each token
plt.figure(figsize=(12, 6))
plt.boxplot(data_to_plot, labels=special_tokens, showfliers=False)
plt.xticks(rotation=45, ha="right")
plt.ylabel("Probability")
plt.title("Probability Distribution for Each Predicted Token")
plt.tight_layout()
plt.show()

import numpy as np

# If each sub‑list is of equal length, you can convert to a 2D array:
arr = np.array(data_to_plot)
np.save("data_to_plot.npy", arr)