In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
# Load the CSV file
df = pd.read_csv('./results/df_model_test_3600_test.csv')

# Drop rows where 'generated_midi_data' is an empty list
df = df[df['generated_midi_data'].apply(lambda x: x != '[]')]

# Line plots for median dissimilarity against columns
columns_to_analyze = ["threshold", "fan_out", "max_distance_atan", "max_key_distance"]

for column in columns_to_analyze:
    # Group by the column and calculate the median of 'dissimilarity'
    median_dissimilarity = df.groupby(column)["dissimilarity"].median()
    
    # Plotting the line plot
    plt.figure(figsize=(10, 6))
    plt.plot(median_dissimilarity.index, median_dissimilarity.values, marker='o', linestyle='-', color='b')
    plt.xlabel(f'{column.capitalize()} Value')
    plt.ylabel('Median Dissimilarity')
    plt.title(f'Median Dissimilarity vs {column.capitalize()} Value')
    plt.grid(True)
    
    # Save the plot
    plt.savefig(f'./results_images/median_dissimilarity_vs_{column}.png')
    plt.close()  # Close the plot to avoid overwriting

# Box plot of dissimilarity by threshold
plt.figure(figsize=(10, 6))
df.boxplot(column="dissimilarity", by="threshold", grid=False, showfliers=True)

# Adjust the plot labels and title
plt.xlabel('Threshold Value')
plt.ylabel('Dissimilarity')
plt.title('Boxplot of Dissimilarity by Threshold Value')
plt.suptitle('')  # Removes the automatic title from pandas boxplot
plt.grid(True, axis='y')

# Save the boxplot
plt.savefig('./results_images/boxplot_dissimilarity_by_threshold.png')
plt.close()

# Scatter plot for threshold vs fan_out (mean dissimilarity)
grouped_mean = df.groupby(['threshold', 'fan_out'])['dissimilarity'].mean().reset_index()
plt.figure(figsize=(10, 6))
scatter = plt.scatter(grouped_mean['threshold'], grouped_mean['fan_out'], 
                      c=grouped_mean['dissimilarity'], s=grouped_mean['dissimilarity'], 
                      cmap='viridis', edgecolor='k', alpha=0.7)
cbar = plt.colorbar(scatter)
cbar.set_label('Mean Dissimilarity')
plt.xlabel('Threshold')
plt.ylabel('Fan Out')
plt.title('Scatter Plot of Threshold vs Fan Out (Size & Color by Mean Dissimilarity)')
plt.grid(False)

# Save the scatter plot
plt.savefig('./results_images/scatter_mean_dissimilarity_threshold_vs_fanout.png')
plt.close()

# Scatter plot for threshold vs fan_out (median dissimilarity)
grouped_median = df.groupby(['threshold', 'fan_out'])['dissimilarity'].median().reset_index()
plt.figure(figsize=(10, 6))
scatter = plt.scatter(grouped_median['threshold'], grouped_median['fan_out'], 
                      c=grouped_median['dissimilarity'], s=grouped_median['dissimilarity'], 
                      cmap='viridis', edgecolor='k', alpha=0.7)
cbar = plt.colorbar(scatter)
cbar.set_label('Median Dissimilarity')
plt.xlabel('Threshold')
plt.ylabel('Fan Out')
plt.title('Scatter Plot of Threshold vs Fan Out (Size & Color by Median Dissimilarity)')
plt.grid(False)

# Save the scatter plot
plt.savefig('./results_images/scatter_median_dissimilarity_threshold_vs_fanout.png')
plt.close()

# Histogram of dissimilarity values with median line
median_value = df['dissimilarity'].median()
plt.figure(figsize=(10, 6))
sns.histplot(df['dissimilarity'], kde=True)
plt.axvline(median_value, color='red', linestyle='--', linewidth=2, label=f'Median: {median_value:.2f}')
plt.title('Full Data Distribution of Dissimilarity Values')
plt.xlabel('Dissimilarity')
plt.legend()

# Save the histogram plot
plt.savefig('./results_images/histogram_dissimilarity_with_median.png')
plt.close()


<Figure size 1000x600 with 0 Axes>