In [None]:
import sys
# !{sys.executable} -m pip install shap
import lightgbm as lgb
import xgboost
import shap
from xgboost import XGBRegressor
from collections import Counter

from tqdm.notebook import trange, tqdm
from time import sleep, perf_counter

import os
from glob import glob
import pandas as pd
import numpy as np
import networkx as nx
from scipy.spatial import distance_matrix

# Scikit-learn
from sklearn.model_selection import KFold
from sklearn.kernel_ridge import KernelRidge
from sklearn.metrics import r2_score
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import r2_score,mean_squared_error,mean_absolute_error,mean_absolute_percentage_error
from sklearn.preprocessing import normalize, MinMaxScaler


# Torch
import torch
print(torch.__version__,torch.__path__)
import torch.nn as nn

# Reps
from alchemical_cms import genpaddedCMs
from dscribe.descriptors import SOAP
from dscribe.kernels import REMatchKernel
from mendeleev.fetch import fetch_table
from ase.io import read
from Element_PI import VariancePersist
from Element_PI import VariancePersistv1

from rdkit import DataStructs
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw, MACCSkeys
from openbabel import openbabel as ob
from openbabel import pybel

from xyz2graph import MolGraph, to_networkx_graph, to_plotly_figure
from plotly.offline import offline
#Plotting
import seaborn as sns
sns.set_style()
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.dpi'] = 200


In [None]:
df=pd.read_excel('site_data.xlsx',index_col=0)
sitelist=['A_carboxy', 'A_amine', 'B_carboxy', 'B_amine', 'A_link','B_link', 'A_side', 'B_side']
inter_df=df[sitelist].dropna()
# inter_df[inter_df!=0]=1

In [None]:


y=pd.read_excel('400_dipeptides_interaction_energy.xlsx',usecols=[1,2],index_col=0)
y['keypair']=['_'.join(i.split('_')[:3]) for i in y.index]
pairs=pd.DataFrame([(i[0].split(":")[0],i[0].split(":")[1].replace(" ","").replace('acid','')) for i in pd.read_excel('400_dipeptides_interaction_energy.xlsx', 'Sheet2',header=None).values]).set_index(0)


# Find set of files
monomerfiles={}        
co2files={}
for k,v in pairs.to_dict()[1].items():
    try:
        path=glob(f'./dipeptides_co2_coordinates/{k}_*.xyz')[0]
        monpath=f'./dipeptides_coordinates/{k}.xyz'
        if os.path.exists(path) and os.path.exists(monpath):
            co2files[v]=path    
            monomerfiles[v]=monpath
    except:
        print(k)
        
# Name check
co2check=['_'.join(i.split('/')[-2].split('_')[:3]) for i in co2files.values()]
moncheck=['_'.join(i.split('/')[-2].split('_')[:4]) for i in monomerfiles.values()]

In [None]:
AAs=sorted(set([i[0].split('-')[0].replace('acid','') for i in pairs.values]))

In [None]:
noncarboyx=inter_df[(inter_df[['A_carboxy','B_carboxy']]==0).all(axis=1)]

In [None]:
noncarboyx

In [None]:

co2dipepgraphs={}
co2mgs={}
for k,v in co2files.items():
    # Create the MolGraph object
    mg = MolGraph()

    # Read the data from the .xyz file
    mg.read_xyz(v)
    co2mgs[k]=mg
    # Convert the molecular graph to the NetworkX graph
    G = to_networkx_graph(mg)
    co2dipepgraphs[k]=G
    
dipepgraphs={}    
mgs={}
for k,v in monomerfiles.items():
    # Create the MolGraph object
    mg = MolGraph()

    # Read the data from the .xyz file
    mg.read_xyz(v)
    mgs[k]=mg
    # Convert the molecular graph to the NetworkX graph
    G = to_networkx_graph(mg)
    dipepgraphs[k]=G    


In [None]:
# FLAG!
# df.loc['Asparagine-Alanine']
# fig = to_plotly_figure(co2mgs['Asparagine-Alanine'])
# offline.plot(fig)

# FLAG THIS ONE!!
# fig = to_plotly_figure(co2mgs['Alanine-Lysine'])
# offline.plot(fig)

In [None]:
for k,v in pairs.to_dict()[1].items():
    names=v.split('-')
    for j in names:
        if 'acid' in j:
            print(names)

In [None]:
AAs

In [None]:
dfpairE=pd.DataFrame(np.zeros((20,20)),index=AAs,columns=AAs)
for k,v in pairs.to_dict()[1].items():
    names=v.split('-')
    dfpairE[names[0]][names[1]]=y.set_index('keypair').loc[k].values

In [None]:
devpairs=pd.DataFrame(sum([[('-'.join((i,j)),abs(dfpairE[i][j]-dfpairE[j][i])) for idxi,i in enumerate(AAs) if idxj>idxi] for idxj,j in enumerate(AAs)],[]),columns=['Pairs','Deviation'])

In [None]:
len(devpairs[devpairs['Deviation']<=1])/len(devpairs)

In [None]:
df['dev_gly']=df['Interaction_Energy']-df['Interaction_Energy']['Glycine-Glycine']

