In [None]:
import ast
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
from scipy.stats import ks_2samp, mannwhitneyu, wasserstein_distance, norm
from pathlib import Path

# Set random seed for reproducibility
seed = 42
np.random.seed(seed)

# Configure visualization
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 8)


# Training Dataset vs VAE-Generated Tags Comparison

This notebook compares the original training dataset with VAE-generated tags dataset to validate the generation quality and assess how well the VAE preserves the statistical properties of the original data.


## Load and Prepare Datasets


In [None]:

# Load training dataset
df_train = pd.read_csv("../data/musiccaps_tags_to_description_dataset.csv")
df_train = df_train.reset_index(drop=True)
df_train = df_train.fillna('')

# Parse tag strings to lists
df_train['aspect_list'] = df_train['aspect_list'].apply(lambda x: x.split(', ') if isinstance(x, str) and x else [])
df_train['instrument_tags'] = df_train['instrument_tags'].apply(lambda x: x.split(', ') if isinstance(x, str) and x else [])
df_train['genre_tags'] = df_train['genre_tags'].apply(lambda x: x.split(', ') if isinstance(x, str) and x else [])
df_train['mood_tags'] = df_train['mood_tags'].apply(lambda x: x.split(', ') if isinstance(x, str) and x else [])
df_train['tempo_tags'] = df_train['tempo_tags'].apply(lambda x: x.split(', ') if isinstance(x, str) and x else [])

print("Training Dataset Loaded")
print(f"  Shape: {df_train.shape}")
print(f"  Sample row:")
print(df_train.iloc[0])


In [None]:

# Load VAE-generated dataset
df_vae = pd.read_csv("../data/vae_mtg_tags/all.csv")
df_vae = df_vae.reset_index(drop=True)
df_vae = df_vae.fillna('')

# Parse tag strings and list representations
df_vae['aspect_list'] = df_vae['aspect_list'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])
df_vae['generated_tempo_tags'] = df_vae['generated_tempo_tags'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])
df_vae['generated_genre_tags'] = df_vae['generated_genre_tags'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])
df_vae['generated_mood_tags'] = df_vae['generated_mood_tags'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])
df_vae['generated_instrument_tags'] = df_vae['generated_instrument_tags'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])

print("\nVAE-Generated Dataset Loaded")
print(f"  Shape: {df_vae.shape}")
print(f"  Unique temperatures: {df_vae['temperature'].unique()}")
print(f"  Sample row:")
print(df_vae.iloc[0])


In [None]:

# For easier comparison, focus on a single temperature (use 1.0 as baseline)
df_vae_baseline = df_vae[df_vae['temperature'] == 1.0].reset_index(drop=True)

print(f"\nUsing VAE dataset at temperature=1.0")
print(f"  Shape: {df_vae_baseline.shape}")


## Compute Tag Distribution Statistics


In [None]:

# Define tag categories for comparison
tag_categories = {
    'Tempo': ('tempo_tags', 'generated_tempo_tags'),
    'Genre': ('genre_tags', 'generated_genre_tags'),
    'Mood': ('mood_tags', 'generated_mood_tags'),
    'Instrument': ('instrument_tags', 'generated_instrument_tags')
}

# Compute statistics for each category
stats_data = []

for cat_name, (train_col, vae_col) in tag_categories.items():
    train_counts = df_train[train_col].apply(len)
    vae_counts = df_vae_baseline[vae_col].apply(len)
    
    stats_data.append({
        'Category': cat_name,
        'Dataset': 'Training',
        'Mean': train_counts.mean(),
        'Median': train_counts.median(),
        'Std': train_counts.std(),
        'Min': train_counts.min(),
        'Max': train_counts.max(),
        'Q25': train_counts.quantile(0.25),
        'Q75': train_counts.quantile(0.75),
        'Count': len(train_counts)
    })
    
    stats_data.append({
        'Category': cat_name,
        'Dataset': 'VAE-Generated',
        'Mean': vae_counts.mean(),
        'Median': vae_counts.median(),
        'Std': vae_counts.std(),
        'Min': vae_counts.min(),
        'Max': vae_counts.max(),
        'Q25': vae_counts.quantile(0.25),
        'Q75': vae_counts.quantile(0.75),
        'Count': len(vae_counts)
    })

