In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
def compare_phylo_cv_results(ds,encoding_method,pre_gs_results_df,post_gs_results_df,atts_of_intrst,models_to_keep=None,max_threshold=100, suffix='gs'):
    
    if models_to_keep != None:
        post_gs_results_df = post_gs_results_df[post_gs_results_df['Model'].isin(models_to_keep)]
    else:
        models_to_keep = post_gs_results_df['Model'].unique()
    
    post_gs_results_df_copy = post_gs_results_df.copy()
    post_gs_results_df_copy['Model'] = post_gs_results_df_copy['Model'] + '_' + suffix
    
    for attr in atts_of_intrst:
        
        # Group by `Relation_Handling` and `Model`, calculate the mean of `R2`, and reset the index
        model_att = pre_gs_results_df[pre_gs_results_df['Model'].isin(models_to_keep)]
        model_att = pd.concat([model_att,post_gs_results_df_copy],ignore_index=True)
        model_att = model_att[model_att['Threshold']<=max_threshold]
        # Get unique relation handling methods
        relation_handling_methods = model_att['Relation_Handling'].unique()

        for method in relation_handling_methods:
            # Filter for 'leave_out' method and drop NaN values
            method_df = model_att[model_att['Relation_Handling'] == method].dropna(subset=attr)
            #method_df = method_df[method_df['Model'] != 'lr']

            # Create the plot using Seaborn
            plt.figure(figsize=(16, 8))  # Increase width for better clarity
            sns.lineplot(data=method_df, x='Threshold', y=attr, hue='Model')

            # Add labels and title with larger font size
            plt.xlabel(f'Percentile Threshold\n(Using "{method}" Threshold Handling)', fontsize=20)
            plt.ylabel(attr, fontsize=20)
            plt.xticks(fontsize=20)
            plt.yticks(fontsize=20)

            # Place the legend in the bottom left corner
            plt.legend(loc='lower left', fontsize=16)

            # Show the plot
            plt.tight_layout()

            # Save the figure as an SVG file
            plt.savefig(f'./{ds}_{encoding_method}_pre_v_post_gs_analysis_{attr}_{method}_handling_performance_trend.svg', format='svg')  # You can change the filename if needed

            plt.show()
            
            for model_type in models_to_keep:
                temp_list = [model_type,f'{model_type}_{suffix}']
                model_df = method_df[method_df['Model'].isin(temp_list)]

                # Create the plot using Seaborn
                plt.figure(figsize=(16, 8))  # Increase width for better clarity
                sns.lineplot(data=model_df, x='Threshold', y=attr, hue='Model')

                # Add labels and title with larger font size
                plt.xlabel(f'Percentile Threshold\n(Using "{method}" Threshold Handling)', fontsize=20)
                plt.ylabel(attr, fontsize=20)
                plt.xticks(fontsize=20)
                plt.yticks(fontsize=20)

                # Place the legend in the bottom left corner
                plt.legend(loc='lower left', fontsize=16)

                # Show the plot
                plt.tight_layout()

                # Save the figure as an SVG file
                plt.savefig(f'./{ds}_{encoding_method}_{model_type}_pre_v_post_gs_analysis_{attr}_{method}_handling_performance_trend.svg', format='svg')  # You can change the filename if needed

                plt.show()

In [None]:
report_dir = 'e:/safra/Documents/GitHub/visual-physiology-opsin-db/result_files/phylo_weighted_cv'
atts_of_intrst = ['R2', 'MAE', 'MAPE', 'MSE', 'RMSE']

In [None]:
wt_pre_gs_results_df = pd.read_csv(f"{report_dir}/one_hot_encoded/pre_grid_search/wt_vpod_1.2_LG_F_R7_phylo_cv_2024-08-21_18-35-49/wt_vpod_1.2_one_hot_LG_F_R7_phylo_cv_results.csv") # Replace with your actual file name
wt_pre_gs_results_df.head()

In [None]:
# Load the results file
wt_post_gs_results_df = pd.read_csv(f"{report_dir}/one_hot_encoded/post_grid_search/wt_vpod_1.2_LG_F_R7_phylo_cv_2025-03-19_20-11-06/wt_vpod_1.2_one_hot_gs_LG_F_R7_phylo_cv_results.csv") # Replace with your actual file name
wt_post_gs_results_df.head()

In [None]:
# Load the results file
encoding_method = 'onehot'
compare_phylo_cv_results('wt',encoding_method,wt_pre_gs_results_df,wt_post_gs_results_df,atts_of_intrst,max_threshold=50)

In [None]:
# Load the results file
wt_vert_pre_gs_results_df = pd.read_csv(f"{report_dir}/one_hot_encoded/pre_grid_search/wt_vert_vpod_1.2_LG_F_R6_phylo_cv_2024-08-26_18-18-20/wt_vert_vpod_1.2_LG_F_R6_phylo_cv_results.csv") # Replace with your actual file name
wt_vert_pre_gs_results_df.head()