In [None]:
df0=df[df['dev_gly']>0]
gt0=np.array([i.split('-') for i in df0.index])
# df1=df[(df['dev_gly']<0)&(df['dev_gly']>=-6)]
df1=df[(df['dev_gly']<0)]
gt1=np.array([i.split('-') for i in df1.index])
# df2=df[(df['dev_gly']<=-6)]
# gt2=np.array([i.split('-') for i in df2.index])
# 

In [None]:
np.unique(inter_df.loc[df0.index].values,axis=0).shape,df0.index.shape

In [None]:
inter_df.loc[df0.index].sum(axis=0).reset_index()

In [None]:
dfA=inter_df[['A_carboxy', 'A_amine', 'A_link', 'A_side']].sum(axis=1).astype(int)
dfA=dfA[dfA!=0]
dfB=inter_df[['B_carboxy', 'B_amine','B_link','B_side']].sum(axis=1).astype(int)
dfB=dfB[dfB!=0]
dfAB=pd.concat([dfA,dfB],axis=1)

In [None]:
dfA[dfA!=0].shape,dfB[dfB!=0].shape

In [None]:
dfA.loc['Alanine-Glycine']

In [None]:
fig,ax=plt.subplots(2,2)
sns.histplot(df['dev_gly'].loc[dfAB[dfAB[0]==3].index],ax=ax[0,0])
sns.histplot(df['dev_gly'].loc[dfAB[dfAB[1]==3].index],ax=ax[0,1])
sns.histplot(df['dev_gly'].loc[dfAB[(~dfAB[0].isna())&(dfAB[0]!=3)].index],ax=ax[1,0])
# sns.histplot(df['dev_gly'].loc[dfAB[(~dfAB[1].isna())&(dfAB[1]!=3)].index],ax=ax[1,1])

In [None]:
len((set(dfB.index)-set(dfA.index))-(set(dfA.index)-set(dfB.index)))

In [None]:
sns.barplot(data=inter_df.loc[df0.index].sum(axis=0).reset_index().sort_values(by='index'),x='index',y=0)

In [None]:
sns.barplot(data=inter_df.loc[df2.index].sum(axis=0).reset_index().sort_values(by=0),x='index',y=0)

In [None]:
inter_df.loc[df2.index]

In [None]:
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(10,5),sharey=True)
sns.barplot(data=pd.DataFrame.from_dict(dict(Counter(gt0[:,0])),orient='index').reset_index().rename(columns={'index':'A',0:'count'}).sort_values(by='count'),x='A',y='count',ax=ax1)
sns.barplot(data=pd.DataFrame.from_dict(dict(Counter(gt0[:,1])),orient='index').reset_index().rename(columns={'index':'B',0:'count'}).sort_values(by='count'),x='B',y='count',ax=ax2)
ax1.tick_params(labelrotation=90)
ax2.tick_params(labelrotation=90)

plt.show()

In [None]:
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(10,5),sharey=True)
sns.barplot(data=pd.DataFrame.from_dict(dict(Counter(gt1[:,0])),orient='index').reset_index().rename(columns={'index':'A',0:'count'}).sort_values(by='count'),x='A',y='count',ax=ax1)
sns.barplot(data=pd.DataFrame.from_dict(dict(Counter(gt1[:,1])),orient='index').reset_index().rename(columns={'index':'B',0:'count'}).sort_values(by='count'),x='B',y='count',ax=ax2)
ax1.tick_params(labelrotation=90)
ax2.tick_params(labelrotation=90)
ax1.set_ylim(0,25)
plt.show()

In [None]:
fig,(ax1,ax2)=plt.subplots(1,2,figsize=(10,5),sharey=True)
sns.barplot(data=pd.DataFrame.from_dict(dict(Counter(gt2[:,0])),orient='index').reset_index().rename(columns={'index':'A',0:'count'}).sort_values(by='count'),x='A',y='count',ax=ax1)
sns.barplot(data=pd.DataFrame.from_dict(dict(Counter(gt2[:,1])),orient='index').reset_index().rename(columns={'index':'B',0:'count'}).sort_values(by='count'),x='B',y='count',ax=ax2)
ax1.tick_params(labelrotation=90)
ax2.tick_params(labelrotation=90)
plt.show()

In [None]:
ax=sns.histplot(data=df,x='dev_gly')

In [None]:
plt.figure(figsize=(30,5))
plt.scatter(range(len(df['Interaction_Energy'])),(df['Interaction_Energy']-df['Interaction_Energy']['Glycine-Glycine']).sort_values())
plt.xticks(range(len(df['Interaction_Energy'])),df['Interaction_Energy'].index,rotation=90, fontsize = 6)
plt.xlim(-1,len(df['Interaction_Energy'])+1)
plt.show()

In [None]:
plt.figure(figsize=(25,5))
# sns.lineplot(data=devpairs,x='Pairs',y='Deviation')
plt.plot(range(len(devpairs['Pairs'])),devpairs['Deviation'],'o--')

