In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir('/Users/zhenyamordan/PyCharmProjects/Kinome-Regularization/')

In [3]:
import pandas as pd
import plotly.express as px
from sklearn.metrics import r2_score, mean_squared_error
import numpy as np

In [4]:
!ls

CV_utils.py                   graph_utils.py
Linreg_clean.ipynb            model_sanity_check.ipynb
README.md                     phenotype_preprocess.py
RegressionModels.py           phenotype_preprocessing.ipynb
[34mfof[m[m                           plot_utils.py


In [16]:
base_path = './fof/data/r/hyperparameter_tunning/'

In [18]:
experiments = [i for i in os.listdir(base_path) if '.DS_Store' not in i]
experiments

['v0_1_Kp_7_Kt_4_A_1_B_1', 'v0_1_Kp_7_Kt_4_A_0.5_B_0.5']

In [23]:
def parse_hyperparameters(experiment):
    split_elements = experiment.split('_')
    return ', '.join([f'{split_elements[i]}: {split_elements[i + 1]}' for i in range(0, len(split_elements), 2)])

In [24]:
parse_hyperparameters(experiments[0])

'v0: 1, Kp: 7, Kt: 4, A: 1, B: 1'

In [45]:
def plot_results(base_path, experiments):
    aggregated_result = {}
    for i, experiment in enumerate(experiments):
        formatted_hyperparameters = parse_hyperparameters(experiment)
        print(i, formatted_hyperparameters)
        folder_path = base_path + f'{experiment}/'
        plots_folder_path = base_path + f'{experiment}/plots/'
        
        os.makedirs(folder_path, exist_ok=True)
        os.makedirs(plots_folder_path, exist_ok=True)

        y_train_predicted = pd.read_csv(folder_path + f'y_train_predicted.csv')['x'].to_numpy()
        y_train_true_standartized = pd.read_csv(folder_path+ f'y_train_true_standartized.csv')['x'].to_numpy()
        y_test_predicted = pd.read_csv(folder_path + f'y_test_predicted.csv')['x'].to_numpy()
        y_test_true_standartized = pd.read_csv(folder_path + f'y_test_true_standartized.csv')['x'].to_numpy()
        gamma = pd.read_csv(folder_path + f'gamma.csv')['x'].to_numpy()
        
        aggregated_result[experiment] = {
            'train_mse': mean_squared_error(y_train_true_standartized, y_train_predicted),
            'test_mse': mean_squared_error(y_test_true_standartized, y_test_predicted),
            'gamma_sum': gamma.sum(),
            'gamma_count': (gamma > 0.01).sum(),
        }
        print('Train MSE\t', aggregated_result[experiment]['train_mse'])
        print('Test MSE\t', aggregated_result[experiment]['test_mse'])
        print('Gamma sum\t', aggregated_result[experiment]['gamma_sum'])
        print('Gamma > 0.01 count\t', aggregated_result[experiment]['gamma_count'])
        
        fig = px.scatter(x=y_train_predicted, y=y_train_true_standartized, title=f'Train, {formatted_hyperparameters}')
        fig.update_xaxes(title_text="Predicted Values")  # X-axis label
        fig.update_yaxes(title_text="True Standarized Values")  # Y-axis label
        fig.show()
        fig.write_html(folder_path + 'true_predicted_scatter_plot_train.html')

        fig = px.scatter(x=y_test_predicted, y=y_test_true_standartized, title=f'Test, {formatted_hyperparameters}')
        fig.update_xaxes(title_text="Predicted Values")  # X-axis label
        fig.update_yaxes(title_text="True Standarized Values")  # Y-axis label
        fig.show()
        fig.write_html(folder_path + 'true_predicted_scatter_plot_test.html')   
    return aggregated_result

In [46]:
aggregated_result = plot_results(base_path, experiments)

0 v0: 1, Kp: 7, Kt: 4, A: 1, B: 1
Train MSE	 0.0884717718622331
Test MSE	 766.2870884784836
Gamma sum	 5.060951242036054
Gamma > 0.01 count	 6


1 v0: 1, Kp: 7, Kt: 4, A: 0.5, B: 0.5
Train MSE	 0.018386097419923156
Test MSE	 610.5869943162883
Gamma sum	 8.947189741606943
Gamma > 0.01 count	 9


In [67]:
result_df = pd.DataFrame(aggregated_result)
result_df = result_df.rename(columns = {i: parse_hyperparameters(i) for i in result_df.columns})

### Aggregated results overview

In [68]:
px.bar(result_df, barmode='group')

### Aggregated results

In [69]:
for i, row in result_df.iterrows():
    fig = px.bar(row, barmode='group')
    fig.show()