# Final result on complete data set

## Assign BDEs and other descriptors to data points

In [1]:
import pandas as pd
from pathlib import Path
from mystructmatch_utils import add_ref_idx
import numpy as np

In [22]:
data = pd.read_csv(
    '/hits/fast/mbm/treydewk/optimized_test/BDEs_compl.txt',
    sep = '\t\t',
    names = ['names', 'BDE_H', 'BDE_G', 'charge', 'pdb', 'pdb_H', 'ref_comp'],
    engine='python'
)

### read in manually assigned alternative indices
alt_ind = pd.read_csv(
    '/hits/fast/mbm/treydewk/optimized_test/alt_ind_compl.txt',
    sep = '\t\t',
    names = ['names', 'alt_idx'], engine='python'
)
### the indices returned by the matching function are zero-indexed, whereas the PDB indices are one-indexed
for i in range(0,len(alt_ind['alt_idx'])):
    if not alt_ind.iloc[i,1] is None:
        ele_list = alt_ind.iloc[i,1].split(sep=',')
        for j,k in enumerate(ele_list):
            ele_list[j] = int(k)-1 
        alt_ind.iloc[i,1] = ele_list

In [23]:
### Manual structure matching for Gly_x-1-9.pdb
termini_df = data.iloc[[192,194]] # 192 -> 8, 194 -> 10
termini_df = termini_df.reset_index(drop=True)
ref_termini = [Path('/hits/basement/mbm/riedmiki/structures/KR0008/reference_structures/Gly_x-1-9.pdb'), Path('/hits/basement/mbm/riedmiki/structures/KR0008/reference_structures/Gly_x-1-9.pdb')]
idx_termini = [8,10]
idx_termini_list = [[8,],[10,]]
matched = pd.DataFrame(zip(ref_termini, idx_termini, idx_termini_list), columns=['ref', 'ref_idx', 'alt_idx'])
termini_df = termini_df.join(matched)

In [24]:
# to pad missing BDEs
N_term_BDE = float(data[data['names'] == 'Nterminus_amino']['BDE_H'])
N_term_BDE_row = data[data['names'] == 'Nterminus_amino']

In [25]:
# drop data for neutral arginine [8-13], aspartic acid [19-21], backbone [22], C termini [23-24],
# glutamic acid [54-57], cationic histidine [59-64], His_pi_beta [66]
# uncharged lysine [126-131], 
# N termini [136-138], pyd crosslinks [152-168], acetylated and N-amino formylated termini [192-195]

######################################

# hlknl crosslinks [75-86],
data = data.drop(index=[
    8, 9, 10, 11, 12, 13, 19, 20, 21, 22, 23, 24, 54, 55, 56, 57, 59, 60, 61, 62, 63, 64, 66, 75, 76, 77, 78, 79, 80, 81, 82,
    83, 84, 85, 86, 93, 94, 126, 127, 128, 129, 130, 131, 136, 137, 138, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161,
    162, 163, 164, 165, 166, 167, 168, 192, 193, 194, 195
])
alt_ind = alt_ind.drop(index=[
    8, 9, 10, 11, 12, 13, 19, 20, 21, 22, 23, 24, 54, 55, 56, 57, 59, 60, 61, 62, 63, 64, 66, 75, 76, 77, 78, 79, 80, 81, 82,
    83, 84, 85, 86, 93, 94, 126, 127, 128, 129, 130, 131, 136, 137, 138, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161,
    162, 163, 164, 165, 166, 167, 168, 192, 193, 194, 195
])
data = data.reset_index(drop=True)
alt_ind = alt_ind.reset_index(drop=True)

In [26]:
### average BDEs for cis and trans conformers of proline and hydroxyproline
# Hyp
alpha_cis = data[data['names'] == 'Hyp_cis_alpha']
alpha_trans = data[data['names'] == 'Hyp_trans_alpha']
pyrr3_cis = data[data['names'] == 'Hyp_cis_pyrr3']
pyrr3_trans = data[data['names'] == 'Hyp_trans_pyrr3']
pyrr4_onC_cis = data[data['names'] == 'Hyp_cis_pyrr4_onC']
pyrr4_onC_trans = data[data['names'] == 'Hyp_trans_pyrr4_onC']
pyrr4_onO_cis = data[data['names'] == 'Hyp_cis_pyrr4_onO']
pyrr4_onO_trans = data[data['names'] == 'Hyp_trans_pyrr4_onO']
pyrr5_cis = data[data['names'] == 'Hyp_cis_pyrr5']
pyrr5_trans = data[data['names'] == 'Hyp_trans_pyrr5']