plt.fill_between(np.arange(-1,len(devpairs['Pairs'])+1)*[1],np.zeros(len(devpairs['Pairs'])+2),np.ones(len(devpairs['Pairs'])+2),color='gray')
plt.xticks(range(len(devpairs['Pairs'])),devpairs['Pairs'],rotation=90, fontsize = 8)
plt.xlim(-1,len(devpairs)+1)
plt.ylim(0,7)
plt.xlabel('Pairs')
plt.ylabel('Deviation (kcal/mol)')
plt.title('Deviation Between Pairs AB and BA')
plt.tight_layout()
plt.savefig('absolute_pair_dev.png',dpi=300,bbox_inches='tight')
plt.show()

In [None]:
sns.heatmap(data=dfpairE,vmin=-12,vmax=0,cbar_kws={'label': 'Interaction Energy (kcal/mol)'},linewidths=0.1,square=True)
plt.xlabel('Amino Acid')
plt.ylabel('Amino Acid')

plt.tight_layout()
plt.savefig('pair_heat.png',dpi=300,bbox_inches='tight')
plt.show()

In [None]:
inter_df.T.values.shape

In [None]:
dict_count=dict(zip(inter_df.T.columns,np.count_nonzero(inter_df.T.values==1,axis=0)))

In [None]:
y

In [None]:
intvsE=[]
for k,v in df['label'].to_dict().items():
    # print(k,v,dict_count[k])
    # print(y[y['keypair']==v].values)
    intvsE.append((dict_count[k],y[y['keypair']==v]['Interaction_Energy'].values[0]))
intvsE=np.array(sorted(intvsE,key=lambda x: x[0]))    

dfintvsE=pd.DataFrame(intvsE,columns=['Count','Interaction_Energy'])

In [None]:
sns.histplot(data=dfintvsE,x='Interaction_Energy',hue='Count',kde=True,stat='count')

In [None]:
stats=pd.DataFrame(y['Interaction_Energy'].describe()).round(2)

plt.figure(figsize=(5,5))
sns.histplot(data=y,x='Interaction_Energy')
table =plt.table(cellText=stats.values,
          rowLabels=stats.index,
          colLabels=stats.columns,
          cellLoc = 'center', rowLoc = 'center',
          loc='bottom', bbox=[0.25, -0.5, 0.5, 0.3])


table.auto_set_font_size(False)
table.set_fontsize(8)


plt.subplots_adjust(left=0, bottom=0.5)
plt.tight_layout()
plt.savefig('spread.png',dpi=300,bbox_inches='tight')
plt.show()

Q1=y['Interaction_Energy'].quantile(0.25)
Q3=y['Interaction_Energy'].quantile(0.75)
IQR=Q3-Q1
upper = Q3 + 1.5*IQR

lower = Q1 - 1.5*IQR

y=y[(y['Interaction_Energy']>=lower)&(y['Interaction_Energy']<=upper)].dropna()
sns.histplot(data=y,x='Interaction_Energy')
plt.show()

In [None]:
bitkey=dict(zip(AAs,range(len(AAs))))

In [None]:
bitkey

In [None]:
# samples=len(y)
# X=np.zeros((samples,len(AAs)))
# Y=np.zeros((samples,1))

# for idx,(k,v) in enumerate(y.set_index('keypair').to_dict()['Interaction_Energy'].items()):
#     aa=pairs.loc[k].values[0].split('-')
#     a1=aa[0]
#     a2=aa[1]
#     if a1!=a2:
#         X[idx,bitkey[a1]]=1
#         X[idx,bitkey[a2]]=1
#     else:
#         X[idx,bitkey[a1]]=2
#     Y[idx]=v

In [None]:
# # Create the MolGraph object
# mg = MolGraph()

# # Read the data from the .xyz file
# mg.read_xyz(monomerfiles[0])


# # Convert the molecular graph to the NetworkX graph
# G = to_networkx_graph(mg)

# # G.nodes(data=True),G.edges(data=True)

In [None]:
# from rdkit.Chem import rdFingerprintGenerator
# mols=[Chem.MolFromSmiles(list(pybel.readfile('xyz',m))[0].write().split('\t')[0]) for m in monomerfiles]
# fpgen = rdFingerprintGenerator.GetMorganGenerator(radius=6)

# # info={}
# # X = np.vstack([fpgen.GetFingerprint(mol) for mol in mols ])

# fps = [MACCSkeys.GenMACCSKeys(x) for x in mols]
# X=np.vstack([f.ToList() for f in fps])
# Y=y.values

# keys=pd.read_excel('MACCS_keys_example.xlsx',index_col='Key').drop(columns=['Unnamed: 0'])

# mol_keys=[idx for idx, i in enumerate(X[0]) if i==1]

# {str(keys.loc[idx].values[0]):i for idx, i in enumerate(np.count_nonzero(X,axis=0)) if i!=0}

In [None]:

# sns.heatmap([[DataStructs.TanimotoSimilarity(i,j) for i in fps] for j in fps],vmin=0,vmax=1,cmap=sns.cm.rocket_r)
# plt.show()

In [None]:
# Draw.MolsToGridImage(mols,molsPerRow=10, subImgSize=(300,300))