# Load checkpoint and imports

In [1]:
import sys
import warnings
warnings.filterwarnings('ignore')
# !{sys.executable} -m pip install seaborn

In [2]:
import sklearn
import pandas as pd
import numpy as np
from sklearn.model_selection import cross_validate,train_test_split
import ast
from sklearn.linear_model import LinearRegression,Lasso,LassoCV,MultiTaskLassoCV
from sklearn.metrics import roc_auc_score, mean_squared_error
import scipy
import os
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy import stats
base_path= '../../../../../../T5 EVO/'
from sklearn.preprocessing import StandardScaler
import torch
from sklearn.linear_model import LogisticRegression,LassoCV
import random
plt.rcParams["font.size"] = 35
# [x[0] for x in os.walk(base_path)]
# input_file_keller_pom = '/local_storage/datasets/farzaneh/openpom/data/curated_datasets/embeddings/pom/keller_pom_embeddings_Apr17.csv'
# 

In [3]:
seed = 2024
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)

In [4]:
times=30
n_components=20

# Helper Methods

In [5]:
def combine_visualize(df1,df2,tasks,ax,title,type="corr",figure_name="def"):
    df_combined = pd.concat((df1,df2))
    # df_combined =  pd.concat((df12,df3))
    melted_df_keller = df_combined.melt(id_vars=['model'], var_name='descritpor')
    
    # g1.despine(left=True)
    # g1.set_axis_labels("", title)
    # g1.legend.set_title("")
    # g1.set_xticklabels(tasks, rotation=45)
    
    if type=="corr":
        melted_df_keller['value'] = melted_df_keller['value'].abs()
    else:
        pass
        # melted_df_keller['value'] = melted_df_keller[['value']].apply(np.sqrt)
        # melted_df_keller['value'] = melted_df_keller.groupby('descritpor')['value'].transform(lambda x: (x) / (x.max() - x.min()))

        # melted_df_keller = melted_df_keller / (melted_df_keller.max() - melted_df_keller.min() )
    # melted_df_keller = melted_df_keller.dropna()


    g1 = sns.barplot(
    data=melted_df_keller,
    x="descritpor", y="value", hue="model",
    errorbar="se",ax=ax,palette=['#4d79a4','#ecc947','#b07aa0'], linewidth=7 )
    g1.set(xlabel='Model', ylabel=title)
    g1.spines['top'].set_visible(False)
    g1.spines['right'].set_visible(False)
    # g2 = sns.barplot(
    # data=melted_df_keller,
    # x="model", y="value",
    # errorbar="sd", palette="dark", alpha=.6)
    # g2.despine(left=True)
    # g2.set_axis_labels("", "Body mass (g)")
    # g2.legend.set_title("")
    g1.set_xticklabels(tasks, rotation=90)
    # change_width(g1, 0.1)
    # g1.figure.savefig(figure_name+".pdf")
    return g1,melted_df_keller

In [6]:
def combine_visualize_separate(df1,df2,tasks,ax,title,type="corr"):
    df_combined = pd.concat((df1,df2))
    melted_df_keller = df_combined.melt(id_vars=['model'], var_name='descritpor')
    # g1 = sns.catplot(
    # data=melted_df_keller, kind="bar",
    # x="descritpor", y="value", hue="model",
    # errorbar="sd", palette="dark", alpha=.6, height=6,aspect =2 )
    # g1.despine(left=True)
    # g1.set_axis_labels("", "Body mass (g)")
    # g1.legend.set_title("")
    # g1.set_xticklabels(tasks, rotation=45)
    if type=="corr":
        melted_df_keller['value'] = melted_df_keller['value'].abs()
    else:
        pass
        # melted_df_keller['value'] = melted_df_keller[['value']].apply(np.sqrt)
    g2 = sns.barplot(
    data=melted_df_keller,
    x="model", y="value",
    errorbar="se", palette="dark", alpha=.6,ax=ax)
    # g2.set_axis_labels("", title)
    g2.set(xlabel='Model', ylabel=title)
    # g2.despine(left=True)
    # g2.set_axis_labels("", "Body mass (g)")
    # g2.legend.set_title("")
    # g.set_xticklabels(tasks, rotation=45)

