# Model Analysis for NeMo QA Chatbot

This notebook analyzes the fine-tuned model for the NeMo QA Chatbot.

In [None]:
import os
import sys
import json
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Add parent directory to path
sys.path.append(os.path.abspath('..'))

# Import NeMo QA modules
from nemo_qa.modeling.model import load_model
from nemo_qa.modeling.evaluation import evaluate_model, compute_exact_match, compute_f1_score

## Load Model

Let's load the fine-tuned model.

In [None]:
# Replace with your model path
model_path = '../models/finetuned/final_model'

# Load model
model = load_model(model_path)

print(f"Model loaded from {model_path}")

## Load Test Data

Let's load the test data.

In [None]:
# Replace with your test data path
test_data_path = '../data/datasets/test.jsonl'

# Load test data
test_data = []
with open(test_data_path, 'r') as f:
    for line in f:
        test_data.append(json.loads(line.strip()))

print(f"Loaded {len(test_data)} test samples")

## Evaluate Model

Now let's evaluate the model on the test data.

In [None]:
# Limit number of samples for faster evaluation
max_samples = 10
eval_data = test_data[:max_samples]

# Evaluate model
metrics = evaluate_model(
    model=model.model,
    test_data=eval_data,
    output_path='../evaluation_results.json',
    temperature=0.0
)

# Print metrics
print(f"Exact Match: {metrics['exact_match']:.4f}")
print(f"F1 Score: {metrics['f1_score']:.4f}")
print(f"Relevance Score: {metrics['relevance_score']:.4f}")

## Analyze Model Outputs

Let's analyze the model outputs in more detail.

In [None]:
# Load evaluation results
with open('../evaluation_results.json', 'r') as f:
    eval_results = json.load(f)

# Convert results to DataFrame
results_df = pd.DataFrame(eval_results['results'])

# Add length columns
results_df['question_length'] = results_df['question'].apply(len)
results_df['reference_length'] = results_df['reference'].apply(len)
results_df['prediction_length'] = results_df['prediction'].apply(len)

# Plot score distributions
fig, ax = plt.subplots(1, 3, figsize=(18, 5))

sns.histplot(results_df['exact_match'], bins=2, ax=ax[0])
ax[0].set_title('Exact Match Distribution')
ax[0].set_xlabel('Exact Match')
ax[0].set_ylabel('Count')

sns.histplot(results_df['f1_score'], bins=10, ax=ax[1])
ax[1].set_title('F1 Score Distribution')
ax[1].set_xlabel('F1 Score')
ax[1].set_ylabel('Count')

sns.histplot(results_df['relevance_score'], bins=10, ax=ax[2])
ax[2].set_title('Relevance Score Distribution')
ax[2].set_xlabel('Relevance Score')
ax[2].set_ylabel('Count')

plt.tight_layout()
plt.show()

## Example Predictions

Let's look at some example predictions.

# Sort by F1 score
results_df = results_df.sort_values('f1_score', ascending=False)

# Print best predictions
print("Best predictions:")
for i, row in results_df.head(3).iterrows():
    print(f"Question: {row['question']}")
    print(f"Reference: {row['reference']}")
    print(f"Prediction: {row['prediction']}")
    print(f"F1 Score: {row['f1_score']:.4f}")
    print("---\n")

# Print worst predictions
print("Worst predictions:")
for i, row in results_df.tail(3).iterrows():
    print(f"Question: {row['question']}")
    print(f"Reference: {row['reference']}")
    print(f"Prediction: {row['prediction']}")
    print(f"F1 Score: {row['f1_score']:.4f}")
    print("---\n")

## Interactive Testing

Let's test the model interactively.

In [None]:
def generate_response(question, temperature=0.7):
    """Generate response from the model."""
    prompt = f"Human: {question}\nAssistant:"
    
    # Generate response
    with torch.inference_mode():
        output = model.generate(
            prompt,
            max_length=512,
            temperature=temperature,
            top_p=0.9,
            top_k=50
        )
    
    # Extract response
    response = output['text'][0].replace(prompt, "").strip()
    
    return response

# Test with a custom question
question = "What is LoRA fine-tuning?"
response = generate_response(question)

print(f"Question: {question}")
print(f"Response: {response}")

## Analyze Model Performance by Question Type

Let's analyze the model performance by question type.

In [None]:
# Categorize questions by starting word
def get_question_type(question):
    """Get question type based on starting word."""
    question = question.strip().lower()
    
    if question.startswith('what'):
        return 'What'
    elif question.startswith('how'):
        return 'How'
    elif question.startswith('why'):
        return 'Why'
    elif question.startswith('when'):
        return 'When'
    elif question.startswith('where'):
        return 'Where'
    elif question.startswith('who'):
        return 'Who'
    elif question.startswith('which'):
        return 'Which'
    elif question.startswith('can') or question.startswith('do') or question.startswith('is') or question.startswith('are'):
        return 'Yes/No'
    else:
        return 'Other'

# Add question type column
results_df['question_type'] = results_df['question'].apply(get_question_type)

# Group by question type
question_type_metrics = results_df.groupby('question_type').agg({
    'exact_match': 'mean',
    'f1_score': 'mean',
    'relevance_score': 'mean',
    'question': 'count'
}).rename(columns={'question': 'count'}).reset_index()

# Plot performance by question type
plt.figure(figsize=(12, 8))
question_type_metrics = question_type_metrics.sort_values('count', ascending=False)

x = np.arange(len(question_type_metrics))
width = 0.25

fig, ax = plt.subplots(figsize=(14, 6))
ax.bar(x - width, question_type_metrics['exact_match'], width, label='Exact Match')
ax.bar(x, question_type_metrics['f1_score'], width, label='F1 Score')
ax.bar(x + width, question_type_metrics['relevance_score'], width, label='Relevance Score')

ax.set_xticks(x)
ax.set_xticklabels(question_type_metrics['question_type'])
ax.set_ylabel('Score')
ax.set_title('Model Performance by Question Type')
ax.legend()

plt.tight_layout()
plt.show()

# Print question type statistics
print(question_type_metrics.to_string(index=False))