In [1]:
import warnings
import tensorflow as tf
warnings.filterwarnings("ignore")
tf.get_logger().setLevel('ERROR')

import numpy as np
import pandas as pd
import deepchem as dc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_max_pool
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataLoader
from sklearn import metrics




Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(30, 256)
        self.conv2 = GCNConv(256, 256)
        self.conv3 = GCNConv(256, 256)
        self.conv4 = GCNConv(256, 256)
        self.fc1 = nn.Linear(256, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)
        self.dropout1 = nn.Dropout(p=0.2)
        self.dropout2 = nn.Dropout(p=0.35)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout1(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.conv4(x, edge_index)
        x = F.relu(x)
        x = global_max_pool(x, data.batch)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

In [3]:
def custom_collate(batch):
    data_list, target_list = zip(*batch)
    batch_data = Batch.from_data_list(data_list)
    batch_target = torch.stack(target_list)
    return batch_data, batch_target

In [4]:
def calculate_statistics(group):
    r2_test = group['r2_test']
    r2_test_dict = {f'run{i}': r2_test_val for i, r2_test_val in enumerate(r2_test)}
    return pd.Series({
        **r2_test_dict, 
        'r2_test_mean': np.mean(r2_test),
        'r2_test_max': np.max(r2_test),
        'r2_test_min': np.min(r2_test),
        'r2_test_std': np.std(r2_test, ddof=0),
    })

def calculate_statistics2(group):
    rmse_test = group['rmse_test']
    rmse_test_dict = {f'run{i}': rmse_test_val for i, rmse_test_val in enumerate(rmse_test)}
    return pd.Series({
        **rmse_test_dict, 
        'rmse_test_mean': np.mean(rmse_test),
        'rmse_test_max': np.max(rmse_test),
        'rmse_test_min': np.min(rmse_test),
        'rmse_test_std': np.std(rmse_test, ddof=0),
    })

In [5]:
torch.manual_seed(0)

epochs = 160
lr = 3e-3
wd = 1e-3

results_r2 = []
results_rmse = []
for random_state in range(10):
    torch.manual_seed(0)
    
    for dataset in ["abcgg", "aatsc3d", "atsc3d", "kappa2", "peoevsa6", "bertzct", "ggi10", "vsaestate3",
                    "atsc4i", "bcutp1l", "kappa3", "estatevsa3", "kier3", "aats8p", "kier2", "frnh0"]:
        torch.manual_seed(0)
        
        for t in ["Yield_CO_cl"]:
            torch.manual_seed(0)
            scaler = StandardScaler()
            df = pd.read_csv('data_Real/data_real.csv')
            smiles = df["SMILES"]
            featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)
            X = featurizer.featurize(smiles)
            
            y = df[t]
            data_train, data_test, target_train, target_test = train_test_split(X, y, test_size=0.5, random_state=random_state)

            target_train = scaler.fit_transform(target_train.values.reshape(-1, 1)).flatten()
            target_test = scaler.transform(target_test.values.reshape(-1, 1)).flatten()
            
            target_train = torch.tensor(target_train, dtype=torch.float32)
            target_test = torch.tensor(target_test, dtype=torch.float32)

            data_train_list = []
            for graph_data in data_train:
                node_features = torch.tensor(graph_data.node_features, dtype=torch.float32)
                edge_index = torch.tensor(graph_data.edge_index, dtype=torch.long)
                edge_features = torch.tensor(graph_data.edge_features, dtype=torch.float32)
                data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)
                data_train_list.append(data)

            data_test_list = []
            for graph_data in data_test:
                node_features = torch.tensor(graph_data.node_features, dtype=torch.float32)
                edge_index = torch.tensor(graph_data.edge_index, dtype=torch.long)
                edge_features = torch.tensor(graph_data.edge_features, dtype=torch.float32)
                data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)
                data_test_list.append(data)

            train_loader = DataLoader(list(zip(data_train_list, target_train)), batch_size=len(data_train_list), collate_fn=custom_collate)
            test_loader = DataLoader(list(zip(data_test_list, target_test)), batch_size=len(data_test_list), collate_fn=custom_collate)

            model = Net()
            model.load_state_dict(torch.load(f'data_AI2+Human/model_{dataset}_sc.pth'))
            model.fc3 = nn.Linear(128, 1)
        
            model.train()
            optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
            criterion = nn.MSELoss()
        
            for param in model.conv1.parameters():
                param.requires_grad = False
            for param in model.conv2.parameters():
                param.requires_grad = False
            for param in model.conv3.parameters():
                param.requires_grad = False
            for param in model.conv4.parameters():
                param.requires_grad = False

            device = torch.device('cpu')
            model.to(device)

            for epoch in range(epochs):
                for data, target in train_loader:
                    data = data.to(device)
                    target = target.to(device)
                    optimizer.zero_grad()
                    out = model(data)
                    loss = criterion(out, target.view(-1, 1))
                    loss.backward()
                    optimizer.step()

            model.eval()
            pred_train = []
            for data, target in train_loader:
                data = data.to(device)
                with torch.no_grad():
                    out = model(data)
                pred_train.append(out.cpu().numpy())
            pred_train = np.concatenate(pred_train)

            pred_test = []
            for data, target in test_loader:
                data = data.to(device)
                with torch.no_grad():
                    out = model(data)
                pred_test.append(out.cpu().numpy())
            pred_test = np.concatenate(pred_test)

            pred_train = scaler.inverse_transform(pred_train)
            pred_test = scaler.inverse_transform(pred_test)
            target_train = scaler.inverse_transform(target_train.numpy().reshape(-1, 1)).flatten()
            target_test = scaler.inverse_transform(target_test.numpy().reshape(-1, 1)).flatten()

            r2_test_score = metrics.r2_score(target_test, pred_test)
            rmse_test_score = metrics.root_mean_squared_error(target_test, pred_test)
            results_r2.append({'source': dataset, 'target': t, 'r2_test': r2_test_score})
            results_rmse.append({'source': dataset, 'target': t, 'rmse_test': rmse_test_score})

