In [1]:
import warnings
import tensorflow as tf
warnings.filterwarnings("ignore")
tf.get_logger().setLevel('ERROR')

import numpy as np
import pandas as pd
import deepchem as dc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_max_pool
from torch_geometric.data import Batch, Data
from torch_geometric.loader import DataLoader
from sklearn import metrics




Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(30, 256)
        self.conv2 = GCNConv(256, 256)
        self.conv3 = GCNConv(256, 256)
        self.conv4 = GCNConv(256, 256)
        self.fc1 = nn.Linear(256, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)
        self.dropout1 = nn.Dropout(p=0.15)
        self.dropout2 = nn.Dropout(p=0.3)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout1(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.conv4(x, edge_index)
        x = F.relu(x)
        x = global_max_pool(x, data.batch)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

In [3]:
def custom_collate(batch):
    data_list, target_list = zip(*batch)
    batch_data = Batch.from_data_list(data_list)
    batch_target = torch.stack(target_list)
    return batch_data, batch_target

In [4]:
def calculate_statistics(group):
    r2_test = group['r2_test']
    r2_test_dict = {f'run{i}': r2_test_val for i, r2_test_val in enumerate(r2_test)}
    return pd.Series({
        **r2_test_dict, 
        'r2_test_mean': np.mean(r2_test),
        'r2_test_max': np.max(r2_test),
        'r2_test_min': np.min(r2_test),
        'r2_test_std': np.std(r2_test, ddof=0),
    })

def calculate_statistics2(group):
    rmse_test = group['rmse_test']
    rmse_test_dict = {f'run{i}': rmse_test_val for i, rmse_test_val in enumerate(rmse_test)}
    return pd.Series({
        **rmse_test_dict, 
        'rmse_test_mean': np.mean(rmse_test),
        'rmse_test_max': np.max(rmse_test),
        'rmse_test_min': np.min(rmse_test),
        'rmse_test_std': np.std(rmse_test, ddof=0),
    })

In [5]:
torch.manual_seed(0)

first_epochs = 40
second_epochs = 180
second_lr = 3e-3
second_wd = 7e-4

results_r2 = []
results_rmse = []
for random_state in range(10):
    torch.manual_seed(0)
    scaler = StandardScaler()
    
    for dataset in ["abcgg", "aatsc3d", "atsc3d", "kappa2", "peoevsa6", "bertzct", "ggi10", "vsaestate3",
                    "atsc4i", "bcutp1l", "kappa3", "estatevsa3", "kier3", "aats8p", "kier2", "frnh0"]:
        torch.manual_seed(0)
        
        for t in ["Yield_CO_cl"]:
            torch.manual_seed(0)
            df = pd.read_csv('data_Real/data_real.csv')
            smiles = df["SMILES"]
            featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)
            X = featurizer.featurize(smiles)
            
            y = df[t]
            data_train, data_test, target_train, target_test = train_test_split(X, y, test_size=0.5, random_state=random_state)

            target_train = scaler.fit_transform(target_train.values.reshape(-1, 1)).flatten()
            target_test = scaler.transform(target_test.values.reshape(-1, 1)).flatten()
            
            target_train = torch.tensor(target_train, dtype=torch.float32)
            target_test = torch.tensor(target_test, dtype=torch.float32)

            data_train_list = []
            for graph_data in data_train:
                node_features = torch.tensor(graph_data.node_features, dtype=torch.float32)
                edge_index = torch.tensor(graph_data.edge_index, dtype=torch.long)
                edge_features = torch.tensor(graph_data.edge_features, dtype=torch.float32)
                data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)
                data_train_list.append(data)

            data_test_list = []
            for graph_data in data_test:
                node_features = torch.tensor(graph_data.node_features, dtype=torch.float32)
                edge_index = torch.tensor(graph_data.edge_index, dtype=torch.long)
                edge_features = torch.tensor(graph_data.edge_features, dtype=torch.float32)
                data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)
                data_test_list.append(data)

            train_loader = DataLoader(list(zip(data_train_list, target_train)), batch_size=len(data_train_list), collate_fn=custom_collate)
            test_loader = DataLoader(list(zip(data_test_list, target_test)), batch_size=len(data_test_list), collate_fn=custom_collate)

            model = Net()
            model.load_state_dict(torch.load(f'data_AI2+Human/model_{dataset}_sc.pth'))
            model.fc3 = nn.Linear(128, 1)
        
            for param in model.parameters():
                param.requires_grad = False
        
            model.train()
            optimizer = torch.optim.Adam(model.parameters())
            criterion = nn.MSELoss()
            
            device = torch.device('cpu')
            model.to(device)

            for epoch in range(first_epochs):
                for data, target in train_loader:
                    data = data.to(device)
                    target = target.to(device)
                    with torch.no_grad():
                        out = model(data)
                        loss = criterion(out, target.view(-1, 1))

            for param in model.fc1.parameters():
                param.requires_grad = True
            for param in model.fc2.parameters():
                param.requires_grad = True
            for param in model.fc3.parameters():
                param.requires_grad = True

            optimizer = torch.optim.Adam(model.parameters(), lr=second_lr, weight_decay=second_wd)

            for epoch in range(second_epochs):
                for data, target in train_loader:
                    data = data.to(device)
                    target = target.to(device)
                    optimizer.zero_grad()
                    out = model(data)
                    loss = criterion(out, target.view(-1, 1))
                    loss.backward()
                    optimizer.step()

            model.eval()
            pred_train = []
            for data, target in train_loader:
                data = data.to(device)
                with torch.no_grad():
                    out = model(data)
                pred_train.append(out.cpu().numpy())
            pred_train = np.concatenate(pred_train)

            pred_test = []
            for data, target in test_loader:
                data = data.to(device)
                with torch.no_grad():
                    out = model(data)
                pred_test.append(out.cpu().numpy())
            pred_test = np.concatenate(pred_test)

            pred_train = scaler.inverse_transform(pred_train)
            pred_test = scaler.inverse_transform(pred_test)
            target_train = scaler.inverse_transform(target_train.numpy().reshape(-1, 1)).flatten()
            target_test = scaler.inverse_transform(target_test.numpy().reshape(-1, 1)).flatten()

            r2_test_score = metrics.r2_score(target_test, pred_test)
            rmse_test_score = metrics.root_mean_squared_error(target_test, pred_test)
            results_r2.append({'source': dataset, 'target': t, 'r2_test': r2_test_score})
            results_rmse.append({'source': dataset, 'target': t, 'rmse_test': rmse_test_score})