In [7]:
def normalize_rmse(df,min_max,j):
    df=df.loc[:, (df != 0).any(axis=0)] 
    min_max=min_max.loc[:, (min_max != 0).any(axis=0)] 
    
    max_values = min_max.max()
    min_values = min_max.min()
    min_max.columns = ['Dataset','Type']+df.columns.values.tolist()[:j]
    # df = df[['nCIR', 'ZM1', 'GNar', 'S1K', 'piPC08', 'MATS1v', 'MATS7v', 'GATS1v',
    #    'Eig05_AEA(bo)', 'SM02_AEA(bo)', 'SM03_AEA(dm)', 'SM10_AEA(dm)',
    #    'SM13_AEA(dm)', 'SpMin3_Bh(v)' ,'nRCOSR']]
    # min_max =  min_max[['nCIR', 'ZM1', 'GNar', 'S1K', 'piPC08', 'MATS1v', 'MATS7v', 'GATS1v',
    #    'Eig05_AEA(bo)', 'SM02_AEA(bo)', 'SM03_AEA(dm)', 'SM10_AEA(dm)',
    #    'SM13_AEA(dm)', 'SpMin3_Bh(v)' ,'nRCOSR']]+['Dataset','Type','layer']
    # Drop 'Dataset' and 'Type' as they are not numeric columns

    # df= df.dropna(inplace=False,axis=1,how='any')
    # min_max= min_max.dropna(inplace=False,axis=1,how='any')

    
    
    min_values = min_values.drop(['Dataset', 'Type'])
    max_values = max_values.drop(['Dataset', 'Type'])
    
    for i,col in enumerate(df.columns[:j]):
        print(i,col)
        print(max_values[i], min_values[i])
        df[col] = np.sqrt(df[col]) / (max_values[i] - min_values[i])
    return df



In [35]:
def post_process_dataframe(corrss,msess,df_cor_pom,df_mse_pom,tasks,figure_name="def"):
    corrss=corrss.loc[:, (corrss != 0).any(axis=0)] 
    df_cor_pom=df_cor_pom.loc[:, (df_cor_pom != 0).any(axis=0)] 
    corrss=corrss.dropna(axis=1)
    df_cor_pom=df_cor_pom.dropna(axis=1)
    # corrss_1_12 = corrss.loc[((corrss["layer"]==0) | (corrss["layer"]==12)) & (corrss["model"]=="molformer")]
    # del corrss_1_12["model"]
    # melted_corrss_1_12 = corrss_1_12.melt(id_vars=['layer'], var_name='descritpor')
    # melted_corrss_filtered_increasing=melted_corrss_1_12.groupby('descritpor').filter(lambda x: x.loc[x['layer'] == 12, 'value'].abs().mean() > x.loc[x['layer'] == 0, 'value'].abs().mean())
    # melted_corrss_filtered_decreasing=melted_corrss_1_12.groupby('descritpor').filter(lambda x: x.loc[x['layer'] == 0, 'value'].abs().mean() > x.loc[x['layer'] == 12, 'value'].abs().mean())

    # print(melted_corrss_1_12.descritpor.unique())

    # melted_corrss_filtered_increasing['trend']='Increasing'
    # melted_corrss_filtered_decreasing['trend']='Decreasing'

    # melted_corrss_filtered = pd.concat((melted_corrss_filtered_increasing,melted_corrss_filtered_decreasing))
    # fig, ax = plt.subplots(nrows=1,ncols=2,figsize=(30,10))
    # sns.lineplot(
    # data=melted_corrss_filtered_increasing, x="layer", y="value", hue="descritpor", err_style='bars',ax=ax[0],errorbar="se"
    # )
    # sns.lineplot(
    # data=melted_corrss_filtered_decreasing, x="layer", y="value", hue="descritpor", err_style='bars',ax=ax[1],errorbar="se"
    # )

    # f1, ax_agg = plt.subplots(1, 2,figsize=(20, 5))

    f2, ax = plt.subplots(1, 2, figsize=(30,12))
    # combine_visualize_separate(corrss.loc[corrss["layer"]==12,].iloc[:,corrss.columns != 'layer'], df_cor_pom,df_cor_alva,tasks,ax_agg[0],'Correlation Coefficient',figure_name="Correlation_Avg_"+figure_name)
    g1,melted_df_keller=combine_visualize(corrss.loc[corrss["layer"]==12].iloc[:,corrss.columns != 'layer'], df_cor_pom,tasks,ax[0],'Correlation Coefficient',figure_name="Correlation_"+figure_name)
    g1.set_xlabel('')
    
    # combine_visualize_separate(msess.loc[msess["layer"]==12].iloc[:,msess.columns != 'layer'], df_mse_pom,df_mse_alva,tasks,ax_agg[1],'MSE',type="mse",figure_name="MSE_Avg_"+figure_name)
    g2,melted_df_keller=combine_visualize(msess.loc[msess["layer"]==12].iloc[:,msess.columns != 'layer'], df_mse_pom,tasks,ax[1],'NRMSE',type="mse",figure_name="MSE__"+figure_name)
    g2.set_xlabel('Descriptor')

    g1.legend().set_title("Model")
    handles, labels = g1.get_legend_handles_labels()
    g1.get_legend().remove()
    g2.legend().set_title("Model")
    handles, labels = g2.get_legend_handles_labels()
    g2.get_legend().remove()
    print(labels)
    f2.subplots_adjust(bottom=0,left=0.1,right=0.95,top=0.7)
    labels = ['MoLFormer', 'Open-POM']
    f2.legend(handles, labels, ncol=3, columnspacing=1, prop={'size': 25}, handlelength=1.5, loc="lower center",
               borderpad=0.3,
               
               bbox_to_anchor=(0.54, -0.05), 
               
               frameon=True, labelspacing=0.4,handletextpad=0.2,)
    # plt.legend(title='Smoker', loc='upper left',)

    plt.subplots_adjust(hspace = 0.65)
    plt.subplots_adjust(wspace = 0.8)
    plt.tight_layout()

    f2.savefig(figure_name+"_h.pdf", bbox_inches='tight')
    
    corrss_molformer = corrss.loc[(corrss["model"]=="molformer")]
    del corrss_molformer["model"]
    melted_corrss_molformer = corrss_molformer.melt(id_vars=['layer'], var_name='descritpor')
    # melted_corrss_molformer=melted_corrss_molformer.dropna()
    g = sns.FacetGrid(melted_corrss_molformer, col='descritpor', col_wrap=5, height=4, aspect=1.5)
    g.map(sns.lineplot, 'layer', 'value',palette=['#4d79a4','#ecc947','#b07aa0'])
    g.set_titles(col_template='{col_name}')
    g.set_axis_labels('Layer', 'Correlation Coefficient')

    max_stages = melted_corrss_molformer.loc[melted_corrss_molformer.groupby('descritpor')['value'].idxmax()]

    # 

    for ax, (idx, row) in zip(g.axes.flat, max_stages.iterrows()):
        # ax.plot(row['layer'], row['value'], 'ro')  # 'ro' means red color, round marker
        # ax.annotate(f'Stage {int(row["layer"])}', 
        #             xy=(row['layer'], row['value']),
        #             xytext=(row['layer'], row['value']+0.05),
        #             arrowprops=dict(facecolor='black', shrink=0.05),
        #             ha='center')
        ax.set_xticks([1,3,5,7,9,11])  # Set x-ticks to match data stages
        ax.set_xticklabels([2,4,6,8,10,12])  # Change x-tick labels to range from 1 to 12

    
    # return melted_corrss_filtered,melted_df_keller

    g.set_axis_labels('', '')  # Remove individual axis labels

    # Adding a single x-axis and y-axis label for the entire figure
    g.fig.text(0.5, 0.04, 'Layer', ha='center', va='center', fontsize=35)
    g.fig.text(0.0, 0.5, 'Correlation Coefficient', ha='center', va='center', rotation='vertical', fontsize=35)
    g.fig.set_size_inches(25, 15)

    g.savefig(figure_name+"_molformer_trend.pdf")

    return melted_df_keller,melted_df_keller
    


