In [1]:
import warnings
import tensorflow as tf
warnings.filterwarnings("ignore")
tf.get_logger().setLevel('ERROR')

import numpy as np
import pandas as pd
import deepchem as dc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_max_pool
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataLoader
from sklearn import metrics




Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(30, 256)
        self.conv2 = GCNConv(256, 256)
        self.conv3 = GCNConv(256, 256)
        self.conv4 = GCNConv(256, 256)
        self.fc1 = nn.Linear(256, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)
        self.dropout1 = nn.Dropout(p=0.15)
        self.dropout2 = nn.Dropout(p=0.3)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout1(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.conv4(x, edge_index)
        x = F.relu(x)
        x = global_max_pool(x, data.batch)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

In [3]:
def custom_collate(batch):
    data_list, target_list = zip(*batch)
    batch_data = Batch.from_data_list(data_list)
    batch_target = torch.stack(target_list)
    return batch_data, batch_target

In [4]:
def calculate_statistics(group):
    r2_test = group['r2_test']
    r2_test_dict = {f'run{i}': r2_test_val for i, r2_test_val in enumerate(r2_test)}
    return pd.Series({
        **r2_test_dict, 
        'r2_test_mean': np.mean(r2_test),
        'r2_test_max': np.max(r2_test),
        'r2_test_min': np.min(r2_test),
        'r2_test_std': np.std(r2_test, ddof=0),
    })

def calculate_statistics2(group):
    rmse_test = group['rmse_test']
    rmse_test_dict = {f'run{i}': rmse_test_val for i, rmse_test_val in enumerate(rmse_test)}
    return pd.Series({
        **rmse_test_dict, 
        'rmse_test_mean': np.mean(rmse_test),
        'rmse_test_max': np.max(rmse_test),
        'rmse_test_min': np.min(rmse_test),
        'rmse_test_std': np.std(rmse_test, ddof=0),
    })

In [5]:
torch.manual_seed(0)

epochs = 120
lr = 9e-3
wd = 2e-6

results_r2 = []
results_rmse = []
for random_state in range(10):
    torch.manual_seed(0)
    
    for dataset in ["abcgg", "aatsc3d", "atsc3d", "kappa2", "peoevsa6", "bertzct", "ggi10", "vsaestate3",
                    "atsc4i", "bcutp1l", "kappa3", "estatevsa3", "kier3", "aats8p", "kier2", "frnh0"]:
        torch.manual_seed(0)
        
        for t in ["Yield_CO_l"]:
            torch.manual_seed(0)
            scaler = StandardScaler()
            df = pd.read_csv('data_Real/data_real.csv')
            smiles = df["SMILES"]
            featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)
            X = featurizer.featurize(smiles)
            
            y = df[t]
            data_train, data_test, target_train, target_test = train_test_split(X, y, test_size=0.5, random_state=random_state)

            target_train = scaler.fit_transform(target_train.values.reshape(-1, 1)).flatten()
            target_test = scaler.transform(target_test.values.reshape(-1, 1)).flatten()
            
            target_train = torch.tensor(target_train, dtype=torch.float32)
            target_test = torch.tensor(target_test, dtype=torch.float32)

            data_train_list = []
            for graph_data in data_train:
                node_features = torch.tensor(graph_data.node_features, dtype=torch.float32)
                edge_index = torch.tensor(graph_data.edge_index, dtype=torch.long)
                edge_features = torch.tensor(graph_data.edge_features, dtype=torch.float32)
                data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)
                data_train_list.append(data)

            data_test_list = []
            for graph_data in data_test:
                node_features = torch.tensor(graph_data.node_features, dtype=torch.float32)
                edge_index = torch.tensor(graph_data.edge_index, dtype=torch.long)
                edge_features = torch.tensor(graph_data.edge_features, dtype=torch.float32)
                data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)
                data_test_list.append(data)

            train_loader = DataLoader(list(zip(data_train_list, target_train)), batch_size=len(data_train_list), collate_fn=custom_collate)
            test_loader = DataLoader(list(zip(data_test_list, target_test)), batch_size=len(data_test_list), collate_fn=custom_collate)

            model = Net()
            model.load_state_dict(torch.load(f'data_AI2+Human/model_{dataset}_sc.pth'))
            model.fc3 = nn.Linear(128, 1)
        
            model.train()
            optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
            criterion = nn.MSELoss()
        
            for param in model.conv1.parameters():
                param.requires_grad = False
            for param in model.conv2.parameters():
                param.requires_grad = False
            for param in model.conv3.parameters():
                param.requires_grad = False
            for param in model.conv4.parameters():
                param.requires_grad = False

            device = torch.device('cpu')
            model.to(device)

            for epoch in range(epochs):
                for data, target in train_loader:
                    data = data.to(device)
                    target = target.to(device)
                    optimizer.zero_grad()
                    out = model(data)
                    loss = criterion(out, target.view(-1, 1))
                    loss.backward()
                    optimizer.step()

            model.eval()
            pred_train = []
            for data, target in train_loader:
                data = data.to(device)
                with torch.no_grad():
                    out = model(data)
                pred_train.append(out.cpu().numpy())
            pred_train = np.concatenate(pred_train)

            pred_test = []
            for data, target in test_loader:
                data = data.to(device)
                with torch.no_grad():
                    out = model(data)
                pred_test.append(out.cpu().numpy())
            pred_test = np.concatenate(pred_test)

            pred_train = scaler.inverse_transform(pred_train)
            pred_test = scaler.inverse_transform(pred_test)
            target_train = scaler.inverse_transform(target_train.numpy().reshape(-1, 1)).flatten()
            target_test = scaler.inverse_transform(target_test.numpy().reshape(-1, 1)).flatten()

            r2_test_score = metrics.r2_score(target_test, pred_test)
            rmse_test_score = metrics.root_mean_squared_error(target_test, pred_test)
            results_r2.append({'source': dataset, 'target': t, 'r2_test': r2_test_score})
            results_rmse.append({'source': dataset, 'target': t, 'rmse_test': rmse_test_score})

