# Evaluation Notebook - Image â†’ Music Generation

Quick evaluation of the music generation system.

## What's in here
1. Load model
2. Generate samples
3. Check diversity
4. Check musicality
5. Visualize

In [None]:
import sys
sys.path.insert(0, '../src')

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json

from midi_tokenizer import MIDITokenizer
from model import ConditionalTransformer
from evaluate import calculate_diversity_metrics, calculate_musicality_metrics

%matplotlib inline
sns.set_style('whitegrid')

## 1. Load stuff

In [None]:
tokenizer = MIDITokenizer.load_vocab('../data/processed/vocab.json')
print(f"Vocabulary size: {tokenizer.vocab_size}")

checkpoint_path = '../models/best_model.pt'
if Path(checkpoint_path).exists():
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
    print(f"Best validation loss: {checkpoint['best_val_loss']:.4f}")
else:
    print("No checkpoint found. Train a model first!")

## 2. Generate samples

In [None]:
if Path(checkpoint_path).exists():
    config = checkpoint['config']['model']
    model = ConditionalTransformer(
        vocab_size=tokenizer.vocab_size,
        **config
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    emotions = ['happy', 'sad', 'calm', 'angry', 'surprised']
    samples_per_emotion = 5
    
    all_samples = []
    
    for emotion_id, emotion in enumerate(emotions):
        print(f"Generating {emotion} melodies...")
        for i in range(samples_per_emotion):
            with torch.no_grad():
                image_embed = torch.randn(1, 512)
                emotion_label = torch.tensor([emotion_id])
                
                tokens = model.generate(
                    image_embed,
                    emotion_label,
                    max_length=128,
                    temperature=0.9,
                    top_k=40,
                    bos_token=tokenizer.bos_id,
                    eos_token=tokenizer.eos_id
                )
                
                all_samples.append({
                    'emotion': emotion,
                    'tokens': tokens[0].tolist()
                })
    
    print(f"Generated {len(all_samples)} total samples")

## 3. Diversity

In [None]:
if Path(checkpoint_path).exists():
    token_sequences = [s['tokens'] for s in all_samples]
    diversity = calculate_diversity_metrics(token_sequences)
    
    print("\nDiversity Metrics:")
    print("-" * 50)
    for key, value in diversity.items():
        print(f"{key}: {value:.4f}")
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    metrics = list(diversity.keys())
    values = list(diversity.values())
    
    ax.bar(metrics, values, color='steelblue')
    ax.set_ylabel('Ratio')
    ax.set_title('Diversity Metrics')
    ax.set_ylim(0, 1)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()

## 4. Musicality

In [None]:
if Path(checkpoint_path).exists():
    musicality = calculate_musicality_metrics(token_sequences, tokenizer)
    
    print("\nMusicality Metrics:")
    print("-" * 50)
    for key, value in musicality.items():
        print(f"{key}: {value:.2f}")
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    axes[0].bar(['Avg Notes'], [musicality['avg_notes_per_sequence']], color='coral')
    axes[0].set_ylabel('Count')
    axes[0].set_title('Average Notes per Sequence')
    
    axes[1].bar(['Avg Interval'], [musicality['avg_melodic_interval']], color='lightgreen')
    axes[1].set_ylabel('Semitones')
    axes[1].set_title('Average Melodic Interval')
    
    plt.tight_layout()
    plt.show()

## 5. Emotion comparison

In [None]:
if Path(checkpoint_path).exists():
    emotion_stats = {}
    
    for emotion in emotions:
        emotion_samples = [s['tokens'] for s in all_samples if s['emotion'] == emotion]
        
        if emotion_samples:
            metrics = calculate_musicality_metrics(emotion_samples, tokenizer)
            emotion_stats[emotion] = metrics
    
    fig, ax = plt.subplots(1, 1, figsize=(12, 6))
    
    emotions_list = list(emotion_stats.keys())
    note_counts = [emotion_stats[e]['avg_notes_per_sequence'] for e in emotions_list]
    
    colors = ['gold', 'steelblue', 'lightgreen', 'tomato', 'purple']
    ax.bar(emotions_list, note_counts, color=colors)
    ax.set_ylabel('Average Notes')
    ax.set_title('Note Density by Emotion')
    ax.set_xlabel('Emotion')
    
    plt.tight_layout()
    plt.show()

## 6. Training history

In [None]:
if Path(checkpoint_path).exists() and 'history' in checkpoint:
    history = checkpoint['history']
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    epochs = range(1, len(history['train_loss']) + 1)
    ax1.plot(epochs, history['train_loss'], label='Train Loss', marker='o')
    ax1.plot(epochs, history['val_loss'], label='Val Loss', marker='s')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True)
    
    ax2.plot(epochs, history['learning_rates'], marker='o', color='green')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Learning Rate')
    ax2.set_title('Learning Rate Schedule')
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

## Summary

Quick evaluation metrics:
- Load trained models
- Generate music for different emotions
- Check diversity (n-gram uniqueness)
- Check musicality (note density, intervals)
- Visualize results