In [None]:
import pandas as pd
import random
import dgl
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dgllife.model import model_zoo
from dgllife.utils import smiles_to_bigraph
from dgllife.utils import EarlyStopping, Meter
from dgllife.utils import AttentiveFPAtomFeaturizer
from dgllife.utils import AttentiveFPBondFeaturizer

import torch
import os
import random
import numpy as np
import ast

import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import pandas as pd
from rdkit.Chem import AllChem
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
from IPython.display import SVG, display
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import pickle
import argparse
from rdkit import RDLogger 
import warnings
warnings.filterwarnings("ignore")
RDLogger.DisableLog('rdApp.*') # switch off RDKit warning messages

In [None]:
from utils import get_values_at_positions, atom_finder, smiles_augmentation, concat_feature_reactive_atom, collate_molgraphs, Canon_SMILES_similarity
from model import AttentiveFPPredictor_rxn, weighted_binary_cross_entropy

In [None]:
atom_featurizer = AttentiveFPAtomFeaturizer(atom_data_field='hv')
bond_featurizer = AttentiveFPBondFeaturizer(bond_data_field='he')
n_feats = atom_featurizer.feat_size('hv')
e_feats = bond_featurizer.feat_size('he')
print( 'Number of features in graph : ' , n_feats)

In [None]:
#Assign device 
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

In [None]:
df_elementary = pd.read_csv('elementary_step_100000.csv')
df_elementary

In [None]:
train_datasets_, test_datasets = train_test_split( df_elementary, test_size=0.2, random_state=42, shuffle = False)
train_datasets, valid_datasets = train_test_split( train_datasets_, test_size=0.125, random_state=42, shuffle = False)

In [None]:
valid_datasets = valid_datasets.reset_index(drop=True)
test_datasets = test_datasets.reset_index(drop=True)

In [None]:
train_augm_smiles = smiles_augmentation(train_datasets, 5, augmentation =False)
valid_augm_smiles = smiles_augmentation(valid_datasets, 5, augmentation =False)
test_augm_smiles = smiles_augmentation(test_datasets, 5, augmentation =False)

In [None]:
print ( 'Total number of reaction steps after SMILES augmentation : ', len(train_augm_smiles) + len(valid_augm_smiles)+ len(test_augm_smiles))

In [None]:
def graph_generation(df_augm_smiles):    
    graph_for_rxn = []
    for i in range(len(df_augm_smiles)):
        graph_for_rxn.append(smiles_to_bigraph(df_augm_smiles[i][0], node_featurizer=atom_featurizer,edge_featurizer=bond_featurizer, canonical_atom_order=False))
    return graph_for_rxn

In [None]:
train_graph_for_rxn = graph_generation(train_augm_smiles)
valid_graph_for_rxn = graph_generation(valid_augm_smiles)
test_graph_for_rxn = graph_generation(test_augm_smiles)

In [None]:
train_graph_dataset = concat_feature_reactive_atom(train_graph_for_rxn, train_augm_smiles)
valid_graph_dataset = concat_feature_reactive_atom(valid_graph_for_rxn, valid_augm_smiles)
test_graph_dataset = concat_feature_reactive_atom(test_graph_for_rxn, test_augm_smiles)

In [None]:
train_loader = DataLoader(train_graph_dataset, batch_size=256,shuffle=False,
                          collate_fn=collate_molgraphs)
valid_loader = DataLoader(valid_graph_dataset, batch_size=256,shuffle=False,
                          collate_fn=collate_molgraphs)
test_loader = DataLoader(test_graph_dataset, batch_size=256,shuffle=False,
                          collate_fn=collate_molgraphs)


In [None]:
# Out of distribution dataloader preparation

In [None]:
df_ood_1000 = pd.read_csv('OOD_elementary_step_3647.csv')
print(df_ood_1000)
ood_augm_smiles = smiles_augmentation(df_ood_1000, 5, augmentation =False)
ood_graph_for_rxn = graph_generation(ood_augm_smiles)
ood_graph_dataset = concat_feature_reactive_atom(ood_graph_for_rxn, ood_augm_smiles)
ood_loader = DataLoader(ood_graph_dataset, batch_size=256,shuffle=False,
                          collate_fn=collate_molgraphs)


