In [None]:
import sys
import os
from os.path import join
import numpy as np
import pickle
import random
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

sys.path.append('.\\additional_code')
from xgboost_training_KM import *

CURRENT_DIR = os.getcwd()
print(CURRENT_DIR)

## 1. Loading and preprocessing data:

In [2]:
df_train = pd.read_pickle(join(CURRENT_DIR, ".." ,"data","KM", "training_data_KM_new_with_unchecked_data.pkl"))
df_test = pd.read_pickle(join(CURRENT_DIR, ".." ,"data","KM", "test_data_KM_new_with_unchecked_data.pkl"))

df_train.rename(columns = {"KEGG ID" : "molecule ID"}, inplace = True)
df_test.rename(columns = {"KEGG ID" : "molecule ID"}, inplace = True)

df_train["Uniprot ID"] = ["Enzyme:train:" + str(ind) for ind in df_train.index]
df_test["Uniprot ID"] = ["Enzyme:test:" + str(ind) for ind in df_test.index]

### (a) Create dictionary with all target values

In [3]:
mol_files = list(set(df_train["molecule ID"])) + list(set(df_test["molecule ID"]))
mol_files = list(set(mol_files))

target_variable_dict_KM = {}
target_variable_dict_KM = create_target_dict_KM(df = df_train, target_variable_dict = target_variable_dict_KM)
target_variable_dict_KM = create_target_dict_KM(df = df_test, target_variable_dict = target_variable_dict_KM)

### (c) Get list with input combinations of Uniprot ID and metabolite ID

In [4]:
train_IDs = get_uid_cid_IDs(df_train)
test_IDs = get_uid_cid_IDs(df_test)

print(len(train_IDs), len(test_IDs))

7580 812


## 2. Calculating input matrices for metabolites

### (a) Creating input matrices:

In [5]:
calculate_atom_and_bond_feature_vectors(mol_files = mol_files)

In [6]:
for mol_ID in mol_files:
    calculate_and_save_input_matrixes(molecule_ID = mol_ID)

More than 70 (75) atoms in molcuele C21471
Could not create input for substrate ID C21471
More than 70 (90) atoms in molcuele C06509
Could not create input for substrate ID C06509
More than 70 (113) atoms in molcuele C06510
Could not create input for substrate ID C06510
More than 70 (77) atoms in molcuele C04702
Could not create input for substrate ID C04702
More than 70 (96) atoms in molcuele C02015
Could not create input for substrate ID C02015
More than 70 (130) atoms in molcuele C05893
Could not create input for substrate ID C05893
More than 70 (91) atoms in molcuele C00853
Could not create input for substrate ID C00853
More than 70 (91) atoms in molcuele C00541
Could not create input for substrate ID C00541


###  (b) Removing all datapoints without molecule input file:

In [7]:
valid_mols = os.listdir(join(CURRENT_DIR, ".." ,"data", "substrate_data_KM", "GNN_input_matrices"))
valid_mols = [mol.split("_A")[0] for mol in valid_mols]

df_train = df_train.loc[df_train["molecule ID"].isin(valid_mols)]
df_test = df_test.loc[df_test["molecule ID"].isin(valid_mols)]

train_IDs = get_uid_cid_IDs(df_train)
test_IDs = get_uid_cid_IDs(df_test)
df_train