alpha_idx = alpha_cis.index[0]
pyrr3_idx = pyrr3_cis.index[0]
pyrr4_onC_idx = pyrr4_onC_cis.index[0]
pyrr4_onO_idx = pyrr4_onO_cis.index[0]
pyrr5_idx = pyrr5_cis.index[0]

data.iloc[alpha_idx,0] = 'Hyp_alpha'
data.iloc[pyrr3_idx,0] = 'Hyp_pyrr3'
data.iloc[pyrr4_onC_idx,0] = 'Hyp_pyrr4_onC'
data.iloc[pyrr4_onO_idx,0] = 'Hyp_pyrr4_onO'
data.iloc[pyrr5_idx,0] = 'Hyp_pyrr5'

data.iloc[alpha_idx,1] = (alpha_cis.iloc[0,1] + alpha_trans.iloc[0,1])/2
data.iloc[pyrr3_idx,1] = (pyrr3_cis.iloc[0,1] + pyrr3_trans.iloc[0,1])/2
data.iloc[pyrr4_onC_idx,1] = (pyrr4_onC_cis.iloc[0,1] + pyrr4_onC_trans.iloc[0,1])/2
data.iloc[pyrr4_onO_idx,1] = (pyrr4_onO_cis.iloc[0,1] + pyrr4_onO_trans.iloc[0,1])/2
data.iloc[pyrr5_idx,1] = (pyrr5_cis.iloc[0,1] + pyrr5_trans.iloc[0,1])/2

data.iloc[alpha_idx,2] = (alpha_cis.iloc[0,2] + alpha_trans.iloc[0,2])/2
data.iloc[pyrr3_idx,2] = (pyrr3_cis.iloc[0,2] + pyrr3_trans.iloc[0,2])/2
data.iloc[pyrr4_onC_idx,2] = (pyrr4_onC_cis.iloc[0,2] + pyrr4_onC_trans.iloc[0,2])/2
data.iloc[pyrr4_onO_idx,2] = (pyrr4_onO_cis.iloc[0,2] + pyrr4_onO_trans.iloc[0,2])/2
data.iloc[pyrr5_idx,2] = (pyrr5_cis.iloc[0,2] + pyrr5_trans.iloc[0,2])/2

data = data.drop(index=[
    alpha_trans.index[0], pyrr3_trans.index[0], pyrr4_onC_trans.index[0], pyrr4_onO_trans.index[0], pyrr5_trans.index[0]
])
data = data.reset_index(drop=True)

alt_ind = alt_ind.drop(index=[
    alpha_trans.index[0], pyrr3_trans.index[0], pyrr4_onC_trans.index[0], pyrr4_onO_trans.index[0], pyrr5_trans.index[0]
])
alt_ind = alt_ind.reset_index(drop=True)

In [27]:
# Pro
alpha_cis = data[data['names'] == 'Pro_cis_alpha']
alpha_trans = data[data['names'] == 'Pro_trans_alpha']
pyrr3_cis = data[data['names'] == 'Pro_cis_pyrr3']
pyrr3_trans = data[data['names'] == 'Pro_trans_pyrr3']
pyrr4_cis = data[data['names'] == 'Pro_cis_pyrr4']
pyrr4_trans = data[data['names'] == 'Pro_trans_pyrr4']
pyrr5_cis = data[data['names'] == 'Pro_cis_pyrr5']
pyrr5_trans = data[data['names'] == 'Pro_trans_pyrr5']

alpha_idx = alpha_cis.index[0]
pyrr3_idx = pyrr3_cis.index[0]
pyrr4_idx = pyrr4_cis.index[0]
pyrr5_idx = pyrr5_cis.index[0]

data.iloc[alpha_idx,0] = 'Pro_alpha'
data.iloc[pyrr3_idx,0] = 'Pro_pyrr3'
data.iloc[pyrr4_idx,0] = 'Pro_pyrr4'
data.iloc[pyrr5_idx,0] = 'Pro_pyrr5'

data.iloc[alpha_idx,1] = (alpha_cis.iloc[0,1] + alpha_trans.iloc[0,1])/2
data.iloc[pyrr3_idx,1] = (pyrr3_cis.iloc[0,1] + pyrr3_trans.iloc[0,1])/2
data.iloc[pyrr4_idx,1] = (pyrr4_cis.iloc[0,1] + pyrr4_trans.iloc[0,1])/2
data.iloc[pyrr5_idx,1] = (pyrr5_cis.iloc[0,1] + pyrr5_trans.iloc[0,1])/2

