In [1]:
import warnings
import tensorflow as tf
warnings.filterwarnings("ignore")
tf.get_logger().setLevel('ERROR')

import numpy as np
import pandas as pd
import deepchem as dc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_max_pool
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataLoader
from sklearn import metrics




Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(30, 256)
        self.conv2 = GCNConv(256, 256)
        self.conv3 = GCNConv(256, 256)
        self.conv4 = GCNConv(256, 256)
        self.fc1 = nn.Linear(256, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.conv4(x, edge_index)
        x = F.relu(x)
        x = global_max_pool(x, data.batch)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

In [3]:
def custom_collate(batch):
    data_list, target_list = zip(*batch)
    batch_data = Batch.from_data_list(data_list)
    batch_target = torch.stack(target_list)
    return batch_data, batch_target

In [4]:
def calculate_statistics(group):
    r2_test = group['r2_test']
    r2_test_dict = {f'run{i}': r2_test_val for i, r2_test_val in enumerate(r2_test)}
    return pd.Series({
        **r2_test_dict, 
        'r2_test_mean': np.mean(r2_test),
        'r2_test_max': np.max(r2_test),
        'r2_test_min': np.min(r2_test),
        'r2_test_std': np.std(r2_test, ddof=0),
    })

def calculate_statistics2(group):
    rmse_test = group['rmse_test']
    rmse_test_dict = {f'run{i}': rmse_test_val for i, rmse_test_val in enumerate(rmse_test)}
    return pd.Series({
        **rmse_test_dict, 
        'rmse_test_mean': np.mean(rmse_test),
        'rmse_test_max': np.max(rmse_test),
        'rmse_test_min': np.min(rmse_test),
        'rmse_test_std': np.std(rmse_test, ddof=0),
    })

In [5]:
torch.manual_seed(0)

epochs = 100
lr = 1e-2
wd = 4e-4

results_r2 = []
results_rmse = []
for random_state in range(10):
    torch.manual_seed(0)
    
    for dataset in ["abcgg", "aatsc3d", "atsc3d", "kappa2", "peoevsa6", "bertzct", "ggi10", "vsaestate3",
                    "atsc4i", "bcutp1l", "kappa3", "estatevsa3", "kier3", "aats8p", "kier2", "frnh0"]:
        torch.manual_seed(0)
        
        for t in ["Yield_CO_s"]:
            torch.manual_seed(0)
            scaler = StandardScaler()
            df = pd.read_csv('data_Real/data_real.csv')
            smiles = df["SMILES"]
            featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)
            X = featurizer.featurize(smiles)
            
            y = df[t]
            data_train, data_test, target_train, target_test = train_test_split(X, y, test_size=0.5, random_state=random_state)

            target_train = scaler.fit_transform(target_train.values.reshape(-1, 1)).flatten()
            target_test = scaler.transform(target_test.values.reshape(-1, 1)).flatten()
            
            target_train = torch.tensor(target_train, dtype=torch.float32)
            target_test = torch.tensor(target_test, dtype=torch.float32)

            data_train_list = []
            for graph_data in data_train:
                node_features = torch.tensor(graph_data.node_features, dtype=torch.float32)
                edge_index = torch.tensor(graph_data.edge_index, dtype=torch.long)
                edge_features = torch.tensor(graph_data.edge_features, dtype=torch.float32)
                data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)
                data_train_list.append(data)

            data_test_list = []
            for graph_data in data_test:
                node_features = torch.tensor(graph_data.node_features, dtype=torch.float32)
                edge_index = torch.tensor(graph_data.edge_index, dtype=torch.long)
                edge_features = torch.tensor(graph_data.edge_features, dtype=torch.float32)
                data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)
                data_test_list.append(data)

            train_loader = DataLoader(list(zip(data_train_list, target_train)), batch_size=len(data_train_list), collate_fn=custom_collate)
            test_loader = DataLoader(list(zip(data_test_list, target_test)), batch_size=len(data_test_list), collate_fn=custom_collate)

            model = Net()
            model.load_state_dict(torch.load(f'data_Random/model_{dataset}_sc.pth'))
            model.fc3 = nn.Linear(128, 1)
        
            model.train()
            optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
            criterion = nn.MSELoss()
        
            for param in model.conv1.parameters():
                param.requires_grad = False
            for param in model.conv2.parameters():
                param.requires_grad = False
            for param in model.conv3.parameters():
                param.requires_grad = False
            for param in model.conv4.parameters():
                param.requires_grad = False

            device = torch.device('cpu')
            model.to(device)

            for epoch in range(epochs):
                for data, target in train_loader:
                    data = data.to(device)
                    target = target.to(device)
                    optimizer.zero_grad()
                    out = model(data)
                    loss = criterion(out, target.view(-1, 1))
                    loss.backward()
                    optimizer.step()

            model.eval()
            pred_train = []
            for data, target in train_loader:
                data = data.to(device)
                with torch.no_grad():
                    out = model(data)
                pred_train.append(out.cpu().numpy())
            pred_train = np.concatenate(pred_train)

            pred_test = []
            for data, target in test_loader:
                data = data.to(device)
                with torch.no_grad():
                    out = model(data)
                pred_test.append(out.cpu().numpy())
            pred_test = np.concatenate(pred_test)

            pred_train = scaler.inverse_transform(pred_train)
            pred_test = scaler.inverse_transform(pred_test)
            target_train = scaler.inverse_transform(target_train.numpy().reshape(-1, 1)).flatten()
            target_test = scaler.inverse_transform(target_test.numpy().reshape(-1, 1)).flatten()

            r2_test_score = metrics.r2_score(target_test, pred_test)
            rmse_test_score = metrics.root_mean_squared_error(target_test, pred_test)
            results_r2.append({'source': dataset, 'target': t, 'r2_test': r2_test_score})
            results_rmse.append({'source': dataset, 'target': t, 'rmse_test': rmse_test_score})