stats_df = pd.DataFrame(stats_data)

print("\n" + "="*100)
print("TAG COUNT STATISTICS COMPARISON")
print("="*100)
print(stats_df.to_string(index=False))


## Compare Category-wise Tag Counts


In [None]:

# Create side-by-side distribution plots
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

for idx, (cat_name, (train_col, vae_col)) in enumerate(tag_categories.items()):
    ax = axes[idx]
    
    train_counts = df_train[train_col].apply(len)
    vae_counts = df_vae_baseline[vae_col].apply(len)
    
    # Create bins
    max_count = max(train_counts.max(), vae_counts.max())
    bins = np.arange(0, max_count + 2)
    
    # Plot histograms
    ax.hist(train_counts, bins=bins, alpha=0.6, label='Training', color='blue', edgecolor='black')
    ax.hist(vae_counts, bins=bins, alpha=0.6, label='VAE-Generated', color='orange', edgecolor='black')
    
    ax.set_xlabel('Number of Tags')
    ax.set_ylabel('Frequency')
    ax.set_title(f'{cat_name} Tags Distribution')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Distribution comparison plots created successfully")


In [None]:

# Create box plots for detailed distribution comparison
fig, axes = plt.subplots(1, 4, figsize=(16, 5))

for idx, (cat_name, (train_col, vae_col)) in enumerate(tag_categories.items()):
    ax = axes[idx]
    
    train_counts = df_train[train_col].apply(len)
    vae_counts = df_vae_baseline[vae_col].apply(len)
    
    data_to_plot = [train_counts, vae_counts]
    bp = ax.boxplot(data_to_plot, labels=['Training', 'VAE-Generated'], patch_artist=True)
    
    # Color the boxes
    colors = ['lightblue', 'lightsalmon']
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
    
    ax.set_ylabel('Number of Tags')
    ax.set_title(f'{cat_name} Tags')
    ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("Box plot comparison created successfully")


## Analyze Tag Correlation Matrices


In [None]:

# Compute correlation matrices
train_tag_counts = pd.DataFrame({
    'Tempo': df_train['tempo_tags'].apply(len),
    'Genre': df_train['genre_tags'].apply(len),
    'Mood': df_train['mood_tags'].apply(len),
    'Instrument': df_train['instrument_tags'].apply(len)
})

vae_tag_counts = pd.DataFrame({
    'Tempo': df_vae_baseline['generated_tempo_tags'].apply(len),
    'Genre': df_vae_baseline['generated_genre_tags'].apply(len),
    'Mood': df_vae_baseline['generated_mood_tags'].apply(len),
    'Instrument': df_vae_baseline['generated_instrument_tags'].apply(len)
})

train_corr = train_tag_counts.corr()
vae_corr = vae_tag_counts.corr()

print("\n" + "="*100)
print("CORRELATION MATRIX - TRAINING DATASET")
print("="*100)
print(train_corr.round(3))

print("\n" + "="*100)
print("CORRELATION MATRIX - VAE-GENERATED DATASET")
print("="*100)
print(vae_corr.round(3))

print("\n" + "="*100)
print("CORRELATION DIFFERENCE (VAE - Training)")
print("="*100)
print((vae_corr - train_corr).round(3))


## Visualize Distribution Differences


In [None]:

# Create correlation heatmaps side-by-side
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Training dataset correlation
sns.heatmap(train_corr, annot=True, fmt='.3f', cmap='coolwarm', vmin=-1, vmax=1, 
            ax=axes[0], cbar_kws={'label': 'Correlation'})
axes[0].set_title('Training Dataset: Tag Count Correlations')

# VAE-generated dataset correlation
sns.heatmap(vae_corr, annot=True, fmt='.3f', cmap='coolwarm', vmin=-1, vmax=1,
            ax=axes[1], cbar_kws={'label': 'Correlation'})
axes[1].set_title('VAE-Generated Dataset: Tag Count Correlations')

# Difference
sns.heatmap(vae_corr - train_corr, annot=True, fmt='.3f', cmap='RdBu_r', vmin=-0.5, vmax=0.5,
            ax=axes[2], cbar_kws={'label': 'Difference'})
axes[2].set_title('Correlation Difference (VAE - Training)')