In [None]:
# Modify the model to fit your classification task
model = AttentiveFPPredictor_rxn(node_feat_size=n_feats,
                                   edge_feat_size=e_feats,
                                   num_layers=2,
                                   num_timesteps=1,
                                   graph_feat_size=200,
                                   n_tasks=8,
                                   dropout=0.1
                                    )



In [None]:
model.to(device)

In [None]:
# Define loss function and optimizer
loss_fn_graph = nn.CrossEntropyLoss()
loss_fn_node = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.000001)

In [None]:
def run_a_train_epoch(n_epochs, epoch, model, data_loader, loss_criterion1, loss_criterion2, optimizer):
    model.train()
    losses = []
    
    y_true = []
    y_pred = []
    loss_node_app = []
    loss_graph_radomize_app = []
    y_true_node = []
    y_pred_node = []
    
    for batch_id, batch_data in enumerate(data_loader):
        
        smiles, bg, labels = batch_data
        
        bg = bg.to(device)
        labels = labels.to(device)
        n_feats_w_l = bg.ndata.pop('hv').to(device)
        e_feats_ = bg.edata.pop('he').to(device)
        n_feats_ = n_feats_w_l[:,:n_feats]
        prediction1, prediction2, graph_feat = model(bg, n_feats_, e_feats_)
        n_labels = n_feats_w_l[:,n_feats].unsqueeze(1)
    
        # Calculate the weights
        counts = torch.bincount(n_labels.view(-1).long())
        class_weights = 1.0 / counts.float()
        class_weights = class_weights / class_weights.sum()
    
        
        loss_graph = loss_fn_graph(prediction1, labels.squeeze(1).long())
        loss_node = weighted_binary_cross_entropy(prediction2,n_labels ,class_weights)
        loss = loss_graph + loss_node 
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss_graph.data.item())
        loss_node_app.append(loss_node.data.item())
        
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(prediction1.detach().cpu().numpy())
    
        y_true_node.extend(n_labels.cpu().numpy())
        y_pred_node.extend(prediction2.detach().cpu().numpy())

    
    total_loss = np.mean(losses)
    total_loss_node = np.mean(loss_node_app)
    total_loss_graph_random = np.mean(loss_graph_radomize_app)
    accuracy = accuracy_score(y_true, np.argmax(y_pred, axis=1))
    print('F1 score classification task:', f1_score(y_true,np.argmax(y_pred, axis=1), average='macro'))

    # Threshold for binary prediction
    threshold_1 = 0.5
    # Convert predicted probabilities to binary values
    y_pred_binary = [1 if pred >= threshold_1 else 0 for pred in np.concatenate(y_pred_node)]
    y_true_flat = np.concatenate(y_true_node)
    # Calculate accuracy score
    accuracy_node = accuracy_score(y_true_flat, y_pred_binary)
    print('F1 score reactive atom task:', f1_score(y_true_flat,np.array(y_pred_binary,dtype=np.float32), average='macro'))


    if epoch % 1 == 0:
        print('epoch {:d}/{:d},train_acc_classification {:.4f},train_node_acc {:.4f},train_loss {:.4f},train_node_loss {:.4f}'.format(
            epoch + 1, n_epochs, accuracy,accuracy_node, total_loss, total_loss_node))
    return accuracy, total_loss, labels, prediction1, y_true_node, y_pred_node, model