results_df = pd.DataFrame(results_r2)
gen_results = results_df.groupby(['source', 'target']).apply(calculate_statistics).reset_index()
results_df2 = pd.DataFrame(results_rmse)
gen_results2 = results_df2.groupby(['source', 'target']).apply(calculate_statistics2).reset_index()

In [6]:
gen_results.T.to_csv('result/result_yield_l_r2.csv', header=False)
gen_results.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
source,aats8p,aatsc3d,abcgg,atsc3d,atsc4i,bcutp1l,bertzct,estatevsa3,frnh0,ggi10,kappa2,kappa3,kier2,kier3,peoevsa6,vsaestate3
target,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l
run0,0.637287,0.653296,0.845425,0.6764,0.723629,0.435271,0.829135,0.607059,0.672737,0.676938,0.716632,0.787286,0.767725,0.789751,0.763653,0.657409
run1,0.674412,0.211156,0.725282,0.379369,0.513075,0.224272,0.783061,0.651383,0.652684,0.646004,0.635994,0.706674,0.510313,0.69842,0.299889,0.684089
run2,0.630658,0.769046,0.817854,0.75149,0.756216,0.695742,0.795368,0.687462,0.74155,0.752553,0.738114,0.599,0.743496,0.805008,0.730239,0.687785
run3,0.809087,0.723004,0.807528,0.769637,0.851373,0.408388,0.843995,0.69516,0.69516,0.793265,0.844733,0.736683,0.767534,0.805801,0.739135,0.643149
run4,0.823816,0.66068,0.831159,0.637113,0.761848,0.435447,0.829068,0.743818,0.781095,0.78437,0.785027,0.744376,0.775453,0.779521,0.648466,0.715204
run5,0.762294,0.638183,0.76905,0.568056,0.572715,0.463539,0.775568,0.744767,0.761068,0.715966,0.689533,0.716955,0.720758,0.748429,0.67418,0.784018
run6,0.798397,0.672664,0.784342,0.534128,0.520019,0.574857,0.768583,0.762494,0.648743,0.559211,0.774469,0.731419,0.7353,0.782109,0.667897,0.754978
run7,0.748862,0.690003,0.846094,0.72676,0.728326,0.539399,0.851925,0.63755,0.788254,0.731105,0.842541,0.743701,0.760445,0.78591,0.763611,0.706812


In [7]:
gen_results2.T.to_csv('result/result_yield_l_rmse.csv', header=False)
gen_results2.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
source,aats8p,aatsc3d,abcgg,atsc3d,atsc4i,bcutp1l,bertzct,estatevsa3,frnh0,ggi10,kappa2,kappa3,kier2,kier3,peoevsa6,vsaestate3
target,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l,Yield_CO_l
run0,21.814066,21.327246,14.240495,20.604361,19.041523,27.219187,14.97205,22.704863,20.720667,20.587234,19.281038,16.705242,17.456486,16.608164,17.608824,21.200365
run1,18.780668,29.232929,17.251253,25.929476,22.96719,28.98888,15.330137,19.433512,19.3972,19.582867,19.857809,17.825935,23.032246,18.075003,27.53978,18.499458
run2,23.511118,18.591797,16.510796,19.285471,19.101223,21.339262,17.500322,21.627707,19.667397,19.244181,19.79771,24.498013,19.593227,17.083113,20.093151,21.616518
run3,15.414165,18.5669,15.477005,16.932047,13.600402,27.134447,13.933874,19.477743,19.477747,16.040211,13.900884,18.102663,17.009144,15.546282,18.018175,21.073946
run4,15.58166,21.623941,15.253471,22.362274,18.115808,27.892176,15.347631,18.789051,17.368345,17.237911,17.211618,18.768562,17.59071,17.43066,22.009697,19.810562
run5,16.900629,20.85103,16.658735,22.782261,22.659058,25.389376,16.421974,17.51263,16.94418,18.474321,19.314825,18.442123,18.31781,17.386536,19.786623,16.109858
run6,15.55498,19.820681,16.088066,23.645826,24.001226,22.588591,16.665545,16.883362,20.532125,23.000484,16.452204,17.953909,17.823704,16.171154,19.964472,17.148409
run7,18.491541,20.544502,14.475892,19.288094,19.232716,25.042604,14.199021,22.214708,16.979464,19.134102,14.642024,18.680559,18.060083,17.073202,17.940332,19.979759
