In [None]:
import math,os,random
import numpy as np
from rdkit import Chem
from rdkit.Chem import MolFromSmiles,MolToSmiles
import argparse
from rdkit.Chem import Draw,AllChem,DataStructs,rdFMCS
from rdkit.Chem import Descriptors
from rdkit.Chem.Scaffolds import MurckoScaffold

from mol_generation import *
from utils import Penalized_logp, Similarity, DockingScore, prop_all
from utils import prepare_ligand, prepare_rep, calculate_center, mol2sdf
from tqdm import tqdm_notebook
import os

import re
import csv
from itertools import islice
import pickle
import matplotlib.pyplot as plt

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torch.optim as optim

In [None]:
def mol_with_atom_index( mol ):
    atoms = mol.GetNumAtoms()
    for idx in range(atoms):
        mol.GetAtomWithIdx(idx).SetProp('molAtomMapNumber', str(mol.GetAtomWithIdx(idx).GetIdx()))
    return mol

In [None]:
# set random seed
seed = 2
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# load data

In [None]:
# data loader
with open("datafinal_0105.pkl", 'rb') as f:
    # input_x = pickle.load(f)
    pickle.load(f)
    input_y = pickle.load(f)
    all_mol = pickle.load(f)
    all_index = pickle.load(f)
    all_freq_list = pickle.load(f)

print(len(all_mol))

In [None]:
del_list = []
for i in range(len(all_freq_list)):
    # 如果all_freq_list中的元素有10-12，就将其索引加入del_list
    if 9 in all_freq_list[i] or 10 in all_freq_list[i] or 11 in all_freq_list[i] or 12 in all_freq_list[i]:
        del_list.append(i)

# 删除all_mol, all_index, all_freq_list中索引为del_list的元素
all_mol = [all_mol[i] for i in range(len(all_mol)) if i not in del_list]
all_index = [all_index[i] for i in range(len(all_index)) if i not in del_list]
all_freq_list = [all_freq_list[i] for i in range(len(all_freq_list)) if i not in del_list]
input_y = [input_y[i] for i in range(len(input_y)) if i not in del_list]

In [None]:
# print(len(input_x))
print(len(input_y))
print(len(all_mol))
print(len(all_index))
print(len(all_freq_list))

i = 22457
# print(input_x[i].shape)
print(input_y[i])
print(all_index[i])
print(all_freq_list[i])
all_mol[i]

In [None]:
# 把所有分子的iso替换为该原子位置的freq
for i in range(len(all_mol)):
    cur_mol = all_mol[i]
    cur_freq_list = all_freq_list[i]
    atoms = cur_mol.GetNumAtoms()
    for j in range(atoms):  
        cur_mol.GetAtomWithIdx(j).SetIsotope(cur_freq_list[j])

In [None]:
# all_index 里的元素转为one-hot向量
new_input_y_onehot = []
for i in range(len(all_index)):
    cur_y = all_index[i]
    cur_y_onehot = [0] * len(all_freq_list[i])
    cur_y_onehot[cur_y] = 1
    new_input_y_onehot.append(cur_y_onehot)

# model-new

In [None]:
def get_split_data(X, y, ratio=0.8):
    n = len(X)
    
    permutation = np.random.choice(n, n, replace=False)

    train_size = int(np.round(ratio * n))

    # 划分训练集和测试集
    X_train = [X[i] for i in permutation[:train_size]]
    y_train = [y[i] for i in permutation[:train_size]]

    X_test = [X[i] for i in permutation[train_size:]]
    y_test = [y[i] for i in permutation[train_size:]]

    return (X_train,X_test,y_train,y_test)

In [None]:
def get_split_data_val(X, y, ratio=0.8, ratio2=0.9):
    n = len(X)
    
    permutation = np.random.choice(n, n, replace=False)

    train_size = int(np.round(ratio * n))
    val_size = int(np.round(ratio2 * n))

    # 划分训练集和测试集
    X_train = [X[i] for i in permutation[:train_size]]
    y_train = [y[i] for i in permutation[:train_size]]

    X_test = [X[i] for i in permutation[train_size:val_size]]
    y_test = [y[i] for i in permutation[train_size:val_size]]

    X_val = [X[i] for i in permutation[val_size:]]
    y_val = [y[i] for i in permutation[val_size:]]

    return (X_train,X_test,X_val,y_train,y_test,y_val)