In [None]:
def run_a_valid_epoch(n_epochs, epoch, model, data_loader, loss_criterion1, loss_criterion2):
    model.eval()
    losses = []
    
    y_true = []
    y_pred = []
    loss_node_app = []
    loss_graph_radomize_app = []
    y_true_node = []
    y_pred_node = []

    
    with torch.no_grad():
        for batch_id, batch_data in enumerate(data_loader):
            
            smiles, bg, labels = batch_data
            
            bg = bg.to(device)
            labels = labels.to(device)
            n_feats_w_l = bg.ndata.pop('hv').to(device)
            e_feats_ = bg.edata.pop('he').to(device)
            n_feats_ = n_feats_w_l[:,:n_feats]
            prediction1, prediction2, graph_feat = model(bg, n_feats_, e_feats_)
            n_labels = n_feats_w_l[:,n_feats].unsqueeze(1)
        
            # Calculate the weights
            counts = torch.bincount(n_labels.view(-1).long())
            class_weights = 1.0 / counts.float()
            class_weights = class_weights / class_weights.sum()
            
            loss_graph = loss_fn_graph(prediction1, labels.squeeze(1).long())
            loss_node = weighted_binary_cross_entropy(prediction2,n_labels ,class_weights) #class_weights

            loss = loss_graph + loss_node 
            
            losses.append(loss_graph.data.item())
            loss_node_app.append(loss_node.data.item())
            
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(prediction1.detach().cpu().numpy())
        
            y_true_node.extend(n_labels.cpu().numpy())
            y_pred_node.extend(prediction2.detach().cpu().numpy())

    
    total_loss = np.mean(losses)
    total_loss_node = np.mean(loss_node_app)
    total_loss_graph_random = np.mean(loss_graph_radomize_app)
    accuracy = accuracy_score(y_true, np.argmax(y_pred, axis=1))
    print('F1 score classification task:', f1_score(y_true,np.argmax(y_pred, axis=1), average='macro'))

    # Threshold for binary prediction
    threshold_1 = 0.5
    # Convert predicted probabilities to binary values
    y_pred_binary = [1 if pred >= threshold_1 else 0 for pred in np.concatenate(y_pred_node)]
    y_true_flat = np.concatenate(y_true_node)
    # Calculate accuracy score
    accuracy_node = accuracy_score(y_true_flat, y_pred_binary)
    print('F1 score reactive atom task:', f1_score(y_true_flat,np.array(y_pred_binary,dtype=np.float32), average='macro'))

    if epoch % 1 == 0:
        print('epoch {:d}/{:d},valid_acc_classification {:.4f}, valid_node_acc {:.4f},valid_loss {:.4f},valid_node_loss {:.4f}'.format(
            epoch + 1, n_epochs, accuracy, accuracy_node, total_loss, total_loss_node))
    return accuracy, total_loss, labels, prediction1, y_true_node, y_pred_node, model

In [None]:
import time
st_time = time.time()
stopper = EarlyStopping(mode='higher', patience=5)
n_epochs = 5
for e in range(n_epochs):
    accuracy, total_loss, labels, prediction, y_true_node, y_pred_node, train_model= run_a_train_epoch(n_epochs, e, model, train_loader, loss_fn_graph, loss_fn_node, optimizer)
    accuracy_, total_loss_, labels_, prediction_, y_true_node_, y_pred_node_, train_model_= run_a_valid_epoch(n_epochs, e, model, valid_loader, loss_fn_graph, loss_fn_node)

    #fn = 'model_' + str(e)
        #torch.save(train_model.state_dict(), fn)
en_time = time.time()
print('time required:', (en_time-st_time)/60)

In [None]:
# Test accuracy calculation
accuracy_, total_loss_, labels_, prediction_, y_true_node_, y_pred_node_, train_model_= run_a_valid_epoch(1, 1, model, test_loader, loss_fn_graph, loss_fn_node)

In [None]:
# OOD accuracy calculation
accuracy_, total_loss_, labels_, prediction_, y_true_node_, y_pred_node_, train_model_= run_a_valid_epoch(1, 1, model, ood_loader, loss_fn_graph, loss_fn_node)

In [None]:
# if you want to save your model run this cell
#fn = 'final_trained_ReactAIvate_model'
#torch.save(model.state_dict(), fn)