# import

In [2]:
import os

# Change path for loading config and data
os.chdir('D:\S_Cat\DL\molflux_git')

from itertools import chain
from pathlib import Path
from typing import Optional, List, Tuple, Union, Any, Dict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import xgboost
from rdkit import Chem
from rdkit.Chem import AllChem
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestRegressor
from sklearn.manifold import TSNE
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.preprocessing import StandardScaler

from src.data_processing.data_loader import DataLoader
from src.molecular_descriptors.descriptors.morfeus_descriptor import MorfeusDescriptor
from src.utils.config import get_config
from src.utils.functool import get_mol_indices
from src.utils.molecule_operator import merge_intermediates, merge_products

%matplotlib inline
plt.style.use('seaborn-v0_8-paper')
plt.rc('font', weight='bold')

In [3]:
config = get_config()
loader = DataLoader(config)
loader.load_data()
loader._raw_data['label_ee'] = abs(loader._raw_data['label_ee'])

# Remove duplicates
loader._raw_data = loader._raw_data.loc[
        sorted(loader._raw_data.groupby(['com_1', 'com_2', 'com_cat'])['label_ee'].idxmax())
].reset_index(drop=True)

metadata = loader.get_metadata()

# Extract and validate species data and descriptors
species_cols, species_smiles, species_calculators = loader.get_species_data()

# Get reaction conditions
condition_data = loader.get_condition_data()

# Get labels
label_data = loader.get_label_data()

# Morgan Fingerprint

In [4]:
mol_registry = pd.read_csv('data/calculators/mol_list.csv')

n_bits = 1024
radius = 3

inter_des_cols = [f'IM_{i}' for i in range(n_bits)]
product_des_cols = [f'Product_{i}' for i in range(n_bits)]
ddg_col = ['ddg']
transfer_dataset = []

for _, row in loader.data.iterrows():
    com_inter = merge_intermediates(row['com_2'], row['com_cat'])

    inter_idx = get_mol_indices(com_inter, mol_registry)[1][0]
    inter_smi = mol_registry.loc[inter_idx, 'smiles']
    mol = Chem.MolFromSmiles(inter_smi)
    inter_des = AllChem.GetMorganFingerprintAsBitVect(mol, radius, n_bits)

    com_product = merge_products(row['product_base'], row['com_2'])
    product_idx = get_mol_indices(com_product, mol_registry)[1][0]
    product_smi = mol_registry.loc[product_idx, 'smiles']
    mol = Chem.MolFromSmiles(product_smi)
    product_des = AllChem.GetMorganFingerprintAsBitVect(mol, radius, n_bits)

    ddg = [(8.314 * (row['con_temperature1'] + 273.15) * np.log((1 + abs(row['label_ee']) / 100) / (1 - abs(row['label_ee']) / 100))) / 4184]
    transfer_dataset.append([i for i in chain(inter_des, product_des, ddg)])



In [5]:
dataset_fp = pd.DataFrame(transfer_dataset, columns=[i for i in chain(inter_des_cols, product_des_cols, ddg_col)], dtype=float)
scaler = StandardScaler()
dataset_fp.iloc[:,:-1] = scaler.fit_transform(dataset_fp.iloc[:,:-1])

dataset_fp