plt.tight_layout()
plt.show()

print("Correlation heatmaps created successfully")


In [None]:

# Create violin plots for detailed comparison
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

for idx, (cat_name, (train_col, vae_col)) in enumerate(tag_categories.items()):
    ax = axes[idx]
    
    train_counts = df_train[train_col].apply(len)
    vae_counts = df_vae_baseline[vae_col].apply(len)
    
    # Prepare data for violin plot
    data_for_violin = pd.DataFrame({
        'Training': train_counts,
        'VAE-Generated': vae_counts
    })
    
    # Create violin plot
    parts = ax.violinplot([train_counts, vae_counts], positions=[1, 2], 
                          showmeans=True, showmedians=True)
    
    ax.set_xticks([1, 2])
    ax.set_xticklabels(['Training', 'VAE-Generated'])
    ax.set_ylabel('Number of Tags')
    ax.set_title(f'{cat_name} Tags Distribution')
    ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("Violin plot comparison created successfully")


## Statistical Testing for Distribution Equivalence


In [None]:

# Perform statistical tests
test_results = []

for cat_name, (train_col, vae_col) in tag_categories.items():
    train_counts = df_train[train_col].apply(len).values
    vae_counts = df_vae_baseline[vae_col].apply(len).values
    
    # Kolmogorov-Smirnov test
    ks_stat, ks_pvalue = ks_2samp(train_counts, vae_counts)
    
    # Mann-Whitney U test (non-parametric)
    mw_stat, mw_pvalue = mannwhitneyu(train_counts, vae_counts)
    
    # Wasserstein distance
    ws_dist = wasserstein_distance(train_counts, vae_counts)
    
    test_results.append({
        'Category': cat_name,
        'KS Statistic': f"{ks_stat:.4f}",
        'KS P-value': f"{ks_pvalue:.6f}",
        'MW Statistic': f"{mw_stat:.1f}",
        'MW P-value': f"{mw_pvalue:.6f}",
        'Wasserstein Distance': f"{ws_dist:.4f}"
    })

test_results_df = pd.DataFrame(test_results)

print("\n" + "="*120)
print("STATISTICAL TESTS FOR DISTRIBUTION EQUIVALENCE")
print("="*120)
print("KS Test: Kolmogorov-Smirnov test (null: distributions are identical)")
print("MW Test: Mann-Whitney U test (null: medians are equal)")
print("Wasserstein Distance: Optimal transport distance between distributions")
print("="*120)
print(test_results_df.to_string(index=False))
print("="*120)
print("\nInterpretation:")
print("  - Low p-values (<0.05) indicate significant differences between distributions")
print("  - Lower Wasserstein distance indicates more similar distributions")


In [None]:

# Visualize Wasserstein distances
fig, ax = plt.subplots(figsize=(10, 6))

ws_distances = []
categories_list = []

for cat_name, (train_col, vae_col) in tag_categories.items():
    train_counts = df_train[train_col].apply(len).values
    vae_counts = df_vae_baseline[vae_col].apply(len).values
    ws_dist = wasserstein_distance(train_counts, vae_counts)
    ws_distances.append(ws_dist)
    categories_list.append(cat_name)

bars = ax.bar(categories_list, ws_distances, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])
ax.set_ylabel('Wasserstein Distance', fontsize=12)
ax.set_xlabel('Tag Category', fontsize=12)
ax.set_title('Distribution Distance: Training vs VAE-Generated (Temperature=1.0)', fontsize=14)
ax.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, distance in zip(bars, ws_distances):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{distance:.4f}', ha='center', va='bottom', fontsize=11)

plt.tight_layout()
plt.show()

print("Wasserstein distance visualization created successfully")


## Generate Comparison Report


In [None]:

# Generate comprehensive report
print("\n" + "="*120)
print("COMPREHENSIVE VAE GENERATION FIDELITY REPORT")
print("="*120)

print("\n### DATASET OVERVIEW ###")
print(f"Training Dataset: {len(df_train)} samples")
print(f"VAE-Generated Dataset (T=1.0): {len(df_vae_baseline)} samples")