In [None]:
def plot_metrics(epoch, training_losses, training_accuracies, validation_losses, validation_accuracies, testing_losses, testing_accuracies):
    epochs = np.arange(1, len(training_losses)+1)

    # Plot Loss
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, training_losses, label='Training Loss')
    plt.plot(epochs, validation_losses, label='Validation Loss')
    plt.plot(epochs, testing_losses, label='Testing Loss')
    plt.title('Loss Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, training_accuracies, label='Training Accuracy')
    plt.plot(epochs, validation_accuracies, label='Validation Accuracy')
    plt.plot(epochs, testing_accuracies, label='Testing Accuracy')
    plt.title('Accuracy Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
atom_input_size = 119
atom_hidden_size = 300
atom_output_size = 1
num_epochs = 100

In [None]:
from model1224 import *
from mol_generation1224 import *

class AtomPredictionNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=2):
        super(AtomPredictionNetwork, self).__init__()
        layers = []
        layers.append(nn.Linear(input_size, hidden_size))
        layers.append(nn.ReLU())
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_size, hidden_size))
            # layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.Dropout(p = 0.2))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_size, 1))
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        x = self.layers(x)
        return x

class MolecularPredictionNetwork(nn.Module):
    def __init__(self, atom_input_size, atom_hidden_size):
        super(MolecularPredictionNetwork, self).__init__()
        self.atom_network = AtomPredictionNetwork(atom_input_size, atom_hidden_size)
        self.gnn = GenerativeModel2(num_layer=5, emb_dim=300, gnn_type = "gin")
        # self.fc = nn.Linear(atom_hidden_size, 1)
    def forward(self, mol_list):
        molecule_scores_list = []
        X_list = []
        for i in range(len(mol_list)):
            cur_mol = mol_list[i]
            graph_feature = new_mol_to_graph_data(cur_mol)
            X = self.gnn(graph_feature)
            X_list.append(X)

        for x in X_list:
            atom_outputs = torch.zeros(len(x))
            for i in range((len(x))):
                atom_outputs[i] = self.atom_network(x[i])
            molecule_scores_list.append(atom_outputs)
        return molecule_scores_list

In [None]:
model = MolecularPredictionNetwork(atom_input_size, atom_hidden_size)
model

In [None]:
def calculate_accuracy(y_pred, y):
    acc_list = []
    for i in range(len(y_pred)):
        # top_pred = y_pred[i].argmax(0, keepdim=True)
        # 找到y_pred中top5的index
        top_pred = y_pred[i].argsort(descending=True)[:5]
        top_true = y[i]
        if top_true.argmax(0, keepdim=True) in top_pred:
            acc_list.append(1)
        else:
            acc_list.append(0)
    return sum(acc_list) / len(acc_list)

In [None]:
def calculate_loss(y_pred, y):
    criterion = nn.CrossEntropyLoss()
    batch_loss = []
    for i in range(len(y_pred)):
        y_pre = y_pred[i].float()
        y_tru = y[i].float()
        loss = criterion(y_pre, y_tru)
        batch_loss.append(loss)
    return sum(batch_loss) / len(batch_loss)

# Training with testing set

In [None]:
star = 0
end = -1

inputs = all_mol[star:end]
labels = [torch.tensor(one_hot) for one_hot in new_input_y_onehot[star:end]]

In [None]:
X_train, X_test, X_val, y_train, y_test, y_val = get_split_data_val(inputs, labels, 0.8, 0.9)
print(len(X_train), len(X_test) ,len(X_val), len(y_train), len(y_test), len(y_val))