data.iloc[alpha_idx,2] = (alpha_cis.iloc[0,2] + alpha_trans.iloc[0,2])/2
data.iloc[pyrr3_idx,2] = (pyrr3_cis.iloc[0,2] + pyrr3_trans.iloc[0,2])/2
data.iloc[pyrr4_idx,2] = (pyrr4_cis.iloc[0,2] + pyrr4_trans.iloc[0,2])/2
data.iloc[pyrr5_idx,2] = (pyrr5_cis.iloc[0,2] + pyrr5_trans.iloc[0,2])/2

data = data.drop(index=[
    alpha_trans.index[0], pyrr3_trans.index[0], pyrr4_trans.index[0], pyrr5_trans.index[0]
])
data = data.reset_index(drop=True)

alt_ind = alt_ind.drop(index=[
    alpha_trans.index[0], pyrr3_trans.index[0], pyrr4_trans.index[0], pyrr5_trans.index[0]
])
alt_ind = alt_ind.reset_index(drop=True)

In [28]:
results = data.join(data.apply(add_ref_idx, axis=1, result_type="expand"))

Structure:  Ala_alpha--> Ala-0.pdb
Structure:  Ala_beta--> Ala-0.pdb
Structure:  Arg_c1_alpha--> Arg_c1-0.pdb
Structure:  Arg_c1_beta--> Arg_c1-0.pdb
Structure:  Arg_c1_gamma--> Arg_c1-0.pdb
Structure:  Arg_c1_delta--> Arg_c1-0.pdb
Structure:  Arg_c1_epsilon--> Arg_c1-0.pdb
Structure:  Arg_c1_guan--> Arg_c1-0.pdb
Structure:  Asn_alpha--> Asn-0.pdb
Structure:  Asn_beta--> Asn-0.pdb
Structure:  Asn_amide--> Asn-0.pdb
Structure:  Asp_c-1_alpha--> Asp_c-1-0.pdb
Structure:  Asp_c-1_beta--> Asp_c-1-0.pdb
Structure:  Cys_alpha--> Cys-0.pdb
Structure:  Cys_beta--> Cys-0.pdb
Structure:  Cys_sulphur--> Cys-0.pdb
Structure:  Dopa_alpha--> Dop-0.pdb
Structure:  Dopa_beta--> Dop-0.pdb
Structure:  Dopa_ortho1--> Dop-0.pdb
Structure:  Dopa_OH_meta--> Dop-0.pdb
Structure:  Dopa_OH_para--> Dop-0.pdb
Structure:  Dopa_meta2--> Dop-0.pdb
Structure:  Dopa_ortho2--> Dop-0.pdb
Structure:  Dopa_c-1meta_alpha--> Dop_c-1_b-0.pdb
Structure:  Dopa_c-1meta_beta--> Dop_c-1_b-0.pdb
Structure:  Dopa_c-1meta_ortho1-->

In [29]:
alt_ind.drop(columns=['names',], inplace=True)
results = results.join(alt_ind)
results = results.append(termini_df)
results = results.reset_index(drop=True)

  results = results.append(termini_df)


In [48]:
results.to_pickle('BDE_df_compl')

In [49]:
df_tidy_idx = pd.read_pickle('/hits/basement/mbm/riedmiki/structures/KR0008/df_tidy_pckl_220401_idx')

In [None]:
ref_stem = pd.DataFrame([str(ref.resolve())[:-6] for ref in results['ref']], columns = ['ref_stem',])
results = results.join(ref_stem)

In [None]:
name_rad = []; name_H = []; ref_comp_rad = []; ref_comp_H = []
BDEs_sorted_rad = []; BDEs_sorted_H = []; BDEs_G_sorted_rad = []; BDEs_G_sorted_H = []

for rad_ref, rad_ref_idx in zip(df_tidy_idx['rad_ref'], df_tidy_idx['rad_ref_idx']):
    ref_path = str(rad_ref.resolve())[:-6]
    found = results[results['ref_stem']==ref_path]
    to_drop = []
    for i,l in zip(found.index, found['alt_idx']):
        if int(rad_ref_idx) not in l:
            to_drop.append(i)
    found.drop(to_drop, inplace=True)
    if found.shape[0]>0:
        idx = found.index[0]
        name_rad.append(results.iloc[idx,0])
        BDEs_sorted_rad.append(results.iloc[idx,1])
        BDEs_G_sorted_rad.append(results.iloc[idx,2])
        ref_comp_rad.append(results.iloc[idx,6])
    else:
        name_rad.append(np.nan)
        BDEs_sorted_rad.append(np.nan)
        BDEs_G_sorted_rad.append(np.nan)
        ref_comp_rad.append(np.nan)