results_df = pd.DataFrame(results_r2)
gen_results = results_df.groupby(['source', 'target']).apply(calculate_statistics).reset_index()
results_df2 = pd.DataFrame(results_rmse)
gen_results2 = results_df2.groupby(['source', 'target']).apply(calculate_statistics2).reset_index()

In [6]:
gen_results.T.to_csv('result/result_yield_cl_r2.csv', header=False)
gen_results.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
source,aats8p,aatsc3d,abcgg,atsc3d,atsc4i,bcutp1l,bertzct,estatevsa3,frnh0,ggi10,kappa2,kappa3,kier2,kier3,peoevsa6,vsaestate3
target,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl
run0,0.559866,0.605406,0.710951,0.639531,0.655074,0.50819,0.737261,0.309988,0.571664,0.643385,0.585011,0.684141,0.6557,0.601346,0.671352,0.554304
run1,0.631837,0.426544,0.753408,0.415276,0.583674,0.164326,0.815323,0.618053,0.628998,0.587676,0.711472,0.643812,0.653349,0.742516,0.432153,0.608591
run2,0.481096,0.671396,0.511675,0.684156,0.60804,0.450778,0.758835,0.381492,0.505854,0.716714,0.516196,0.406839,0.615071,0.565192,0.649232,0.394325
run3,0.77281,0.726887,0.773949,0.759878,0.762233,0.214235,0.790991,0.694914,0.739276,0.771787,0.742883,0.716094,0.748778,0.742811,0.691489,0.629653
run4,0.770966,0.699658,0.787069,0.719447,0.801396,0.44983,0.84968,0.705892,0.801838,0.773238,0.742108,0.750448,0.756889,0.781576,0.669639,0.631404
run5,0.707714,0.609404,0.764032,0.641244,0.61808,0.573468,0.77615,0.688842,0.795691,0.692217,0.725713,0.668251,0.691841,0.688697,0.649799,0.636918
run6,0.753715,0.66862,0.797739,0.655545,0.668422,0.375403,0.796594,0.72567,0.698338,0.646309,0.732124,0.66819,0.717371,0.734273,0.546415,0.62905
run7,0.685391,0.675824,0.681671,0.730545,0.704322,0.471839,0.751967,0.605822,0.678748,0.648736,0.756344,0.668401,0.684427,0.674936,0.695526,0.658939


In [7]:
gen_results2.T.to_csv('result/result_yield_cl_rmse.csv', header=False)
gen_results2.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
source,aats8p,aatsc3d,abcgg,atsc3d,atsc4i,bcutp1l,bertzct,estatevsa3,frnh0,ggi10,kappa2,kappa3,kier2,kier3,peoevsa6,vsaestate3
target,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl
run0,12.853059,12.169957,10.415968,11.63183,11.378282,13.586657,9.930623,16.093187,12.679621,11.569482,12.480512,10.888306,11.367949,12.232402,11.106548,12.934011
run1,9.8664,12.313708,8.074739,12.434102,10.491935,14.864743,6.987881,10.049413,9.904377,10.441391,8.734399,9.704618,9.573812,8.251143,12.253345,10.173122
run2,14.312323,11.389467,13.884217,11.166142,12.439064,14.724504,9.757169,15.625705,13.966727,10.574968,13.819798,15.302167,12.326992,13.101329,11.767295,15.462751
run3,9.219845,10.108816,9.196695,9.478622,9.432014,17.146486,8.84323,10.684159,9.876863,9.240568,9.808307,10.306608,9.695229,9.809684,10.743964,11.77154
run4,9.400146,10.76447,9.063674,10.403811,8.753427,14.569118,7.615396,10.652172,8.743696,9.353408,9.974789,9.81218,9.684719,9.179837,11.289616,11.925039
run5,9.408072,10.875773,8.453227,10.423079,10.75431,11.365073,8.233316,9.707037,7.865754,9.654251,9.11378,10.023073,9.66014,9.709307,10.298048,10.485732
run6,8.915346,10.341472,8.079332,10.543529,10.344559,14.19774,8.102166,9.409277,9.86688,10.683945,9.297924,10.348178,9.550533,9.26056,12.098982,10.9415
run7,10.674257,10.835351,10.737185,9.878607,10.348134,13.830423,9.477791,11.948099,10.786369,11.278956,9.39379,10.958699,10.690608,10.850183,10.500921,11.113945