Unnamed: 0,IM_0,IM_1,IM_2,IM_3,IM_4,IM_5,IM_6,IM_7,IM_8,IM_9,...,Product_1015,Product_1016,Product_1017,Product_1018,Product_1019,Product_1020,Product_1021,Product_1022,Product_1023,ddg
0,-0.097129,0.197066,-0.197066,-0.264575,-0.363696,-0.097129,0.0,0.0,-0.119239,0.0,...,0.0,0.0,-0.119239,-0.221404,-0.221404,-0.595341,-0.209529,2.341975,0.0,0.008685
1,-0.097129,0.197066,-0.197066,-0.264575,-0.363696,-0.097129,0.0,0.0,-0.119239,0.0,...,0.0,0.0,-0.119239,-0.221404,-0.221404,-0.595341,-0.209529,2.341975,0.0,0.352778
2,-0.097129,0.197066,-0.197066,-0.264575,-0.363696,-0.097129,0.0,0.0,-0.119239,0.0,...,0.0,0.0,-0.119239,-0.221404,-0.221404,-0.595341,-0.209529,2.341975,0.0,0.435554
3,-0.097129,0.197066,-0.197066,-0.264575,-0.363696,-0.097129,0.0,0.0,-0.119239,0.0,...,0.0,0.0,-0.119239,-0.221404,-0.221404,-0.595341,-0.209529,2.341975,0.0,0.315830
4,-0.097129,0.197066,-0.197066,-0.264575,-0.363696,-0.097129,0.0,0.0,-0.119239,0.0,...,0.0,0.0,-0.119239,-0.221404,-0.221404,-0.595341,-0.209529,2.341975,0.0,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
209,-0.097129,0.197066,-0.197066,-0.264575,-0.363696,-0.097129,0.0,0.0,-0.119239,0.0,...,0.0,0.0,-0.119239,-0.221404,-0.221404,-0.595341,4.772607,-0.426990,0.0,1.697295
210,-0.097129,0.197066,-0.197066,-0.264575,-0.363696,-0.097129,0.0,0.0,-0.119239,0.0,...,0.0,0.0,-0.119239,-0.221404,-0.221404,-0.595341,-0.209529,-0.426990,0.0,1.938683
211,-0.097129,0.197066,-0.197066,-0.264575,-0.363696,-0.097129,0.0,0.0,-0.119239,0.0,...,0.0,0.0,-0.119239,-0.221404,-0.221404,-0.595341,-0.209529,-0.426990,0.0,1.472363
212,-0.097129,0.197066,-0.197066,-0.264575,-0.363696,-0.097129,0.0,0.0,-0.119239,0.0,...,0.0,0.0,-0.119239,-0.221404,-0.221404,-0.595341,-0.209529,-0.426990,0.0,1.415376


# Unimol

In [9]:
mol_registry = pd.read_csv('data/calculators/mol_list.csv')
inter_des_dataset = pd.read_csv('data/calculators/optimizers/unimol/raw_data_inter_v.csv')
product_des_dataset = pd.read_csv('data/calculators/optimizers/unimol/raw_data_product_s.csv')
inter_des_cols = 'IM_' + inter_des_dataset.columns[1:]
product_des_cols = 'Product_' + product_des_dataset.columns[1:]

transfer_dataset = []

for _, row in loader.data.iterrows():
    com_inter = merge_intermediates(row['com_2'], row['com_cat'])

    inter_idx = get_mol_indices(com_inter, mol_registry)[1][0]
    inter_des = inter_des_dataset[inter_des_dataset['mol_index'] == inter_idx].values[0][1:]

    com_product = merge_products(row['product_base'], row['com_2'])
    product_idx = get_mol_indices(com_product, mol_registry)[1][0]
    product_des = product_des_dataset[product_des_dataset['mol_index'] == product_idx].values[0][1:]

    ddg = [(8.314 * (row['con_temperature1'] + 273.15) * np.log((1 + abs(row['label_ee']) / 100) / (1 - abs(row['label_ee']) / 100))) / 4184]
    transfer_dataset.append([i for i in chain(inter_des, product_des, ddg)])

In [10]:
dataset_unimol = pd.DataFrame(transfer_dataset, columns=[i for i in chain(inter_des_cols, product_des_cols, ddg_col)], dtype=float)
scaler = StandardScaler()
dataset_unimol.iloc[:,:-1] = scaler.fit_transform(dataset_unimol.iloc[:,:-1])

dataset_unimol