for h_ref, h_ref_idx in zip(df_tidy_idx['h_ref'], df_tidy_idx['h_ref_idx']):
    ref_path = str(h_ref.resolve())[:-6]
    found = results[results['ref_stem']==ref_path]
    to_drop = []
    for i,l in zip(found.index, found['alt_idx']):
        if int(h_ref_idx) not in l:
            to_drop.append(i)
    found.drop(to_drop, inplace=True)
    if found.shape[0]>0:
        idx = found.index[0]
        name_H.append(results.iloc[idx,0])
        BDEs_sorted_H.append(results.iloc[idx,1])
        BDEs_G_sorted_H.append(results.iloc[idx,2])
        ref_comp_H.append(results.iloc[idx,6])
    else:
        name_H.append(np.nan)
        BDEs_sorted_H.append(np.nan)
        BDEs_G_sorted_H.append(np.nan)
        ref_comp_H.append(np.nan)

In [None]:
BDEs_df = pd.DataFrame(
    zip(name_rad, name_H, ref_comp_rad, ref_comp_H, BDEs_sorted_rad, BDEs_sorted_H, BDEs_G_sorted_rad, BDEs_G_sorted_H),
    columns = ['rad_chem_name', 'H_chem_name', 'rad_ref_comp', 'H_ref_comp', 'rad_BDE', 'H_BDE', 'rad_BDE_G', 'H_BDE_G']
)

In [None]:
# reactions for which a BDE is missing involve reactions with an amine group after a backbone 
# break, so let's pad these values with the BDE for a neutral N terminus
BDEs_df = BDEs_df.fillna(N_term_BDE)

In [None]:
complete = df_tidy_idx.join(BDEs_df)

In [None]:
complete.to_pickle('data_complete_compl')

In [None]:
BDE_data = pd.read_pickle('BDE_df_compl')
N_term_BDE_row = N_term_BDE_row.join(
    pd.DataFrame({'ref': 'ref', 'ref_idx': 1, 'alt_idx': [1]})
)
BDE_data = pd.concat([BDE_data, N_term_BDE_row])
BDE_data = BDE_data.reset_index(drop=True)

In [None]:
rad_BDE = complete['rad_BDE'].to_list()
H_BDE = complete['H_BDE'].to_list()

rad_PDB = []
for rad in rad_BDE:
    idx = BDE_data[BDE_data['BDE_H'] == rad].index[0]
    rad_PDB.append(BDE_data.iloc[idx, 4])

H_PDB = []
for H in H_BDE:
    idx = BDE_data[BDE_data['BDE_H'] == H].index[0]
    H_PDB.append(BDE_data.iloc[idx, 4])

In [None]:
descriptors_df = pd.read_csv('/hits/fast/mbm/treydewk/optimized_test/descriptors.csv')


In [None]:
# Mordred errors
to_drop = []
for column in descriptors_df.columns:
    if isinstance(descriptors_df[column].to_list()[1], str):
        to_drop.append(column)
del to_drop[to_drop.index('names')], to_drop[to_drop.index('pdb')], to_drop[to_drop.index('SMILES')]

In [None]:
descriptors_df.drop(to_drop, axis=1, inplace=True)

In [None]:
### average descriptors for cis and trans conformers of proline and hydroxyproline
# Hyp
alpha_cis = descriptors_df[descriptors_df['names'] == 'hyp_cis_pyrr2_1']
alpha_trans = descriptors_df[descriptors_df['names'] == 'Hyp_trans_pyrr2_1']
pyrr3_cis = descriptors_df[descriptors_df['names'] == 'hyp_cis_pyrr3_1']
pyrr3_trans = descriptors_df[descriptors_df['names'] == 'Hyp_trans_pyrr3_1']
pyrr4_onC_cis = descriptors_df[descriptors_df['names'] == 'hyp_cis_pyrr4_onC_1']
pyrr4_onC_trans = descriptors_df[descriptors_df['names'] == 'Hyp_trans_pyrr4_onC_1']
pyrr4_onO_cis = descriptors_df[descriptors_df['names'] == 'hyp_cis_pyrr4_onO_1']
pyrr4_onO_trans = descriptors_df[descriptors_df['names'] == 'Hyp_trans_pyrr4_onO_1']
pyrr5_cis = descriptors_df[descriptors_df['names'] == 'hyp_cis_pyrr5_1']
pyrr5_trans = descriptors_df[descriptors_df['names'] == 'Hyp_trans_pyrr5_1']