In [None]:
# Load the results file
wt_vert_post_gs_results_df = pd.read_csv(f"{report_dir}/one_hot_encoded/post_grid_search/wt_vert_vpod_1.2_LG_F_R6_phylo_cv_2025-03-19_20-10-57/wt_vert_vpod_1.2_LG_F_R6_phylo_cv_results.csv") # Replace with your actual file name
wt_vert_post_gs_results_df.head()

In [None]:
models_to_keep = ['gbr','xgb','rf','BayesianRidge']
compare_phylo_cv_results('wt_vert',encoding_method,wt_vert_pre_gs_results_df,wt_vert_post_gs_results_df,atts_of_intrst,models_to_keep,max_threshold=35)

In [None]:
# Load the results file
wt_aap_pre_gs_results_df = pd.read_csv(f"{report_dir}/aa_prop_encoded/pre_grid_search/wt_vpod_1.2_LG_F_R7_phylo_cv_2025-03-23_19-11-21/wt_vpod_1.2_LG_F_R7_phylo_cv_results.csv") # Replace with your actual file name
wt_aap_pre_gs_results_df.head()

In [None]:
# Load the results file
wt_aap_post_gs_results_df = pd.read_csv(f"{report_dir}/aa_prop_encoded/post_grid_search/wt_vpod_1.2_LG_F_R7_phylo_cv_2025-03-23_18-48-16/wt_vpod_1.2_LG_F_R7_phylo_cv_results.csv") # Replace with your actual file name
wt_aap_post_gs_results_df.head()

In [None]:
encoding_method='aaprop'
models_to_keep = ['gbr','xgb','rf']
compare_phylo_cv_results('wt',encoding_method,wt_aap_pre_gs_results_df,wt_aap_post_gs_results_df,atts_of_intrst,models_to_keep,max_threshold=50)

In [None]:
encoding_method='aaprop_vs_onehot_pre_gs'
models_to_keep = ['gbr']
compare_phylo_cv_results('wt',encoding_method,wt_post_gs_results_df,wt_aap_post_gs_results_df,atts_of_intrst,models_to_keep,max_threshold=50, suffix='aaprop')

MNM Analysis

In [None]:
wt_mnm_pre_gs_results_df = pd.read_csv(f"{report_dir}/one_hot_encoded/pre_grid_search/wt_mnm_aligned_VPOD_1.2_het_phylo_cv_2025-03-12_18-57-08/wt_mnm_aligned_VPOD_1.2_het_phylo_cv_results.csv") # Replace with your actual file name
wt_mnm_pre_gs_results_df.head()

In [None]:
wt_mnm_post_gs_results_df = pd.read_csv(f"{report_dir}/one_hot_encoded/post_grid_search/wt_mnm_aligned_VPOD_1.2_het_phylo_cv_2025-03-18_19-36-44/wt_mnm_aligned_VPOD_1.2_het_phylo_cv_results.csv") # Replace with your actual file name
wt_mnm_post_gs_results_df.head()

In [None]:
encoding_method='onehot'
models_to_keep = ['gbr']
compare_phylo_cv_results('wt_mnm',encoding_method,wt_mnm_pre_gs_results_df,wt_mnm_post_gs_results_df,atts_of_intrst,models_to_keep,max_threshold=50)

In [None]:
wt_mnm_aap_pre_gs_results_df = pd.read_csv(f"{report_dir}/aa_prop_encoded/pre_grid_search/wt_mnm_aligned_VPOD_1.2_het_phylo_cv_2025-04-16_14-57-51/wt_mnm_aligned_VPOD_1.2_het_phylo_cv_results.csv") # Replace with your actual file name
wt_mnm_aap_pre_gs_results_df.head()

In [None]:
wt_mnm_aap_post_gs_results_df = pd.read_csv(f"{report_dir}/aa_prop_encoded/post_grid_search/wt_mnm_aligned_VPOD_1.2_het_phylo_cv_2025-04-16_19-45-55/wt_mnm_aligned_VPOD_1.2_het_phylo_cv_results.csv") # Replace with your actual file name
wt_mnm_aap_post_gs_results_df.head()

In [None]:
encoding_method='aaprop'
models_to_keep = ['gbr']
compare_phylo_cv_results('wt_mnm',encoding_method,wt_mnm_aap_pre_gs_results_df,wt_mnm_aap_post_gs_results_df,atts_of_intrst,models_to_keep,max_threshold=50)

aa_prop vs. one-hot mnm

In [None]:
encoding_method='aaprop_vs_onehot'
models_to_keep = ['gbr']
compare_phylo_cv_results('wt_mnm',encoding_method,wt_mnm_post_gs_results_df,wt_mnm_aap_post_gs_results_df,atts_of_intrst,models_to_keep,max_threshold=50, suffix='aaprop')

mnm vs. het

In [None]:
encoding_method='aaprop_mnm_v_het'
models_to_keep = ['gbr']
compare_phylo_cv_results('wt_mnm_v_wt',encoding_method,wt_aap_post_gs_results_df,wt_mnm_aap_post_gs_results_df,atts_of_intrst,models_to_keep,max_threshold=50, suffix='mnm')