# Extracting Representations

## Keller

In [19]:
chemical_features_r=["nCIR",
                     "ZM1", 
                     "GNar", 
                     "S1K", 
                     "piPC08",
                     "MATS1v",
                     "MATS7v",
                     "GATS1v", 
                     "Eig05_AEA(bo)", 
                     "SM02_AEA(bo)",
                     "SM03_AEA(dm)",
                     "SM10_AEA(dm)",
                     "SM13_AEA(dm)",
                      "SpMin3_Bh(v)",
                     # "RDF035v",
                     # "G1m",
                     # "G1v",
                     # "G1e",
                     # "G3s",
                     # "R8u+",
                     "nRCOSR"]

In [20]:
# df_mse_pom

In [21]:
# min_max

In [22]:
# df_corrs_molfomer=pd.read_csv("df_"+ds+"_corrs_chemical_molfomer.csv")
# df_mse_pom

In [36]:
ds = 'keller'

df_cor_pom=pd.read_csv("df_"+ds+"_cor_chemical_pom.csv")
df_mse_pom=pd.read_csv("df_"+ds+"_mse_chemical_pom.csv")

df_corrs_molfomer=pd.read_csv("df_"+ds+"_corrs_chemical_molfomer.csv")
df_mses_molfomer=pd.read_csv("df_"+ds+"_mses_chemical_molfomer.csv")

min_max = pd.read_csv(ds+"_min_max_alva.csv")
df_mses_molfomer=normalize_rmse(df_mses_molfomer,min_max,j=-2)
df_mse_pom=normalize_rmse(df_mse_pom,min_max,j=-1)
trend_learning_molformer,melted_df_keller =  post_process_dataframe(df_corrs_molfomer,df_mses_molfomer,df_cor_pom,df_mse_pom,chemical_features_r,"keller_chems")
# trend_learning_molformer['dataset']=ds