Unnamed: 0,IM_0,IM_1,IM_2,IM_3,IM_4,IM_5,IM_6,IM_7,IM_8,IM_9,...,Product_759,Product_760,Product_761,Product_762,Product_763,Product_764,Product_765,Product_766,Product_767,ddg
0,0.530284,-0.088286,0.086139,0.075797,0.000203,0.171243,-0.096752,-0.100565,0.381315,0.420819,...,-1.018849,-1.484706,0.326135,-1.919743,0.925217,1.136834,0.450484,0.261357,-0.614082,0.008685
1,0.750725,-1.555876,0.894305,0.042022,-0.886936,-0.175608,0.150251,0.751107,1.064364,0.233236,...,-1.018849,-1.484706,0.326135,-1.919743,0.925217,1.136834,0.450484,0.261357,-0.614082,0.352778
2,-0.629945,-0.233179,-0.484730,0.495893,-0.163752,0.811876,-0.107152,-0.706991,-0.004820,-0.199688,...,-1.018849,-1.484706,0.326135,-1.919743,0.925217,1.136834,0.450484,0.261357,-0.614082,0.435554
3,-0.122469,-0.379503,-0.221538,0.313055,0.363620,0.907601,0.380655,0.441350,0.486269,0.710422,...,-1.018849,-1.484706,0.326135,-1.919743,0.925217,1.136834,0.450484,0.261357,-0.614082,0.315830
4,0.241262,-0.563214,0.825832,0.529376,-0.002014,-0.303327,0.215769,1.291266,1.262006,0.608786,...,-1.018849,-1.484706,0.326135,-1.919743,0.925217,1.136834,0.450484,0.261357,-0.614082,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
386,0.841913,-1.570067,-0.632231,-0.769054,-2.109380,1.467026,-1.221665,-0.359151,-1.972196,-1.514197,...,0.494704,0.100591,-0.760473,-0.357386,0.278732,0.163532,0.852378,-0.380506,-0.934191,0.572981
387,0.346669,-1.806607,-0.723755,-0.925540,-1.964396,1.542761,-1.646060,-0.102382,-2.425199,-1.735600,...,0.494704,0.100591,-0.760473,-0.357386,0.278732,0.163532,0.852378,-0.380506,-0.934191,0.734607
388,-0.332129,0.377146,-0.439182,0.600874,-0.217716,0.477270,-0.273767,-0.672700,-0.202238,-0.751760,...,0.494704,0.100591,-0.760473,-0.357386,0.278732,0.163532,0.852378,-0.380506,-0.934191,0.671777
389,0.038818,-0.797087,-0.379537,0.093163,-0.201857,0.576847,-0.961537,-0.456968,-0.097388,-0.806565,...,0.494704,0.100591,-0.760473,-0.357386,0.278732,0.163532,0.852378,-0.380506,-0.934191,0.521415


# Morfeus

In [6]:
mol_registry = pd.read_csv('data/calculators/mol_list.csv')
inter_des_dataset = pd.read_csv('data/calculators/descriptors/morfeus/raw_data_inter_chg1_xtb_intermediate.csv')
product_des_dataset = pd.read_csv('data/calculators/descriptors/morfeus/raw_data_product_s_xtb_product_s.csv')

descriptor_start_idx = metadata.shape[1] + 1
inter_des_cols = 'IM_' + inter_des_dataset.columns[descriptor_start_idx:]
product_des_cols = 'Product_' + product_des_dataset.columns[descriptor_start_idx:]
ddg_col = ['ddg']
transfer_dataset = []
transfer_metadata = []

for _, row in loader.data.iterrows():
    # Skip Se compounds
    # if 'Se' in row['com_2'] or 'Se' in row['com_cat']:
    #     continue

    # Get intermediate index and descriptors
    com_inter = merge_intermediates(row['com_2'], row['com_cat'])
    inter_idx = get_mol_indices(com_inter, mol_registry)[1][0]
    inter_des = inter_des_dataset[inter_des_dataset['index'] == inter_idx].values[0][descriptor_start_idx:]

    # Get product index and descriptors
    com_product = merge_products(row['product_base'], row['com_2'])
    product_idx = get_mol_indices(com_product, mol_registry)[1][0]
    product_des = product_des_dataset[product_des_dataset['index'] == product_idx].values[0][descriptor_start_idx:]

    ddg = [(8.314 * (row['con_temperature1'] + 273.15) * np.log((1 + abs(row['label_ee']) / 100) / (1 - abs(row['label_ee']) / 100))) / 4184]
    transfer_dataset.append([i for i in chain(inter_des, product_des, ddg)])
    transfer_metadata.append(row['ID'])

In [7]:
dataset_morfeus = pd.DataFrame(transfer_dataset, columns=[i for i in chain(inter_des_cols, product_des_cols, ddg_col)])
scaler = StandardScaler()
dataset_morfeus.iloc[:,:-1] = scaler.fit_transform(dataset_morfeus.iloc[:,:-1])

dataset_morfeus

