In [17]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

#load the CSV file
df = pd.read_csv('./results/df_model_tests.csv')

print(f"Df before dropping [] values: \n{df.shape}")

#drop rows for which "generated_midi_data" is []
df = df[df['generated_midi_data'].apply(lambda x: x != '[]')]

print(f"Df after dropping [] values: \n{df.shape}")

#remove irrelevant columns
df = df.drop(['original_midi_data', 'generated_midi_data', 'individual'], axis=1)

#list of parameter columns
param_columns = ['threshold', 'fan_out', 'max_distance_atan', 'onset_threshold', 'frame_threshold', 'max_key_distance']

#function to perform analysis
def analyze_data(data, title_prefix):
    #1) Basic statistics and information
    print(f"{title_prefix} - Basic Statistics:")
    print(data.describe())
    print(data.info())

    #2) Find the row(s) with the lowest dissimilarity
    best_params = data.loc[data['dissimilarity'].idxmin()]
    print(f"{title_prefix} - Best parameters:")
    print(best_params)

    #3) Correlation analysis
    correlation_matrix = data.corr()
    plt.figure(figsize=(12, 10))
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm')
    plt.title(f'{title_prefix} - Correlation Heatmap')
    plt.savefig(f'./results_images/{title_prefix.lower().replace(" ", "_")}_correlation_heatmap.png')
    plt.close()

    #4) Pairplot
    sns.pairplot(data, vars=['dissimilarity'] + param_columns)
    plt.suptitle(f'{title_prefix} - Pairplot', y=1.02)
    plt.savefig(f'./results_images/{title_prefix.lower().replace(" ", "_")}_pairplot.png')
    plt.close()

    #5) Feature importance using PCA
    scaler = StandardScaler()
    scaled_data = scaler.fit_transform(data[param_columns])
    pca = PCA()
    pca.fit(scaled_data)

    feature_importance = pd.DataFrame({
        'feature': param_columns,
        'importance': pca.explained_variance_ratio_
    })
    feature_importance = feature_importance.sort_values('importance', ascending=False)

    plt.figure(figsize=(10, 6))
    sns.barplot(x='importance', y='feature', data=feature_importance)
    plt.title(f'{title_prefix} - Feature Importance')
    plt.savefig(f'./results_images/{title_prefix.lower().replace(" ", "_")}_feature_importance.png')
    plt.close()

    #6) Scatter plots for top 3 important features vs dissimilarity
    top_features = feature_importance['feature'][:3].tolist()
    fig, axes = plt.subplots(1, 3, figsize=(20, 6))
    for i, feature in enumerate(top_features):
        sns.scatterplot(x=feature, y='dissimilarity', data=data, ax=axes[i])
        axes[i].set_title(f'{feature} vs Dissimilarity')
    plt.suptitle(f'{title_prefix} - Top Features Scatter Plots', y=1.02)
    plt.tight_layout()
    plt.savefig(f'./results_images/{title_prefix.lower().replace(" ", "_")}_top_features_scatter.png')
    plt.close()

    #7) 3D scatter plot for top 3 features
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    scatter = ax.scatter(data[top_features[0]], data[top_features[1]], data[top_features[2]], 
                         c=data['dissimilarity'], cmap='viridis')
    ax.set_xlabel(top_features[0])
    ax.set_ylabel(top_features[1])
    ax.set_zlabel(top_features[2])
    plt.colorbar(scatter, label='Dissimilarity')
    plt.title(f'{title_prefix} - 3D Scatter Plot of Top 3 Features')
    plt.savefig(f'./results_images/{title_prefix.lower().replace(" ", "_")}_3d_scatter.png')
    plt.close()

    #8) Histogram of dissimilarity values
    plt.figure(figsize=(10, 6))
    sns.histplot(data['dissimilarity'], kde=True)
    plt.title(f'{title_prefix} - Distribution of Dissimilarity Values')
    plt.xlabel('Dissimilarity')
    plt.savefig(f'./results_images/{title_prefix.lower().replace(" ", "_")}_dissimilarity_distribution.png')
    plt.close()

#analyze full dataset
print("Analysis on full dataset:")
analyze_data(df, "Full Data")

#group by parameter combinations and calculate mean and std of dissimilarity
grouped_df = df.groupby(param_columns).agg({
    'dissimilarity': ['mean', 'std']
}).reset_index()
grouped_df.columns = param_columns + ['dissimilarity_mean', 'dissimilarity_std']

#analyze grouped dataset
print("\nAnalysis on grouped dataset (mean dissimilarity):")
analyze_data(grouped_df.drop('dissimilarity_std', axis=1).rename(columns={'dissimilarity_mean': 'dissimilarity'}), "Grouped Data")

Df before dropping [] values: 
(960, 10)
Df after dropping [] values: 
(951, 10)
Analysis on full dataset:
Full Data - Basic Statistics:
       dissimilarity   threshold     fan_out  max_distance_atan  \
count     951.000000  951.000000  951.000000         951.000000   
mean      308.313720    2.501052   14.994742          95.005258   
std       384.894960    0.408462    5.002628           5.002628   
min         0.252577    2.000000   10.000000          90.000000   
25%        52.900602    2.000000   10.000000          90.000000   
50%       102.923642    2.500000   10.000000         100.000000   
75%       381.686112    3.000000   20.000000         100.000000   
max      1000.000000    3.000000   20.000000         100.000000   

       onset_threshold  frame_threshold  max_key_distance  
count       951.000000       951.000000        951.000000  
mean          0.499685         0.499054         29.989485  
std           0.100052         0.100048         10.005256  
min           0.400