Unnamed: 0,molecule ID,ESM1b,ECFP,RDKit FP,MACCS FP,PMID,MW,LogP,log10_KM,checked,GNN FP,Uniprot ID
0,C00387,"[0.100948475, 0.23829113, 0.0027401948, 0.0371...",0000000000000000000000000000000000000000010000...,1010111010101011101111011100011011000100100110...,0000000000000000000000000100000000000010000100...,17918964.0,283.091669,-2.6867,-0.728158,True,"[13.362184, 70.41528, 2.7784162, 60.74657, 0.0...",Enzyme:train:0
1,C00143,"[-0.09477718, 0.16472308, 0.09403025, 0.007433...",0100000000000000000000000000000000000000000000...,1111111000110111101101111011100011001011011110...,0000000000000000000000000100100000000010000100...,21858212.0,457.170981,-0.5219,-0.744727,True,"[18.767265, 151.67131, 25.194078, 66.9493, 0.6...",Enzyme:train:1
2,C00756,"[0.12043195, 0.17901447, -0.003300894, 0.07185...",0000000000000000000000000000000001000000000000...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,19383697.0,130.135765,2.3392,0.588832,True,"[0.053105697, 23.302288, 3.7088723, 5.9439626,...",Enzyme:train:2
3,C00002,"[0.068544716, 0.23684321, 0.080181114, -0.0251...",0000000001000000000000000000000000000000000000...,1010111010101011101011111000111010011100100111...,0000000000000000000000000000010000000010000100...,19509290.0,506.995745,-1.6290,-0.709965,True,"[15.331518, 103.84776, 6.569991, 63.609444, 0....",Enzyme:train:3
4,C00083,"[-0.062576994, 0.30821875, 0.101220384, -0.011...",0100000001000100000000000000000001000000010000...,1010111010101011101011111011111010011100111111...,0000000000000000000000000000010000000010000100...,17292360.0,853.115603,-1.8606,-2.246545,True,"[19.037132, 187.85568, 16.434797, 90.37692, 1....",Enzyme:train:4
...,...,...,...,...,...,...,...,...,...,...,...,...
7575,C20925,"[-0.12106511, 0.16286044, -0.05657043, 0.00162...",0100000000000010000000000000000001000000010000...,0000000000000011100011000011000011000000001000...,0000000000000000000000000000000000000000000000...,,390.175064,-2.1652,-1.000000,False,"[14.411318, 104.04242, 21.408749, 26.555807, 2...",Enzyme:train:7575
7576,C21181,"[-0.009757707, 0.1251226, 0.011750575, -0.0227...",0100000000000000000000000000000000000000000000...,0000000010000000000000000000000000000000000000...,0000000000000000000000000000000000000001100000...,,153.993594,-1.5660,-0.853872,False,,Enzyme:train:7576
7577,C21310,"[-0.0037425177, 0.06174834, -0.05052497, 0.063...",0000000000000100000000000000000000001000000000...,1011111011011111101111111101101011111111111111...,0000000000000000000000000100010000000010000100...,,522.990660,-2.9161,-3.102373,False,"[12.673787, 120.78082, 6.8376884, 88.71222, 0....",Enzyme:train:7577
7578,C21563,"[0.028141364, 0.16967583, -0.118034706, 0.1133...",0100010000000000000000000000000000000000010000...,0100000100001010100001000011001110000001001101...,0000000000000000000000000000000000000000000000...,,415.104936,-0.9324,-0.366532,False,"[10.382019, 105.68732, 18.721575, 34.91598, 2....",Enzyme:train:7578


### (c) Creating representations for the enzymes:

In [8]:
uids_list = list(set(df_train["Uniprot ID"])) + list(set(df_test["Uniprot ID"]))
uids_list = list(set(uids_list))
uid_to_emb = {}
embeddings = np.zeros((0,1280))
for uid in uids_list:
    try:
        emb = np.reshape(np.array(list(df_train["ESM1b"].loc[df_train["Uniprot ID"] == uid])[0]), (1,1280))
    except IndexError:
        try:
            emb = np.reshape(np.array(list(df_test["ESM1b"].loc[df_test["Uniprot ID"] == uid])[0]), (1,1280))
        except IndexError:
            emb = np.reshape(np.array(list(df_validation["ESM1b"].loc[df_validation["Uniprot ID"] == uid])[0]), (1,1280))
    embeddings = np.concatenate([embeddings, emb])
    uid_to_emb[uid] = emb

We perform a PCA an the enzyme representations to get 50-dimensional representations

In [9]:
from sklearn.decomposition import PCA
dim = 50

pca = PCA(n_components = dim)
pca.fit(embeddings)
emb_pca = pca.transform(embeddings)

#Calculate mean and std to normalize the PCA-transformed vectors
mean = np.mean(emb_pca, axis = 0)
std = np.std(emb_pca, axis = 0)

uid_to_pca_emb = {}

for i, uid in enumerate(uids_list):
    uid_to_pca_emb[uid] = (emb_pca[i] - mean) / std

In [10]:
uid_to_emb = uid_to_pca_emb

## 3. Training GNN:

###  (a) Defining a DataGenerator:

In [16]:
class CustomDataSet(Dataset):
    def __init__(self, split_IDs, folder):
        self.all_IDs = split_IDs
        self.folder = folder

    def __len__(self):
        return len(self.all_IDs)

    def __getitem__(self, idx):
        ID = self.all_IDs[idx]
        try:
            [uid,cid1, cid2] = ID.split("_") 
            cid = cid1 +"_"+cid2
        except ValueError:
            [uid,cid] = ID.split("_")
            
        XE = torch.tensor(np.load(join(self.folder, cid + '_XE.npy')), dtype = torch.float32)
        X = torch.tensor(np.load(join(self.folder, cid + '_X.npy')), dtype = torch.float32)
        A = torch.tensor(np.load(join(self.folder, cid + '_A.npy')), dtype = torch.float32)
        ESM1b = torch.tensor(uid_to_emb[uid], dtype = torch.float32)
        label = torch.tensor(target_variable_dict_KM[ID], dtype= torch.float32)
        return XE,X,A,ESM1b, label

### (b) Splitting the training set in a validation and a training set:

In [17]:
n = len(train_IDs) 
random.seed(1)
random.shuffle(train_IDs)
test_IDs = train_IDs[int(0.8*n):]
train_IDs = train_IDs[:int(0.8*n)]

In [18]:
batch_size = 64

train_dataset = CustomDataSet(folder  = join(CURRENT_DIR, ".." ,"data", "substrate_data_KM",
                                             "GNN_input_matrices"), split_IDs = train_IDs)
train_loader = DataLoader(train_dataset , batch_size=batch_size, shuffle=True, drop_last=True)

test_dataset = CustomDataSet(folder  = join(CURRENT_DIR, ".." ,"data", "substrate_data_KM",
                                            "GNN_input_matrices"), split_IDs = test_IDs)
test_loader = DataLoader(test_dataset , batch_size=batch_size, shuffle=False, drop_last=True)

In [19]:
n_train_batches = int(len(train_dataset)/batch_size)
n_test_batches = int(len(test_dataset)/batch_size)
train_batches = list(range(n_train_batches))
test_batches = list(range(n_test_batches))

### (c) Training GNN:

In [21]:
from sklearn.metrics import r2_score

In [23]:
import torch.optim as optim

model = GNN(D= 100, N = 70, F1 = 32 , F2 = 10, F = F1+F2).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay= 0.00001)

for epoch in range(10):  # loop over the dataset multiple times
    model.train()
    running_loss = 0.0
    for i, [XE, X, A,ESM1b, labels] in enumerate(train_loader):
        # zero the parameter gradients
        optimizer.zero_grad()
        XE, X, A, ESM1b, labels = XE.to(device), X.to(device), A.to(device),ESM1b.to(device), labels.to(device)
        # forward + backward + optimize
        outputs = model(XE, X, A, ESM1b)
        loss = criterion(outputs, labels.view((batch_size,-1)))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20 == 19:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 20))
            running_loss = 0.0
            
    #After each epoch, calculate the validation loss:
    running_mse = 0.0
    running_r2 = 0.0
    running_loss = 0.0
    model.eval()
    for i, [XE, X, A,ESM1b, labels] in enumerate(test_loader):
        XE, X, A, ESM1b, labels = XE.to(device), X.to(device), A.to(device),ESM1b.to(device), labels.to(device)
        
        with torch.no_grad():
            outputs = model(XE, X, A, ESM1b)
        loss = criterion(outputs, labels.view((batch_size,-1)))
        running_loss += loss.item()

        outputs2 = outputs.view(-1).cpu().detach().numpy()
        labels2 = labels.cpu().detach().numpy()
        mse = np.mean((np.array(outputs2) - np.array(labels2))**2)
        R2 = r2_score(np.array(labels2), np.array(outputs2))
        running_mse += mse
        running_r2 += R2

    print("Epoch: %s, Val. loss: %s, Val. mse: %s, Val R2: %s" % (epoch, np.round(running_loss/(i+1),2),
                                                                  np.round(running_mse/(i+1), 2), 
                                                                 np.round(running_r2/(i+1))))

print('Finished Training')

[1,    20] loss: 1.585
[1,    40] loss: 1.272
[1,    60] loss: 1.098
Epoch: 0, Val. loss: 1.05, Val. mse: 1.05, Val R2: 0.0
[2,    20] loss: 1.021
[2,    40] loss: 0.965
[2,    60] loss: 0.989
Epoch: 1, Val. loss: 1.0, Val. mse: 1.0, Val R2: 0.0
[3,    20] loss: 0.930
[3,    40] loss: 0.974
[3,    60] loss: 0.894
Epoch: 2, Val. loss: 0.93, Val. mse: 0.93, Val R2: 0.0
[4,    20] loss: 0.857
[4,    40] loss: 0.882
[4,    60] loss: 0.933
Epoch: 3, Val. loss: 0.92, Val. mse: 0.92, Val R2: 0.0
[5,    20] loss: 0.826
[5,    40] loss: 0.885
[5,    60] loss: 0.835
Epoch: 4, Val. loss: 0.9, Val. mse: 0.9, Val R2: 0.0
[6,    20] loss: 0.862
[6,    40] loss: 0.820
[6,    60] loss: 0.782
Epoch: 5, Val. loss: 0.89, Val. mse: 0.89, Val R2: 0.0
[7,    20] loss: 0.815
[7,    40] loss: 0.849
[7,    60] loss: 0.808
Epoch: 6, Val. loss: 0.92, Val. mse: 0.92, Val R2: 0.0
[8,    20] loss: 0.770
[8,    40] loss: 0.735
[8,    60] loss: 0.828
Epoch: 7, Val. loss: 0.87, Val. mse: 0.87, Val R2: 0.0
[9,    20] l

In [25]:
torch.save(model.state_dict(),join(CURRENT_DIR, ".." ,"data", "substrate_data_KM", "GNN", "Pytorch_GNN_KM"))