alpha_idx = alpha_cis.index[0]
pyrr3_idx = pyrr3_cis.index[0]
pyrr4_onC_idx = pyrr4_onC_cis.index[0]
pyrr4_onO_idx = pyrr4_onO_cis.index[0]
pyrr5_idx = pyrr5_cis.index[0]

for i in range(4,descriptors_df.shape[1]):
    descriptors_df.iloc[alpha_idx,i] = (alpha_cis.iloc[0,i] + alpha_trans.iloc[0,i])/2
    descriptors_df.iloc[pyrr3_idx,i] = (pyrr3_cis.iloc[0,i] + pyrr3_trans.iloc[0,i])/2
    descriptors_df.iloc[pyrr4_onC_idx,i] = (pyrr4_onC_cis.iloc[0,i] + pyrr4_onC_trans.iloc[0,i])/2
    descriptors_df.iloc[pyrr4_onO_idx,i] = (pyrr4_onO_cis.iloc[0,i] + pyrr4_onO_trans.iloc[0,i])/2
    descriptors_df.iloc[pyrr5_idx,i] = (pyrr5_cis.iloc[0,i] + pyrr5_trans.iloc[0,i])/2

descriptors_df = descriptors_df.drop(index=[
    alpha_trans.index[0], pyrr3_trans.index[0], pyrr4_onC_trans.index[0], pyrr4_onO_trans.index[0], pyrr5_trans.index[0]
])
descriptors_df = descriptors_df.reset_index(drop=True)

In [None]:
# Pro
alpha_cis = descriptors_df[descriptors_df['names'] == 'pro_cis_pyrr2_1']
alpha_trans = descriptors_df[descriptors_df['names'] == 'Pro_trans_pyrr2_1']
pyrr3_cis = descriptors_df[descriptors_df['names'] == 'pro_cis_pyrr3_1']
pyrr3_trans = descriptors_df[descriptors_df['names'] == 'Pro_trans_pyrr3_1']
pyrr4_cis = descriptors_df[descriptors_df['names'] == 'pro_cis_pyrr4_1']
pyrr4_trans = descriptors_df[descriptors_df['names'] == 'Pro_trans_pyrr4_1']
pyrr5_cis = descriptors_df[descriptors_df['names'] == 'pro_cis_pyrr5_1']
pyrr5_trans = descriptors_df[descriptors_df['names'] == 'Pro_trans_pyrr5_1']

alpha_idx = alpha_cis.index[0]
pyrr3_idx = pyrr3_cis.index[0]
pyrr4_idx = pyrr4_cis.index[0]
pyrr5_idx = pyrr5_cis.index[0]

for i in range(4,descriptors_df.shape[1]):
    descriptors_df.iloc[alpha_idx,i] = (alpha_cis.iloc[0,i] + alpha_trans.iloc[0,i])/2
    descriptors_df.iloc[pyrr3_idx,i] = (pyrr3_cis.iloc[0,i] + pyrr3_trans.iloc[0,i])/2
    descriptors_df.iloc[pyrr4_idx,i] = (pyrr4_cis.iloc[0,i] + pyrr4_trans.iloc[0,i])/2
    descriptors_df.iloc[pyrr5_idx,i] = (pyrr5_cis.iloc[0,i] + pyrr5_trans.iloc[0,i])/2

descriptors_df = descriptors_df.drop(index=[
    alpha_trans.index[0], pyrr3_trans.index[0], pyrr4_trans.index[0], pyrr5_trans.index[0]
])
descriptors_df = descriptors_df.reset_index(drop=True)

In [None]:
descriptors_sorted_rad = pd.DataFrame()
descriptors_sorted_H = pd.DataFrame()

for pdb in rad_PDB:
    idx = descriptors_df[descriptors_df['pdb']==pdb].index[0]
    descriptors_sorted_rad = descriptors_sorted_rad.append(descriptors_df.iloc[idx])

for pdb in H_PDB:
    idx = descriptors_df[descriptors_df['pdb']==pdb].index[0]
    descriptors_sorted_H = descriptors_sorted_H.append(descriptors_df.iloc[idx])

In [None]:
descriptors_sorted_rad = descriptors_sorted_rad.drop(columns = [descriptors_sorted_rad.columns[0], 'names'])
descriptors_sorted_H = descriptors_sorted_H.drop(columns = [descriptors_sorted_H.columns[0], 'names'])

In [None]:
for column in descriptors_sorted_rad.columns:
    descriptors_sorted_rad.rename(columns = {column: '{}_rad'.format(column)}, inplace = True)

for column in descriptors_sorted_H.columns:
    descriptors_sorted_H.rename(columns = {column: '{}_H'.format(column)}, inplace = True)

