In [42]:
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from datasets import Dataset, DatasetDict
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 1. Load Data from All Models

In [43]:
# Define base directory and model names
base_dir = Path("../outputs/caption_inference")
models = ['zero_shot', 'base_random', 'base_vae', 'ft_random', 'ft_vae']
splits = ['train', 'validation', 'test']

# Load all predictions and metrics
predictions = {}
metrics = {}

for model in models:
    predictions[model] = {}
    metrics[model] = {}
    
    for split in splits:
        # Load predictions
        pred_path = base_dir / model / f"{split}_predictions.csv"
        if pred_path.exists():
            predictions[model][split] = pd.read_csv(pred_path)
            print(f"Loaded {model}/{split}: {len(predictions[model][split])} samples")
        
        # Load quality metrics
        metrics_path = base_dir / model / f"{split}_quality_metrics.json"
        if metrics_path.exists():
            with open(metrics_path, 'r') as f:
                metrics[model][split] = json.load(f)

print(f"\nLoaded data for {len(models)} models across {len(splits)} splits")

Loaded zero_shot/train: 1000 samples
Loaded zero_shot/validation: 100 samples
Loaded zero_shot/test: 100 samples
Loaded base_random/train: 1000 samples
Loaded base_random/validation: 100 samples
Loaded base_random/test: 100 samples
Loaded base_vae/train: 1000 samples
Loaded base_vae/validation: 100 samples
Loaded base_vae/test: 100 samples
Loaded ft_random/train: 1000 samples
Loaded ft_random/validation: 100 samples
Loaded ft_random/test: 100 samples
Loaded ft_vae/train: 1000 samples
Loaded ft_vae/validation: 100 samples
Loaded ft_vae/test: 100 samples

Loaded data for 5 models across 3 splits


## 2. Filter Empty Samples

Remove samples with empty or missing captions/predictions

In [44]:
filtered_predictions = {}
filtering_stats = {}

for model in models:
    filtered_predictions[model] = {}
    filtering_stats[model] = {}
    
    for split in splits:
        if split not in predictions[model]:
            continue
            
        df = predictions[model][split].copy()
        original_count = len(df)
        
        pred_col = 'prediction'
        df = df[df[pred_col].notna()]
        df = df[df[pred_col].astype(str).str.strip() != '']
        df = df[df[pred_col].astype(str).str.lower() != 'nan']
        df['prediction_length'] = df[pred_col].astype(str).apply(len)
        
        filtered_predictions[model][split] = df

In [45]:
# Merge all splits for each model and trim to first 500 samples
merged_predictions = {}
for model in models:
    merged_df = pd.DataFrame()
    
    for split in splits:
        if split in filtered_predictions[model]:
            merged_df = pd.concat([merged_df, filtered_predictions[model][split]], ignore_index=True)
    
    merged_predictions[model] = merged_df.head(500)
    print(f"Merged {model}: {len(merged_predictions[model])} samples")

Merged zero_shot: 500 samples
Merged base_random: 500 samples
Merged base_vae: 500 samples
Merged ft_random: 500 samples
Merged ft_vae: 500 samples


In [46]:
display(merged_predictions['base_random'].head())

Unnamed: 0,id,aspect_list,prediction,perplexity,llm_judge_score,llm_judge_reasoning,prediction_length
0,sample_0000,"punchy kick, happy, passionate, scary, eerie, ...",describe this song.\n\nThis up-tempo track is ...,8.472451,9.0,The description accurately incorporates all th...,634
1,sample_0001,"e-bass, fun, hip hop, rhythmic patter",info.\n\nThis track starts off with an infecti...,7.61044,5.0,The description is somewhat accurate but lacks...,631
2,sample_0002,"no percussion, punchy kick, electric guitar, v...",":\n\nIn this hauntingly beautiful composition,...",14.413651,8.0,The description offers a rich and detailed acc...,691
3,sample_0003,"acoustic drums, shimmering shakers, male voice...","only this:\n**Song Description**\n\nIn ""Slow S...",12.282628,7.0,The description is fairly accurate and coheren...,667
4,sample_0004,"keyboard accompaniment, shimmering cymbals, ba...",atmospheric.\n\nWhat an intriguing combination...,9.046556,6.0,The description attempts to incorporate most o...,652


## Analyze overall statistics

In [49]:
# Table: model, split, avg_prediction_length, avg_perplexity, median_perplexity, avg_llm_judge_score, median_llm_judge_score
summary_stats = []

for model, df in merged_predictions.items():
    avg_pred_length = df['prediction_length'].mean()
    avg_perplexity = df['perplexity'].mean()
    median_perplexity = df['perplexity'].median()
    avg_llm_judge_score = df['llm_judge_score'].mean()
    median_llm_judge_score = df['llm_judge_score'].median()
    
    summary_stats.append({
        'model': model,
        'split': split,
        'avg_prediction_length': avg_pred_length,
        'avg_perplexity': avg_perplexity,
        'median_perplexity': median_perplexity,
        'avg_llm_judge_score': avg_llm_judge_score,
        'median_llm_judge_score': median_llm_judge_score
    })

summary_df = pd.DataFrame(summary_stats)
display(summary_df)

Unnamed: 0,model,split,avg_prediction_length,avg_perplexity,median_perplexity,avg_llm_judge_score,median_llm_judge_score
0,zero_shot,test,695.348,9.310202,8.858738,5.306,5.0
1,base_random,test,660.634,11.850457,10.596377,7.111,7.0
2,base_vae,test,666.776,9.975931,9.486908,7.298,7.0
3,ft_random,test,446.91,13.391273,12.269531,5.769,6.0
4,ft_vae,test,385.442,14.550156,13.015625,6.717,7.0


In [51]:
# Create huggingface datasets for further analysis
for model in models:
    df = merged_predictions[model]
    hf_dataset = Dataset.from_pandas(df)
    hf_dataset_dict = DatasetDict({ 'test': hf_dataset })
    hf_dataset_dict.push_to_hub(f"bsienkiewicz/{model}-caption-inference-dataset")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

README.md:   0%|          | 0.00/519 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

README.md:   0%|          | 0.00/519 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

README.md:   0%|          | 0.00/519 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

README.md:   0%|          | 0.00/519 [00:00<?, ?B/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

README.md:   0%|          | 0.00/519 [00:00<?, ?B/s]