print("\n### MEAN TAG COUNT COMPARISON ###")
for cat_name, (train_col, vae_col) in tag_categories.items():
    train_mean = df_train[train_col].apply(len).mean()
    vae_mean = df_vae_baseline[vae_col].apply(len).mean()
    diff = vae_mean - train_mean
    pct_diff = (diff / train_mean * 100) if train_mean > 0 else 0
    
    print(f"\n{cat_name}:")
    print(f"  Training Mean: {train_mean:.3f}")
    print(f"  VAE Mean:      {vae_mean:.3f}")
    print(f"  Difference:    {diff:+.3f} ({pct_diff:+.1f}%)")

print("\n### CORRELATION PRESERVATION ###")
correlation_diffs = []
for i in range(len(train_corr)):
    for j in range(i+1, len(train_corr)):
        diff = abs(vae_corr.iloc[i, j] - train_corr.iloc[i, j])
        correlation_diffs.append(diff)
        cat_i = train_corr.columns[i]
        cat_j = train_corr.columns[j]
        print(f"{cat_i} vs {cat_j}: Training={train_corr.iloc[i, j]:.3f}, VAE={vae_corr.iloc[i, j]:.3f}, Diff={diff:.3f}")

avg_corr_diff = np.mean(correlation_diffs)
print(f"\nAverage Correlation Difference: {avg_corr_diff:.4f}")

print("\n### DISTRIBUTION SIMILARITY ASSESSMENT ###")
for cat_name, (train_col, vae_col) in tag_categories.items():
    train_counts = df_train[train_col].apply(len).values
    vae_counts = df_vae_baseline[vae_col].apply(len).values
    ws_dist = wasserstein_distance(train_counts, vae_counts)
    ks_stat, ks_pvalue = ks_2samp(train_counts, vae_counts)
    
    print(f"\n{cat_name}:")
    print(f"  Wasserstein Distance: {ws_dist:.4f}")
    print(f"  KS Statistic: {ks_stat:.4f}, P-value: {ks_pvalue:.6f}")
    
    if ks_pvalue > 0.05:
        print(f"  ✓ Distributions are statistically similar (p > 0.05)")
    else:
        print(f"  ✗ Distributions differ significantly (p < 0.05)")

print("\n" + "="*120)
print("CONCLUSION")
print("="*120)
print("\nThe VAE-generated dataset shows how well the model preserves the statistical")
print("properties of the original training dataset. Lower Wasserstein distances and")
print("higher KS test p-values indicate better preservation of distribution characteristics.")
print("="*120)


In [None]:

# Analyze temperature effects on VAE generation
print("\n" + "="*120)
print("TEMPERATURE SENSITIVITY ANALYSIS")
print("="*120)

temperatures_in_data = sorted(df_vae['temperature'].unique())
temp_stats = []

for temp in temperatures_in_data:
    df_temp = df_vae[df_vae['temperature'] == temp].reset_index(drop=True)
    
    temp_row = {'Temperature': temp}
    for cat_name, (train_col, vae_col) in tag_categories.items():
        vae_counts = df_temp[vae_col].apply(len)
        temp_row[f'{cat_name}_Mean'] = vae_counts.mean()
    
    temp_stats.append(temp_row)

temp_stats_df = pd.DataFrame(temp_stats)
print("\nMean Tag Counts Across Different Temperatures:")
print(temp_stats_df.to_string(index=False))

# Plot temperature effects
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

for idx, (cat_name, (_, vae_col)) in enumerate(tag_categories.items()):
    ax = axes[idx]
    
    temp_means = []
    for temp in temperatures_in_data:
        df_temp = df_vae[df_vae['temperature'] == temp]
        mean_count = df_temp[vae_col].apply(len).mean()
        temp_means.append(mean_count)
    
    # Add training data as baseline
    train_mean = df_train[tag_categories[cat_name][0]].apply(len).mean()
    
    ax.plot(temperatures_in_data, temp_means, marker='o', linewidth=2, markersize=8, label='VAE Generated')
    ax.axhline(y=train_mean, color='r', linestyle='--', linewidth=2, label='Training Baseline')
    
    ax.set_xlabel('Temperature', fontsize=11)
    ax.set_ylabel('Mean Tag Count', fontsize=11)
    ax.set_title(f'{cat_name} Tags: Temperature Sensitivity', fontsize=12)
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nTemperature sensitivity analysis completed")