In [None]:
indices = complete.index
descriptors_sorted_rad = descriptors_sorted_rad.set_index(indices)
descriptors_sorted_H = descriptors_sorted_H.set_index(indices)

In [None]:
final_results = complete.join(descriptors_sorted_rad)
final_results = final_results.join(descriptors_sorted_H)

In [None]:
final_results.to_pickle('data_complete_w_descriptors_020422')

## Prepare data for GNNs

In [None]:
from kgcnn.utils.adj import coordinates_to_distancematrix, define_adjacency_from_distance, distance_to_gauss_basis, get_angle_indices, sort_edge_indices
from ase.io import read
import pandas as pd
import numpy as np
import re
from pathlib import Path
from itertools import chain

In [None]:
root = Path('/hits/basement/mbm/riedmiki/structures/KR0008/')

se_folders = list((root/'traj').glob('batch*/se'))

se_folders = se_folders + [root/'start_end_prod_1', root/'start_end_prod_2', root/'start_end_prod_3', root/'start_end_prod_4', root/'start_end_prod_6',
root/'start_end_prod_7', root/'start_end_prod_8', root/'start_end_prod_9', root/'start_end_prod_10', root/'start_end_prod_11', root/'start_end_prod_intra_2']

pdb_files_all = [f for f in root.glob('**/*.pdb') if (re.search('_1.pdb', f.name) or re.search('_2.pdb', f.name))]

data = pd.read_pickle('data_complete_w_descriptors_020422')
hashes_se = []
for hash1, hash2 in zip(data['hash_u1'], data['hash_u2']):
    hashes_se.append([str(hash1) + '_' + str(hash2) + '_1.pdb', str(hash1) + '_' + str(hash2) + '_2.pdb'])

pdb_files = [f for f in pdb_files_all if f.name in list(chain.from_iterable(hashes_se))]
directions = data['reaction'].to_list()

stems = []
for file in pdb_files:
    stems.append(file.stem)

pdb_files_sorted_start = []
pdb_files_sorted_end = []

for hash, direction in zip(hashes_se, directions):
    if direction == 1:
        idx = stems.index(hash[0][:-4])
        pdb_files_sorted_start.append(pdb_files[idx])
        idx = stems.index(hash[1][:-4])
        pdb_files_sorted_end.append(pdb_files[idx])

    if direction == 2:
        idx = stems.index(hash[1][:-4])
        pdb_files_sorted_start.append(pdb_files[idx])
        idx = stems.index(hash[0][:-4])
        pdb_files_sorted_end.append(pdb_files[idx])

all_nodes_start = []
all_pos_start = []

for file_start, file_end in zip(pdb_files_sorted_start, pdb_files_sorted_end):

    mol_start = read(str(file_start.resolve()))
    mol_end = read(str(file_end.resolve()))

    an_start = mol_start.get_atomic_numbers()
    pos_start = mol_start.positions
    pos_end = mol_end.positions

    # reacting H atom in its final position
    nodes_start = np.concatenate((np.array([0]), an_start), axis=0)
    pos_start_compl = np.concatenate((np.array([pos_end[0]]), pos_start), axis=0)

    all_nodes_start.append(nodes_start)
    all_pos_start.append(pos_start_compl)

# create distance matrices, adjacency matrices and edge indices
dist_mat_start = [coordinates_to_distancematrix(x) for x in all_pos_start]
adj_mat_start = [define_adjacency_from_distance(x)[0] for x in dist_mat_start]
edge_idx_start = [define_adjacency_from_distance(x)[1] for x in dist_mat_start]

graph_input = pd.DataFrame(
    zip(all_nodes_start, all_pos_start, edge_idx_start),
    columns = [
        'nodes_start', 'pos_start', 'egde_idx_start'
    ]
)

graph_input.to_pickle('graph_input_pickled_020422')

## Train and save final model

In [None]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import tensorflow as tf
from kgcnn.utils.data import ragged_tensor_from_nested_numpy
from kgcnn.layers.conv.painn_conv import PAiNNUpdate, EquivariantInitialize
from kgcnn.layers.conv.painn_conv import PAiNNconv
from kgcnn.layers.geom import NodeDistanceEuclidean, BesselBasisLayer, EdgeDirectionNormalized, CosCutOffEnvelope, NodePosition
from kgcnn.layers.modules import LazyAdd, OptionalInputEmbedding
from kgcnn.layers.mlp import MLP
from kgcnn.layers.pooling import PoolingNodes
from kgcnn.layers.gather import GatherNodes, GatherNodesIngoing
from kgcnn.utils.adj import coordinates_to_distancematrix, define_adjacency_from_distance, distance_to_gauss_basis, get_angle_indices, sort_edge_indices