In [24]:
ds = 'ravia'
df_cor_pom=pd.read_csv("df_"+ds+"_cor_chemical_pom.csv")
df_mse_pom=pd.read_csv("df_"+ds+"_mse_chemical_pom.csv")

df_corrs_molfomer=pd.read_csv("df_"+ds+"_corrs_chemical_molfomer.csv")
df_mses_molfomer=pd.read_csv("df_"+ds+"_mses_chemical_molfomer.csv")
min_max = pd.read_csv(ds+"_min_max_alva.csv")
df_mses_molfomer=normalize_rmse(df_mses_molfomer,min_max,j=-2)
df_mse_pom=normalize_rmse(df_mse_pom,min_max,j=-1)

trend_learning_molformer,melted_df_keller =  post_process_dataframe(df_corrs_molfomer,df_mses_molfomer,df_cor_pom,df_mse_pom,chemical_features_r,"ravia_chems")
# trend_learning_molformer['dataset']=ds

In [25]:
ds = 'snitz'
df_cor_pom=pd.read_csv("df_"+ds+"_cor_chemical_pom.csv")
df_mse_pom=pd.read_csv("df_"+ds+"_mse_chemical_pom.csv")

df_corrs_molfomer=pd.read_csv("df_"+ds+"_corrs_chemical_molfomer.csv")
df_mses_molfomer=pd.read_csv("df_"+ds+"_mses_chemical_molfomer.csv")
min_max = pd.read_csv(ds+"_min_max_alva.csv")
df_mses_molfomer=normalize_rmse(df_mses_molfomer,min_max,j=-2)
df_mse_pom=normalize_rmse(df_mse_pom,min_max,j=-1)

trend_learning_molformer,melted_df_keller  =  post_process_dataframe(df_corrs_molfomer,df_mses_molfomer,df_cor_pom,df_mse_pom,chemical_features_r,"snitz_chems")
trend_learning_molformer['dataset']=ds

In [26]:
ds = 'sagar'
df_cor_pom=pd.read_csv("df_"+ds+"_cor_chemical_pom.csv")
df_mse_pom=pd.read_csv("df_"+ds+"_mse_chemical_pom.csv")

df_corrs_molfomer=pd.read_csv("df_"+ds+"_corrs_chemical_molfomer.csv")
df_mses_molfomer=pd.read_csv("df_"+ds+"_mses_chemical_molfomer.csv")
min_max = pd.read_csv(ds+"_min_max_alva.csv")
df_mses_molfomer=normalize_rmse(df_mses_molfomer,min_max,j=-2)
df_mse_pom=normalize_rmse(df_mse_pom,min_max,j=-1)
trend_learning_molformer,melted_df_keller =  post_process_dataframe(df_corrs_molfomer,df_mses_molfomer,df_cor_pom,df_mse_pom,chemical_features_r,"sagar_chems")
trend_learning_molformer['dataset']=ds

In [27]:
ds = 'gslf'
df_cor_pom=pd.read_csv("df_"+ds+"_cor_chemical_pom.csv")
df_mse_pom=pd.read_csv("df_"+ds+"_mse_chemical_pom.csv")

df_corrs_molfomer=pd.read_csv("df_"+ds+"_corrs_chemical_molfomer.csv")
df_mses_molfomer=pd.read_csv("df_"+ds+"_mses_chemical_molfomer.csv")
min_max = pd.read_csv(ds+"_min_max_alva.csv")
df_mses_molfomer=normalize_rmse(df_mses_molfomer,min_max,j=-2)
df_mse_pom=normalize_rmse(df_mse_pom,min_max,j=-1)
trend_learning_molformer,melted_df_keller =  post_process_dataframe(df_corrs_molfomer,df_mses_molfomer,df_cor_pom,df_mse_pom,chemical_features_r,"gslf_chems")
trend_learning_molformer['dataset']=ds

In [37]:
ds = 'all'
df_cor_pom=pd.read_csv("df_"+ds+"_cor_chemical_pom.csv")
df_mse_pom=pd.read_csv("df_"+ds+"_mse_chemical_pom.csv")

df_corrs_molfomer=pd.read_csv("df_"+ds+"_corrs_chemical_molfomer.csv")
df_mses_molfomer=pd.read_csv("df_"+ds+"_mses_chemical_molfomer.csv")
min_max = pd.read_csv(ds+"_min_max_alva.csv")
df_mses_molfomer=normalize_rmse(df_mses_molfomer,min_max,j=-2)
df_mse_pom=normalize_rmse(df_mse_pom,min_max,j=-1)
trend_learning_molformer,melted_df_keller =  post_process_dataframe(df_corrs_molfomer,df_mses_molfomer,df_cor_pom,df_mse_pom,chemical_features_r,"all_chems")
trend_learning_molformer['dataset']=ds