In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from ast import literal_eval
from scipy.stats import ttest_ind, pearsonr, spearmanr
df = pd.read_csv('result/randomize_accuracy/randomize_data_new_kl_2.csv')
df["loss"] = df["loss"].apply(lambda x: [float(xx) for xx in literal_eval(x)])
df["final_loss"] = df["loss"].apply(lambda x: x[-1])
print(df.columns)

In [None]:
grouped = df.groupby(['iteration', 'model', 'pruning_style', 'pruning_ratio'])
for group_name, grouped_df in grouped:
    print(group_name)
    all_df = grouped_df[grouped_df['finetune'] == 'All']  # Replace 'modules' with actual label if different
    actual_df = grouped_df[grouped_df['finetune'] == 'Community']  # Replace 'modules' with actual label if different
    random_df = grouped_df[grouped_df['finetune'] == 'Random']  # Replace 'modules' with random label if different

    
    '''# Compare model L2 norm distributions
    plt.figure(figsize=(8, 6))
    sns.kdeplot(data = grouped_df, x='model_l2', hue="finetune", palette=["C0", "C1", "C2"]) 
    plt.title(f'Model L2 Norm Distribution | {group_name}')
    plt.xlabel('L2 Norm')
    plt.ylabel('Density')
    plt.show()'''

    # Correlation between L2 norm and final loss per epoch
    all_l2_loss_corr, _ = pearsonr(all_df['model_l2'].tolist(), all_df['final_loss'].tolist())
    actual_l2_loss_corr, _ = pearsonr(actual_df['model_l2'].tolist(), actual_df['final_loss'].tolist())
    random_l2_loss_corr, _ = pearsonr(random_df['model_l2'].tolist(), random_df['final_loss'].tolist())
    print(f'Correlation between L2 norm and final loss(All): {all_l2_loss_corr}')
    print(f'Correlation between L2 norm and final loss(Actual): {actual_l2_loss_corr}')
    print(f'Correlation between L2 norm and final loss(Random): {random_l2_loss_corr}')

    # T-test on accuracy between actual and Randoms
    t_stat, p_val = ttest_ind(actual_df['accuracy'], random_df['accuracy'])
    print(f'T-test for accuracy difference between actual and Randoms: t-stat={t_stat}, p-value={p_val}')



    # Loss and accuracy correlation
    all_loss_accuracy_corr, _ = spearmanr(all_df['final_loss'], all_df['accuracy'])
    actual_loss_accuracy_corr, _ = spearmanr(actual_df['final_loss'], actual_df['accuracy'])
    random_loss_accuracy_corr, _ = spearmanr(random_df['final_loss'], random_df['accuracy'])
    print(f'Spearman correlation between loss and accuracy (All): {all_loss_accuracy_corr}')
    print(f'Spearman correlation between loss and accuracy (Actual): {actual_loss_accuracy_corr}')
    print(f'Spearman correlation between loss and accuracy (Random): {random_loss_accuracy_corr}')
    
    fig, [axx,ax0,ax1,axll] = plt.subplots(figsize=(24, 6),ncols=4)

    # Scatter plot with hue for different modules
    #sns.scatterplot(data=grouped_df, x='model_l2', y='accuracy', hue='finetune')
    sns.kdeplot(data = grouped_df, x='model_l2', hue="finetune", palette=["C0", "C1", "C2"],ax=axx) 
    sns.boxplot(data=grouped_df, x='finetune', y='accuracy', palette=["C0", "C1", "C2"],ax=axll)

    sns.scatterplot(data=grouped_df[grouped_df["finetune"]!="All"], x='model_l2', y='accuracy', hue='finetune', palette=["C1", "C2"],ax=ax0)
    sns.scatterplot(data=grouped_df[grouped_df["finetune"]!="All"], x='model_l2', y='final_loss', hue='finetune', palette=["C1", "C2"],ax=ax1)

    # Fit and plot a line of best fit for each module subset
    #for finetune, subset in [("All",all_df,"C0"),("Community",actual_df,"C1" ),("Random", random_df,"C2")]:
    for finetune, subset, color in [("Community",actual_df,"C1" ),("Random", random_df,"C2")]:
        slope, intercept = np.polyfit(subset['model_l2'], subset['accuracy'], 1)
        ax0.plot(subset['model_l2'], slope * subset["model_l2"] + intercept, color=color)
        slope, intercept = np.polyfit(subset['model_l2'], subset['final_loss'], 1)
        ax1.plot(subset['model_l2'], slope * subset["model_l2"] + intercept, color=color)
        ax1.set_ylim(ymin=0)
    # Add labels and title
    ax0.set_xlabel('Model L2 Norm')
    ax0.set_ylabel('Accuracy')
    ax1.set_ylabel('Final Loss')
    fig.suptitle(f'{group_name}\nMagnitude of Weight (L2) w.r.t Accuracy and Loss')
    plt.show()

    print("+"*500)