In [None]:
def painn(inputs=[{"shape": (None,), "name": "node_attributes", "dtype": "float32", "ragged": True},
                            {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32", "ragged": True},
                            {"shape": (None, 2), "name": "edge_indices", "dtype": "int64", "ragged": True},
                            {'shape': (None, 2), 'name': 'node_radical_input', 'dtype': 'int64', 'ragged': True},
                            {'shape': (None, 1), 'name': 'edge_radical_input', 'dtype': 'int64', 'ragged': True},
                            {'shape': (25,), 'name': 'input_desc', 'dtype': 'float32', 'ragged': True}],
               input_embedding={"node": {"input_dim": 95, "output_dim": 128}},
               bessel_basis={"num_radial": 20, "cutoff": 5.0, "envelope_exponent": 5},
               depth=2,
               pooling_args={"pooling_method": "sum"},
               conv_args={"units": 128, "cutoff": None, "conv_pool": "sum"},
               update_args={"units": 128},
               output_mlp={"use_bias": [True, True, True, True, True],
                                "units": [512, 256, 128, 64, 1], "activation": ["swish", "swish", "swish", "swish", "linear"]}
               ):
    """Make PAiNN graph network via functional API.

    Args:
        inputs : list
            List of dictionaries unpacked in :obj:`tf.keras.layers.Input`. Order must match model definition.
        input_embedding : dict
            Dictionary of embedding arguments for nodes etc. unpacked in `Embedding` layers.
        bessel_basis : dict
            Dictionary of layer arguments unpacked in final `BesselBasisLayer` layer.
        depth : int
            Number of graph embedding units or depth of the network.
        pooling_args : dict
            Dictionary of layer arguments unpacked in `PoolingNodes` layer.
        conv_args : dict
            Dictionary of layer arguments unpacked in `PAiNNconv` layer.
        update_args : dict
            Dictionary of layer arguments unpacked in `PAiNNUpdate` layer.
        output_mlp : dict
            Dictionary of layer arguments unpacked in the final classification `MLP` layer block.
            Defines number of model outputs and activation.

    Returns:
        tf.keras.models.Model
    """

    # Make input
    node_input = tf.keras.layers.Input(**inputs[0])
    xyz_input = tf.keras.layers.Input(**inputs[1])
    bond_index_input = tf.keras.layers.Input(**inputs[2])
    node_radical_index = tf.keras.layers.Input(**inputs[3])
    edge_radical_index = tf.keras.layers.Input(**inputs[4])
    eri_n = node_radical_index
    eri_e = edge_radical_index
    descriptors = tf.keras.layers.Input(**inputs[5])
    z = OptionalInputEmbedding(**input_embedding['node'],
                               use_embedding=len(inputs[0]['shape']) < 2)(node_input)

    equiv_input = EquivariantInitialize(dim=3)(z)

    edi = bond_index_input
    x = xyz_input
    v = equiv_input

    pos1, pos2 = NodePosition()([x, edi])
    rij = EdgeDirectionNormalized()([pos1, pos2])
    d = NodeDistanceEuclidean()([pos1, pos2])
    env = CosCutOffEnvelope(conv_args["cutoff"])(d)
    rbf = BesselBasisLayer(**bessel_basis)(d)

    for i in range(depth):
        # Message
        ds, dv = PAiNNconv(**conv_args)([z, v, rbf, env, rij, edi])
        z = LazyAdd()([z, ds])
        v = LazyAdd()([v, dv])
        # Update
        ds, dv = PAiNNUpdate(**update_args)([z, v])
        z = LazyAdd()([z, ds])
        v = LazyAdd()([v, dv])
    n = z

    n_radical = GatherNodes()([n, eri_n])
    e_radical = GatherNodesIngoing()([rbf, eri_e])

    rad_embedd = tf.keras.layers.Concatenate(axis=-1)([n_radical, e_radical])
    rad_embedd = PoolingNodes(**pooling_args)(rad_embedd)

    out = tf.keras.layers.Concatenate()([rad_embedd, descriptors])

    initial_output = MLP(**output_mlp)(out)
    concat_input = tf.keras.layers.Concatenate(axis=-1)([out, initial_output])
    main_output = MLP(**output_mlp)(concat_input)

    model = tf.keras.models.Model(inputs=[
        node_input, xyz_input, bond_index_input, node_radical_index, edge_radical_index, descriptors
    ], outputs=main_output)
    
    return model

In [None]:
painn = painn()

data = pd.read_pickle('data_complete_w_descriptors_020422')
target = data['Ea'].to_numpy()
descriptors = data[[
    'translation', 'rad_BDE', 'H_BDE', 'max_spin_rad',
    'mull_charge_rad', 'bur_vol_iso_rad', 'nBase_rad', 'SpMax_A_rad',
    'ATSC2s_rad', 'ATSC1Z_rad', 'ATSC2i_rad', 'NdNH_rad', 'SMR_VSA4_rad',
    'max_spin_H', 'mull_charge_H', 'bur_vol_iso_H', 'nBase_H', 'SpMax_A_H',
    'ATSC2s_H', 'ATSC1Z_H', 'ATSC2i_H', 'GATS2dv_H', 
    'BCUTdv-1h_H', 'SMR_VSA4_H', 'VSA_EState7_H'
]]
del data

descriptors = np.array(descriptors)

graph_input = pd.read_pickle('graph_input_pickled_020422')
nodes = graph_input['nodes_start'].to_numpy()
pos = graph_input['pos_start'].to_numpy()
edge_idx = graph_input['egde_idx_start'].to_numpy()

dist_mat_start = [coordinates_to_distancematrix(x) for x in pos]
adj_mat_start = [define_adjacency_from_distance(x)[0] for x in dist_mat_start]
edge_idx = [x if x[0,1]==1 else sort_edge_indices(np.concatenate([np.array([[0, 1], [1,0]]), x], axis=0)) for x in edge_idx]
edge_idx = np.array(edge_idx)
node_radical_index = [np.array([[0, 1]]) for _ in edge_idx]  
edge_radical_index = [np.array([[0]]) for _ in node_radical_index]
node_radical_index = np.array(node_radical_index)
edge_radical_index = np.array(edge_radical_index)
del graph_input

lr_schedule = tf.keras.optimizers.schedules.InverseTimeDecay(
    0.001, decay_steps=15122*0.8/256*1000, decay_rate=1, staircase=False
)
optimizer=tf.keras.optimizers.Adam(lr_schedule)
callbacks=tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20)