results_df = pd.DataFrame(results_r2)
gen_results = results_df.groupby(['source', 'target']).apply(calculate_statistics).reset_index()
results_df2 = pd.DataFrame(results_rmse)
gen_results2 = results_df2.groupby(['source', 'target']).apply(calculate_statistics2).reset_index()

In [6]:
gen_results.T.to_csv('result/result_yield_cl_r2.csv', header=False)
gen_results.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
source,aats8p,aatsc3d,abcgg,atsc3d,atsc4i,bcutp1l,bertzct,estatevsa3,frnh0,ggi10,kappa2,kappa3,kier2,kier3,peoevsa6,vsaestate3
target,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl
run0,0.545778,0.587173,0.643066,0.590701,0.646793,0.529709,0.727486,0.165328,0.461808,0.647928,0.570583,0.671767,0.683162,0.571123,0.662888,0.571227
run1,0.666128,0.427976,0.74132,0.453364,0.578799,-0.518101,0.815835,0.601591,0.597677,0.577772,0.690986,0.586818,0.646759,0.686821,0.435254,0.560342
run2,0.431951,0.648165,0.479322,0.642427,0.643262,0.525887,0.75453,0.36333,0.496568,0.71599,0.525275,0.403748,0.604109,0.609642,0.615333,0.365557
run3,0.768907,0.749634,0.785444,0.765579,0.767375,0.23865,0.803814,0.687035,0.755501,0.777783,0.742359,0.699955,0.756686,0.74993,0.697486,0.642479
run4,0.774995,0.719743,0.785202,0.718991,0.77695,0.471256,0.858685,0.710648,0.791395,0.76879,0.745208,0.732506,0.760757,0.781043,0.651939,0.671009
run5,0.684849,0.603645,0.752261,0.663834,0.608482,0.572966,0.788501,0.70139,0.786117,0.727192,0.725402,0.693713,0.655347,0.689483,0.635461,0.624526
run6,0.727219,0.670547,0.776064,0.64086,0.660463,0.306034,0.810087,0.716174,0.695958,0.629876,0.744266,0.68284,0.702438,0.734632,0.586375,0.618205
run7,0.630865,0.69939,0.737124,0.715236,0.705701,0.499035,0.786214,0.63182,0.728797,0.649046,0.746023,0.683091,0.691754,0.701128,0.723365,0.650433


In [7]:
gen_results2.T.to_csv('result/result_yield_cl_rmse.csv', header=False)
gen_results2.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
source,aats8p,aatsc3d,abcgg,atsc3d,atsc4i,bcutp1l,bertzct,estatevsa3,frnh0,ggi10,kappa2,kappa3,kier2,kier3,peoevsa6,vsaestate3
target,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl,Yield_CO_cl
run0,13.057136,12.447946,11.574654,12.394654,11.514057,13.286083,10.113657,17.699938,14.212898,11.495541,12.695612,11.099539,10.905169,12.687628,11.248665,12.686083
run1,9.395702,12.29833,8.270281,12.022316,10.553191,20.035,6.97819,10.263692,10.313981,10.566038,9.039158,10.452252,9.664388,9.099878,12.219844,10.781932
run2,14.974754,11.785179,14.336774,11.880891,11.867021,13.680676,9.843874,15.853459,14.097342,10.588481,13.689515,15.341991,12.501283,12.41361,12.322793,15.825704
run3,9.298702,9.678697,8.959822,9.365408,9.329473,16.878004,8.567673,10.82124,9.564617,9.118375,9.818304,10.595507,9.541407,9.672964,10.639022,11.56591
run4,9.317103,10.398311,9.103315,10.412255,9.276529,14.282608,7.383789,10.565698,8.971138,9.444695,9.914658,10.158791,9.607366,9.191026,11.588098,11.266185
run5,9.769127,10.955661,8.661514,10.089578,10.888599,11.371759,8.002963,9.509294,8.047937,9.089187,9.118947,9.63076,10.216156,9.697038,10.506744,10.663171
run6,9.382663,10.311367,8.501216,10.765929,10.46798,14.965402,7.828818,9.570744,9.905725,10.929313,9.08476,10.117167,9.799589,9.254309,11.553742,11.10029
run7,11.562319,10.434075,9.757265,10.155355,10.323971,13.469642,8.799174,11.54736,9.910595,11.273979,9.590687,10.713213,10.56577,10.403875,10.009343,11.25168