In [None]:
def get_CLSmodel_val(X_train, y_train, X_val, y_val, X_test, y_test, num_epoch=500, batch_size=128, lambda_reg=0.01, learning_rate=0.001):
    num = len(X_train)
    batchsize = batch_size
    train_bs = int(math.ceil(num / batchsize))
    lambda_reg = lambda_reg

    mlp = MolecularPredictionNetwork(atom_input_size, atom_hidden_size)
    optimizer = optim.Adam(mlp.parameters(), lr=learning_rate)
    mlp.train()

    training_losses = []
    training_accuracies = []
    validation_losses = []  
    validation_accuracies = []  
    testing_losses = []  
    testing_accuracies = [] 
    batch_training_losses = []

    maxacc = 1e10
    iepoch = 0
    for i in range(num_epoch):
        print("---------------------------------- Epoch --------------------------------------", i)
        print("************** Training **************")
        epoch_training_loss = 0.0
        epoch_training_accuracy = 0.0

        for k in range(train_bs):
            mlp.zero_grad()
            batch = X_train[k * batchsize:(k + 1) * batchsize]
            target = y_train[k * batchsize:(k + 1) * batchsize]

            optimizer.zero_grad()
            pred = mlp.forward(batch)
            loss = calculate_loss(pred, target)
            loss.backward()
            optimizer.step()
            acc = calculate_accuracy(pred, target)

            epoch_training_loss += loss.item()
            epoch_training_accuracy += acc
            batch_training_losses.append(loss.item())

            # 打印每个batch的pred和true
            batch_pred = []
            batch_true = []
            for j in range(len(pred)):
                top_pred = pred[j].argsort(descending=True)[:5]
                top_true = target[j]
                batch_pred.append(top_pred.tolist())
                batch_true.append(top_true)
                
        # print("training batch_pred", batch_pred)
        # print("training batch_true", batch_true)
        epoch_training_loss /= train_bs
        epoch_training_accuracy /= train_bs
        training_losses.append(epoch_training_loss)
        training_accuracies.append(epoch_training_accuracy)
        print("epoch: {}, loss: {}, acc: {}".format(i, epoch_training_loss, epoch_training_accuracy))

        if X_test is not None:
            print("************** Testing **************")
            mlp.eval()
            test_accs = []
            test_losses = []
            test_bs = int(math.ceil(len(X_test) / batchsize))
            epoch_testing_loss = 0.0
            epoch_testing_accuracy = 0.0

            with torch.no_grad():
                for j in range(test_bs):
                    batch = X_test[j * batchsize:(j + 1) * batchsize]
                    target = y_test[j * batchsize:(j + 1) * batchsize]

                    pred = mlp.forward(batch)
                    loss = calculate_loss(pred, target)
                    test_losses.append(loss.item())
                    acc = calculate_accuracy(pred, target)
                    test_accs.append(acc)

                    epoch_testing_loss += loss.item()
                    epoch_testing_accuracy += acc

                    batch_pred = []
                    batch_true = []
                    for j in range(len(pred)):
                        top_pred = pred[j].argsort(descending=True)[:5]
                        top_true = target[j]
                        batch_pred.append(top_pred.tolist())
                        batch_true.append(top_true)

            # print("testing batch_pred", batch_pred)
            # print("testing batch_true", batch_true)
            avg_test_loss = np.mean(test_losses)
            avg_test_acc = np.mean(test_accs)
            testing_losses.append(avg_test_loss)
            testing_accuracies.append(avg_test_acc)
            print("Testing loss: {}, acc: {}".format(avg_test_loss, avg_test_acc))

        if X_val is not None:
            print("************** Validation **************")
            mlp.eval()
            val_accs = []
            val_losses = []
            val_bs = int(math.ceil(len(X_val) / batchsize))
            epoch_validation_loss = 0.0
            epoch_validation_accuracy = 0.0

            for j in range(val_bs):
                batch = X_val[j * batchsize:(j + 1) * batchsize]
                target = y_val[j * batchsize:(j + 1) * batchsize]

                pred = mlp.forward(batch)
                loss = calculate_loss(pred, target)
                val_losses.append(loss.item())
                acc = calculate_accuracy(pred, target)
                val_accs.append(acc)

                epoch_validation_loss += loss.item()
                epoch_validation_accuracy += acc

            avg_val_loss = np.mean(val_losses)
            avg_val_acc = np.mean(val_accs)
            validation_losses.append(avg_val_loss)
            validation_accuracies.append(avg_val_acc)

            print("Validation loss: {}, acc: {}".format(avg_val_loss, avg_val_acc))

            # early-stopping based on validation loss
            if avg_val_loss < maxacc:
                maxacc = avg_val_loss
                iepoch = i
            if i - iepoch > 5:
                break

    return mlp, iepoch, training_losses, training_accuracies, validation_losses, validation_accuracies, testing_losses, testing_accuracies

In [None]:
mlp , kepoch, training_losses, training_accuracies, validation_losses, validation_accuracies, testing_losses, testing_accuracies = \
get_CLSmodel_val(X_train, y_train, X_val, y_val, X_test, y_test, 100, 128, 0.01, 0.001)

In [None]:
plot_metrics(kepoch, training_losses, training_accuracies, validation_losses, validation_accuracies, testing_losses, testing_accuracies)

In [None]:
X_train_mol , X_test_mol, X_val_mol, y_train_idx , y_test_idx, y_val_index = get_split_data_val(all_mol[star:end], all_index[star:end], 0.8, 0.9)