results_df = pd.DataFrame(results_r2)
gen_results = results_df.groupby(['source', 'target']).apply(calculate_statistics).reset_index()
results_df2 = pd.DataFrame(results_rmse)
gen_results2 = results_df2.groupby(['source', 'target']).apply(calculate_statistics2).reset_index()

In [6]:
gen_results.T.to_csv('result/result_yield_s_r2.csv', header=False)
gen_results.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
source,aats8p,aatsc3d,abcgg,atsc3d,atsc4i,bcutp1l,bertzct,estatevsa3,frnh0,ggi10,kappa2,kappa3,kier2,kier3,peoevsa6,vsaestate3
target,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s
run0,0.737883,0.824573,0.766486,0.65414,0.648354,0.666878,0.812871,0.708539,0.780604,0.742882,0.790022,0.749074,0.784251,0.745903,0.672942,0.656078
run1,0.644403,0.42253,0.677805,0.636759,0.570834,0.380538,0.752107,0.519363,0.661028,0.495514,0.635522,0.730394,0.66172,0.591124,0.577949,0.183978
run2,0.445168,0.697523,0.630127,0.575098,0.650616,0.544755,0.722849,0.604855,0.474561,0.51054,0.61503,0.503681,0.670035,0.605824,0.47144,0.509625
run3,0.874689,0.842926,0.906072,0.883518,0.784121,0.445768,0.916402,0.749341,0.906964,0.901188,0.909007,0.888024,0.896391,0.892204,0.877925,0.837692
run4,0.56524,0.532852,0.70516,0.633393,0.586576,0.380654,0.730177,0.585977,0.715552,0.472933,0.687633,0.693222,0.729277,0.745968,0.627897,0.515997
run5,0.718633,0.721736,0.809933,0.741582,0.664792,0.690733,0.825206,0.718724,0.778386,0.761045,0.755107,0.770087,0.787244,0.797806,0.674449,0.668439
run6,0.739372,0.648539,0.727015,0.640586,0.692964,0.417343,0.77381,0.71522,0.703401,0.696358,0.780072,0.835987,0.766005,0.800934,0.705403,0.61859
run7,0.808368,0.675549,0.773156,0.736143,0.725398,0.686231,0.806671,0.73755,0.746487,0.754852,0.7346,0.798533,0.778601,0.80472,0.729948,0.656917


In [7]:
gen_results2.T.to_csv('result/result_yield_s_rmse.csv', header=False)
gen_results2.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
source,aats8p,aatsc3d,abcgg,atsc3d,atsc4i,bcutp1l,bertzct,estatevsa3,frnh0,ggi10,kappa2,kappa3,kier2,kier3,peoevsa6,vsaestate3
target,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s,Yield_CO_s
run0,15.805715,12.930481,14.918426,18.155878,18.307098,17.818388,13.354806,16.666983,14.460419,15.654291,14.146656,15.464634,14.339745,15.562039,17.65547,18.104918
run1,11.805571,15.044322,11.237439,11.931779,12.969425,15.581717,9.856895,13.725124,11.526301,14.061518,11.952082,10.279514,11.514524,12.659127,12.861456,17.883759
run2,24.525801,18.108761,20.024841,21.462845,19.462299,22.21596,17.334078,20.697651,23.867332,23.035694,20.429434,23.196527,18.913696,20.672243,23.938095,23.057203
run3,10.415208,11.66072,9.017159,10.041602,13.670305,21.903774,8.506882,14.730402,8.974275,9.248629,8.875171,9.845469,9.470477,9.659947,10.279847,11.853412
run4,21.18898,21.964054,17.449308,19.457445,20.662504,25.290188,16.692625,20.677465,17.139063,23.330189,17.960478,17.799082,16.720442,16.196817,19.602741,22.356787
run5,14.852406,14.770286,12.207129,14.233829,16.211283,15.57139,11.706421,14.850018,13.181308,13.687331,13.856328,13.425875,12.915214,12.590552,15.976061,16.122864
run6,14.382274,16.701502,14.719275,16.889418,15.610317,21.504208,13.39844,15.033917,15.342709,15.523799,13.211658,11.409214,13.627624,12.569417,15.290827,17.398552
run7,12.856173,16.728304,13.987518,15.085563,15.389658,16.450613,12.912967,15.045295,14.786914,14.540909,15.129618,13.181927,13.818642,12.977942,15.261646,17.201935