Unnamed: 0,IM_area,IM_volume,IM_P_bv,IM_S_cat_bv,IM_S_S_bv,IM_S_cat_darea,IM_S_cat_pint,IM_S_S_darea,IM_S_S_pint,IM_S_cat_P,...,Product_S_sarea,Product_S_svol,Product_S_toC1_Lvalue,Product_S_toC1_B1value,Product_S_toC1_B5value,Product_S_toC2_Lvalue,Product_S_toC2_B1value,Product_S_toC2_B5value,Product_S_charge,ddg
0,-0.833509,-0.995285,-0.534082,-1.099627,-0.516550,0.673135,-1.698068,0.533280,-1.459036,0.656226,...,-0.460393,-0.513940,-1.170500,-0.823927,1.869959,2.187683,0.129829,-1.461302,1.343890,0.008685
1,-0.829646,-0.985986,-1.450911,-0.322112,-1.111342,0.691723,-0.323518,1.422839,-1.127034,-0.502269,...,-0.460393,-0.513940,-1.170500,-0.823927,1.869959,2.187683,0.129829,-1.461302,1.343890,0.352778
2,-0.519034,-0.493099,0.403724,1.607775,1.248712,-1.295073,-0.334774,-0.745461,0.560810,-0.531386,...,-0.460393,-0.513940,-1.170500,-0.823927,1.869959,2.187683,0.129829,-1.461302,1.343890,0.435554
3,-0.280945,-0.363175,-0.125869,0.938287,-0.477941,-0.161531,0.742236,0.700073,-0.042578,1.647746,...,-0.460393,-0.513940,-1.170500,-0.823927,1.869959,2.187683,0.129829,-1.461302,1.343890,0.315830
4,-1.379383,-1.341461,-2.015150,0.344202,-0.846745,-0.109818,-0.185315,1.478437,-0.011101,-0.910004,...,-0.460393,-0.513940,-1.170500,-0.823927,1.869959,2.187683,0.129829,-1.461302,1.343890,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
209,0.069181,0.354442,-0.034657,0.383681,1.411505,-0.777124,0.442731,-1.496026,1.400807,-0.994492,...,-0.976149,-1.054103,-1.164548,0.294382,-0.388388,0.176891,2.442160,-0.792096,-0.455351,1.697295
210,1.059990,1.250997,0.794024,0.852094,0.823820,-1.450458,0.237118,-1.273636,0.289487,-0.089840,...,-0.937040,-0.984965,0.351385,0.301634,-0.392734,0.176803,2.464622,0.169632,-0.471975,1.938683
211,0.638332,0.843898,0.786011,0.836480,0.814024,-1.372766,0.341709,-1.357032,0.161935,-0.094581,...,-0.937040,-1.006174,0.350697,0.291860,-0.404945,0.175939,2.494797,0.032960,-0.386539,1.472363
212,-0.081243,-0.085826,0.109349,2.164285,1.354263,-1.217381,1.452569,-1.106844,1.197040,-0.189370,...,0.522233,0.759765,0.422595,-0.236381,-0.574237,-0.516668,-0.794064,0.115387,-0.733454,1.415376


# Morfeus-data_augmentation (N=3,5,10,20)

In [8]:
mol_registry = pd.read_csv('data/calculators/mol_list.csv')
inter_des_dataset = pd.read_csv('data/calculators/descriptors/morfeus/raw_data_inter_chg1_xtb_intermediate.csv')
product_des_dataset = pd.read_csv('data/calculators/descriptors/morfeus/raw_data_product_s_xtb_product_s.csv')

# ============================================================
# N_CONFORMERS = 3, 5, 10, 20
# ============================================================
N_CONFORMERS = 3

descriptor_start_idx = metadata.shape[1] + 1
inter_des_cols = 'IM_' + inter_des_dataset.columns[descriptor_start_idx:]
product_des_cols = 'Product_' + product_des_dataset.columns[descriptor_start_idx:]
ddg_col = ['ddg']

processed_data = []
transfer_dataset = []
transfer_metadata_3 = []

for idx, row in loader.data.iterrows():
    # Skip Se compounds
    # if 'Se' in row['com_2'] or 'Se' in row['com_cat']:
    #     continue

    # Get intermediate index and descriptors
    com_inter = merge_intermediates(row['com_2'], row['com_cat'])
    inter_idx = get_mol_indices(com_inter, mol_registry)[1][0]

    com_product = merge_products(row['product_base'], row['com_2'])
    product_idx = get_mol_indices(com_product, mol_registry)[1][0]

    ddg = [(8.314 * (row['con_temperature1'] + 273.15) * 
            np.log((1 + abs(row['label_ee']) / 100) / 
                   (1 - abs(row['label_ee']) / 100))) / 4184]

    intermediate_conformer_data = inter_des_dataset[
        inter_des_dataset['index'] == inter_idx
    ][:N_CONFORMERS].values[:, descriptor_start_idx:]
    
    product_conformer_data = product_des_dataset[
        product_des_dataset['index'] == product_idx
    ][:N_CONFORMERS].values[:, descriptor_start_idx:]
    
    row_data = [
        ([idx, row['ID']], list(chain(inter_des, product_des, ddg)))
        for inter_des in intermediate_conformer_data
        for product_des in product_conformer_data
    ]
    
    processed_data.extend(row_data)

