# Post-Finetuning Analysis

This notebook analyzes the finetuned Gemma model from the `gemma-text-to-sql` checkpoint directory. We'll:
1. Load the finetuned model
2. Analyze the training metrics
3. Visualize the training process
4. Test the model on some examples

In [1]:
import os
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel, PeftConfig
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams.update({'font.size': 12})

## 1. Load the Finetuned Model

We'll load the model from the checkpoint directory.

In [2]:
# Define checkpoint path
checkpoint_path = "/home/kosmas/projects/llm-in-cybersecurity/final-project/gemma-text-to-sql/checkpoint-24410"
base_model_id = "google/gemma-3-1b-pt"

# Check if CUDA is available
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
    
# Check if the device benefits from bfloat16
torch_dtype = torch.float16
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
    print("Using bfloat16")
else:
    print("Using float16")

CUDA available: True
CUDA device name: NVIDIA GeForce GTX 1080
Using float16


In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

# Define paths
base_model_id = "google/gemma-3-1b-pt"  
adapter_path = "/home/kosmas/projects/llm-in-cybersecurity/final-project/gemma-text-to-sql/checkpoint-24410"

# Setup quantization config (same as during training)
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_quant_storage=torch_dtype,
)

# Load base model with quantization
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    quantization_config=quantization_config,
    device_map="auto",
    torch_dtype=torch_dtype,
    attn_implementation="eager",
)

# Load adapter on top of the base model
model = PeftModel.from_pretrained(base_model, adapter_path)

# Load tokenizer (from original instruction tokenizer to be safe)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") # Load the Instruction Tokenizer to use the official Gemma template

In [11]:
# Load test data
test_df = pd.read_csv(
    "/home/kosmas/projects/llm-in-cybersecurity/final-project/datasets/test_emails.csv"
)
print(f"Test set: {test_df.shape[0]} examples")

# Display class distribution
print("\nClass distribution:")
print(test_df["label"].value_counts(normalize=True).apply(lambda x: f"{x:.2%}"))

# Create the same prompt template used during training
user_prompt = """Analyze the <EMAIL> and determine if it's PHISHING or LEGITIMATE.

<EMAIL>
Body: {body}
</EMAIL>"""


# Function to classify an email
def classify_email(email_body):
    # Format the input with our prompt template
    messages = [
        {
            "role": "user", 
            "content": user_prompt.format(body=email_body),
        }
    ]
    
    # Use the tokenizer's chat template with generation
    input_ids = tokenizer.apply_chat_template(
        messages, 
        return_tensors="pt"
    ).to(model.device)
    
    # Generate prediction
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            max_new_tokens=50,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    # DEBUG: Print token counts to understand what's happening
    print(f"Input length: {input_ids.shape[1]} tokens")
    print(f"Output length: {outputs.shape[1]} tokens")
    print(f"Generated: {outputs.shape[1] - input_ids.shape[1]} new tokens")
    
    # Extract only the generated text
    generated_text = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
    
    # DEBUG: Print the raw generated text
    print(f"Raw generated: '{generated_text}'")
    
    # If the model's output doesn't contain clear labels, extract them
    if not (("PHISHING" in generated_text) or ("LEGITIMATE" in generated_text)):
        # Try to extract from full response
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"Full response: '{full_response}'")
        
        # Extract the most likely classification
        if "PHISHING" in full_response:
            return "PHISHING"
        elif "LEGITIMATE" in full_response:
            return "LEGITIMATE"
        else:
            # Fall back to a simple keyword matching for common phishing indicators
            phishing_keywords = ["suspicious", "fraud", "scam", "phish", "fake"]
            for kw in phishing_keywords:
                if kw in generated_text.lower():
                    return "PHISHING"
            return "Unable to classify"
    
    # Return the cleaned result that has one of our target labels
    if "PHISHING" in generated_text:
        return "PHISHING"
    elif "LEGITIMATE" in generated_text:
        return "LEGITIMATE"
    
    return generated_text.strip()

# Test on a few examples
num_examples = 5
results = []

print("\nTesting model on examples:")
for i, row in test_df.sample(num_examples, random_state=42).iterrows():
    body = row["body"]
    true_label = "PHISHING" if row["label"] == 1 else "LEGITIMATE"

    # Truncate the email body for display
    display_body = body[:100] + "..." if len(body) > 100 else body

    print(f"\nExample {i}:")
    print(f"Email (truncated): {display_body}")
    print(f"True label: {true_label}")

    # Run model prediction
    prediction = classify_email(body)
    print(f"Model prediction: {prediction}")

    # Store result
    results.append(
        {
            "id": i,
            "true_label": true_label,
            "prediction": prediction,
            "correct": prediction.strip() == true_label,
        }
    )

# Calculate accuracy on this sample
correct = sum(1 for r in results if r["correct"])
accuracy = correct / len(results)
print(f"\nAccuracy on sample: {accuracy:.2%} ({correct}/{len(results)})")

Test set: 19702 examples

Class distribution:
label
1    51.31%
0    48.69%
Name: proportion, dtype: object

Testing model on examples:

Example 7615:
Email (truncated): corporate image can say a lot of things about your company . contemporary rhythm of life is too dyna...
True label: PHISHING




Input length: 161 tokens
Output length: 211 tokens
Generated: 50 new tokens
Raw generated: ''
Full response: 'user
Analyze the <EMAIL> and determine if it's PHISHING or LEGITIMATE.

<EMAIL>
Body: corporate image can say a lot of things about your company . contemporary rhythm of life is too dynamic . sometimes it takes only
several seconds for your company to be remembered or to be lost among competitors .
get your logo , business stationery or website done right now !
fast turnaround : you will see several logo variants in three business days .
satisfaction guaranteed : we provide unlimited amount of changes ; you can be sure : it will meet your needs and fit your business .
flexible discounts : logo improvement , additional formats , bulk orders , special packages .
creative design for competitive price : have a look at it right now !
</EMAIL>
'
Model prediction: PHISHING

Example 18456:
Email (truncated): Verner Kjærsgaard wrote:
> Hi list,
> 
> SuSE10.3, plain vanilla. Linksys 311 