painn.compile(loss='mae', optimizer=optimizer)

targets_train, targets_val, nodes_train, nodes_val, pos_train, pos_val, \
    edge_idx_train, edge_idx_val, eri_n_train, eri_n_val, \
        eri_e_train, eri_e_val, descriptors_train, descriptors_val = train_test_split(
            target, nodes, pos, edge_idx, node_radical_index, edge_radical_index, descriptors,
            test_size = 0.2, random_state=1
        )

# normalize descriptors
means = descriptors_train.mean(axis=0)
stds = descriptors_train.std(axis=0)
epsilon = 1e-7
descriptors_train = (descriptors_train - means) / (stds + epsilon)
descriptors_val = (descriptors_val - means) / (stds + epsilon)

nodes_train, nodes_val = ragged_tensor_from_nested_numpy(nodes_train), ragged_tensor_from_nested_numpy(nodes_val)
pos_train, pos_val = ragged_tensor_from_nested_numpy(pos_train), ragged_tensor_from_nested_numpy(pos_val)
edge_idx_train, edge_idx_val = ragged_tensor_from_nested_numpy(edge_idx_train), ragged_tensor_from_nested_numpy(edge_idx_val)
eri_n_train, eri_n_val = ragged_tensor_from_nested_numpy(eri_n_train), ragged_tensor_from_nested_numpy(eri_n_val)
eri_e_train, eri_e_val = ragged_tensor_from_nested_numpy(eri_e_train), ragged_tensor_from_nested_numpy(eri_e_val)

data_train = nodes_train, pos_train, edge_idx_train, eri_n_train, eri_e_train, descriptors_train
data_val = nodes_val, pos_val, edge_idx_val, eri_n_val, eri_e_val, descriptors_val

history = painn.fit(data_train, targets_train,
            batch_size=128,
            epochs=5000,
            verbose=0,
            validation_data=(data_val, targets_val),
            callbacks=callbacks)

train_score = painn.evaluate(data_train, targets_train, verbose=0, batch_size=128)
val_score = painn.evaluate(data_val, targets_val, verbose=0, batch_size=128)
print('Train score:', train_score, '\nValidation score:', val_score)

# painn.save('saved_models/painn_compl')


# Results
Batch size: 64, 5000 epochs.

Train score: 1.147171974182129 

Validation score: 4.122030735015869

------------------------------------------------------------------------------------------------
Batch size: 128, 5000 epochs.

Train score: 0.8191611766815186 

Validation score: 4.069589138031006

------------------------------------------------------------------------------------------------
Batch size: 256, 2000 epochs

Train score: 1.3984484672546387 

Validation score: 4.312300682067871