for meta_info, feature_data  in processed_data:
    transfer_metadata_3.append(meta_info)
    transfer_dataset.append(feature_data)

dataset_morfeus_aug_3 = pd.DataFrame(transfer_dataset, columns=[i for i in chain(inter_des_cols, product_des_cols, ddg_col)])
scaler = StandardScaler()
dataset_morfeus_aug_3.iloc[:,:-1] = scaler.fit_transform(dataset_morfeus_aug_3.iloc[:,:-1])
transfer_metadata_3 = pd.DataFrame(transfer_metadata_3, columns=['idx', 'ID'])

In [9]:
dataset_morfeus_aug_3

Unnamed: 0,IM_area,IM_volume,IM_P_bv,IM_S_cat_bv,IM_S_S_bv,IM_S_cat_darea,IM_S_cat_pint,IM_S_S_darea,IM_S_S_pint,IM_S_cat_P,...,Product_S_sarea,Product_S_svol,Product_S_toC1_Lvalue,Product_S_toC1_B1value,Product_S_toC1_B5value,Product_S_toC2_Lvalue,Product_S_toC2_B1value,Product_S_toC2_B5value,Product_S_charge,ddg
0,-0.848932,-0.996012,-0.512547,-1.112726,-0.547722,0.691677,-1.631557,0.556399,-1.456017,0.613467,...,-0.471405,-0.530941,-1.229016,-0.832130,1.822434,2.163218,-0.030208,-1.473961,1.454611,0.008685
1,-0.848932,-0.996012,-0.512547,-1.112726,-0.547722,0.691677,-1.631557,0.556399,-1.456017,0.613467,...,-0.018197,0.282961,-1.228565,-1.092483,1.944980,2.043602,1.167801,-1.565309,0.882271,0.008685
2,-0.848932,-0.996012,-0.512547,-1.112726,-0.547722,0.691677,-1.631557,0.556399,-1.456017,0.613467,...,-0.310510,-0.436818,-1.229030,-0.649946,1.724016,2.096733,0.276699,-1.451546,1.376592,0.008685
3,-0.880235,-1.036277,-1.323217,0.095746,-1.109029,0.121190,-0.826024,1.343034,-1.487643,2.735874,...,-0.471405,-0.530941,-1.229016,-0.832130,1.822434,2.163218,-0.030208,-1.473961,1.454611,0.008685
4,-0.880235,-1.036277,-1.323217,0.095746,-1.109029,0.121190,-0.826024,1.343034,-1.487643,2.735874,...,-0.018197,0.282961,-1.228565,-1.092483,1.944980,2.043602,1.167801,-1.565309,0.882271,0.008685
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1921,-0.632518,-0.516511,-0.482505,-0.580593,1.293948,-0.138122,-0.386751,-0.832963,1.266009,0.095749,...,0.263028,-0.081608,-1.560244,1.212194,-1.206498,-0.788272,-0.861836,-1.595928,0.935495,1.002258
1922,-0.632518,-0.516511,-0.482505,-0.580593,1.293948,-0.138122,-0.386751,-0.832963,1.266009,0.095749,...,1.306989,1.250853,0.345039,-0.570496,-1.266396,-0.793179,-1.221278,0.188348,0.503629,1.002258
1923,-0.497200,-0.493926,0.374636,-0.043152,-0.009338,-0.215916,0.137765,-0.020500,0.371799,0.347494,...,0.236008,0.083108,0.330150,-1.191512,-1.329684,-0.791822,-1.332750,0.204549,2.581916,1.002258
1924,-0.497200,-0.493926,0.374636,-0.043152,-0.009338,-0.215916,0.137765,-0.020500,0.371799,0.347494,...,0.263028,-0.081608,-1.560244,1.212194,-1.206498,-0.788272,-0.861836,-1.595928,0.935495,1.002258
