In [1]:
import pandas as pd
import numpy as np
import torch
from tdc.single_pred.adme import ADME
from tdc import Evaluator
from tqdm.notebook import tqdm, trange
from torch.utils.data import TensorDataset, DataLoader
from rdkit import Chem
from rdkit.Chem import AllChem
from matplotlib import pyplot as plt
from IPython import display

from typing import List, Tuple


class Featurizer:
    def __init__(self, y_column, smiles_col='Drug', **kwargs):
        self.y_column = y_column
        self.smiles_col = smiles_col
        self.__dict__.update(kwargs)
    
    def __call__(self, df):
        raise NotImplementedError()

In [2]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader


def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise ValueError("input {0} not in allowable set{1}:".format(
            x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))


def one_of_k_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))


from rdkit.Chem import rdMolDescriptors

class GraphFeaturizer(Featurizer):
    def __call__(self, df, getRepresentation):
        graphs = []
        labels = []
        for i, row in df.iterrows():
            y = row[self.y_column]
            smiles = row[self.smiles_col]
            mol = Chem.MolFromSmiles(smiles)
            
            edges = []
            for bond in mol.GetBonds():
                begin = bond.GetBeginAtomIdx()
                end = bond.GetEndAtomIdx()
                edges.append((begin, end))  # TODO: Add edges in both directions
            edges = np.array(edges)
            
            nodes = []
            for atom in mol.GetAtoms():
                # print(atom.GetAtomicNum(), atom.GetNumImplicitHs(), atom.GetTotalNumHs(), atom.GetSymbol(), atom.GetNumExplicitHs(), atom.GetTotalValence())
                results = getRepresentation(atom)
                # print(results)
                nodes.append(results)
            nodes = np.array(nodes)
            
            graphs.append((nodes, edges.T))
            labels.append(y)
        labels = np.array(labels)
        return [Data(
            x=torch.FloatTensor(x), 
            edge_index=torch.LongTensor(edge_index), 
            y=torch.FloatTensor([y])
        ) for ((x, edge_index), y) in zip(graphs, labels)]

In [3]:
def defaultRepresentation(atom):
    return one_of_k_encoding_unk(atom.GetAtomicNum(), range(11)) + one_of_k_encoding(
                    atom.GetDegree(), range(11)
                ) + one_of_k_encoding_unk(
                    atom.GetImplicitValence(), range(11)
                ) + [atom.GetIsAromatic()] + one_of_k_encoding_unk(
                    atom.GetTotalNumHs(), range(11)
                ) + [atom.GetNumImplicitHs(), atom.GetFormalCharge(), atom.GetNumRadicalElectrons(), atom.IsInRing()] # TODO: Add atom features as a list, you can use one_of_k_encodings defined above

def representation1(atom):
    return one_of_k_encoding_unk(atom.GetAtomicNum(), range(12)) + one_of_k_encoding_unk(
                    atom.GetDegree(), range(6)) + one_of_k_encoding_unk(
                    atom.GetTotalNumHs(), range(5)
                ) + [atom.GetFormalCharge(), atom.IsInRing(), atom.GetIsAromatic()]

def representation10(atom):
    return one_of_k_encoding_unk(atom.GetAtomicNum(), range(12)) + one_of_k_encoding_unk(
                    atom.GetDegree(), range(6)) + one_of_k_encoding_unk(
                    atom.GetTotalNumHs(), range(5)
                ) + [atom.IsInRing(), atom.GetIsAromatic()]

def representationAll(atom):
    return one_of_k_encoding_unk(atom.GetAtomicNum(), range(12)) + one_of_k_encoding_unk(
                    atom.GetDegree(), range(6)) + one_of_k_encoding_unk(
                    atom.GetTotalNumHs(), range(5)) + one_of_k_encoding_unk(
                    atom.GetImplicitValence(), range(6))  + one_of_k_encoding_unk(
                    atom.GetHybridization(),
                    [
                        Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                        Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
                        Chem.rdchem.HybridizationType.SP3D2
                    ]
                ) + [atom.GetFormalCharge(), atom.IsInRing(), atom.GetIsAromatic()
                     ] + [atom.GetNumRadicalElectrons()]

def printProperties(atom):
    print("=========")
    print("GetDegree", atom.GetDegree())
    print("GetImplicitValence", atom.GetImplicitValence())
    print("GetAtomicNum", atom.GetAtomicNum())
    print("GetTotalNumHs", atom.GetTotalNumHs())
    print("GetNumImplicitHs", atom.GetNumImplicitHs())
    print("GetNeighbors", atom.GetNeighbors())
    print("GetNumExplicitHs", atom.GetNumExplicitHs())
    print("GetTotalDegree", atom.GetTotalDegree())
    print("GetTotalNumHs", atom.GetTotalNumHs())
    print("GetTotalValence", atom.GetTotalValence())

In [4]:
class ECFPFeaturizer(Featurizer):
    def __init__(self, y_column, radius=2, length=1024, **kwargs):
        self.radius = radius
        self.length = length
        super().__init__(y_column, **kwargs)
    
    def __call__(self, df):
        fingerprints = []
        labels = []
        for i, row in df.iterrows():
            y = row[self.y_column]
            smiles = row[self.smiles_col]
            mol = Chem.MolFromSmiles(smiles)
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, self.radius, nBits=self.length)
            fingerprints.append(fp)
            labels.append(y)
        fingerprints = np.array(fingerprints)
        labels = np.array(labels)
        return fingerprints, labels

dataset = pd.read_csv("./datasets/esol.csv")

from sklearn.utils import shuffle
dataset = shuffle(dataset)
train_dataset = dataset[:500]
val_dataset = dataset[500:800]
test_dataset = dataset[800:1100]

In [5]:
rmse = Evaluator(name = 'RMSE')

featurizer = ECFPFeaturizer(y_column='measured log solubility in mols per litre', smiles_col="smiles")
X_train, y_train = featurizer(train_dataset)
X_valid, y_valid = featurizer(val_dataset)
X_test, y_test = featurizer(test_dataset)

featurizer = GraphFeaturizer(y_column='measured log solubility in mols per litre', smiles_col="smiles")

graph = featurizer(test_dataset.iloc[:1], representationAll)[0]

In [6]:
from torch_geometric.loader import DataLoader as GraphDataLoader


# prepare data loaders
batch_size = 64

train_loader1 = GraphDataLoader(featurizer(train_dataset, representation1), batch_size=batch_size, shuffle=True)
valid_loader1 = GraphDataLoader(featurizer(val_dataset, representation1), batch_size=batch_size)
test_loader1 = GraphDataLoader(featurizer(test_dataset, representation1), batch_size=batch_size)

train_loader10 = GraphDataLoader(featurizer(train_dataset, representation10), batch_size=batch_size, shuffle=True)
valid_loader10 = GraphDataLoader(featurizer(val_dataset, representation10), batch_size=batch_size)
test_loader10 = GraphDataLoader(featurizer(test_dataset, representation10), batch_size=batch_size)

train_loader = GraphDataLoader(featurizer(train_dataset, representationAll), batch_size=batch_size, shuffle=True)
valid_loader = GraphDataLoader(featurizer(val_dataset, representationAll), batch_size=batch_size)
test_loader = GraphDataLoader(featurizer(test_dataset, representationAll), batch_size=batch_size)

In [7]:
from torch_geometric.nn import GCNConv, GINConv, global_mean_pool
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool as gap

In [8]:
#warstwa attention pooling
class MyAttentionModule3(torch.nn.Module): # zakladamy ze atom ma 49 featerow
    def __init__(self, groupFeatures=1):
        super().__init__()
        self.groupFeatures = groupFeatures
        self.gates = torch.nn.ModuleDict({ # do wyliczenia atencji dla kazdej grupy cech - jest ich 9
            'AtomicNum': GCNConv(12, 1),
            'Degree': GCNConv(6, 1),
            'TotalNumHs': GCNConv(5, 1),
            'ImplicitValence': GCNConv(6, 1),
            'Hybridization': GCNConv(5, 1),
            'FormalCharge': GCNConv(1, 1),
            'IsInRing': GCNConv(1, 1),
            'IsAromatic': GCNConv(1, 1),
            'NumRadicalElectrons': GCNConv(1, 1)
        })
        
        self.feats = torch.nn.ModuleDict({ # do transformacji grupy cech w wektor, na razie dziala tylko dla groupFeatures=1
            'AtomicNum': torch.nn.Linear(12, groupFeatures),
            'Degree': torch.nn.Linear(6, groupFeatures),
            'TotalNumHs': torch.nn.Linear(5, groupFeatures),
            'ImplicitValence': torch.nn.Linear(6, groupFeatures),
            'Hybridization': torch.nn.Linear(5, groupFeatures),
            'FormalCharge': torch.nn.Linear(1, groupFeatures),
            'IsInRing': torch.nn.Linear(1, groupFeatures),
            'IsAromatic': torch.nn.Linear(1, groupFeatures),
            'NumRadicalElectrons': torch.nn.Linear(1, groupFeatures)
        })

    def forward(self, x, edge_index, batch):
        gates = []
        gates.append(self.gates['AtomicNum'](x[:,0:12], edge_index))
        gates.append(self.gates['Degree'](x[:,12:18], edge_index))
        gates.append(self.gates['TotalNumHs'](x[:,18:23], edge_index))
        gates.append(self.gates['ImplicitValence'](x[:,23:29], edge_index))
        gates.append(self.gates['Hybridization'](x[:,29:34], edge_index))
        gates.append(self.gates['FormalCharge'](x[:,34:35], edge_index))
        gates.append(self.gates['IsInRing'](x[:,35:36], edge_index))
        gates.append(self.gates['IsAromatic'](x[:,36:37], edge_index))
        gates.append(self.gates['NumRadicalElectrons'](x[:,37:38], edge_index))
        logits = torch.cat(gates, dim=-1)
        attention = torch.softmax(logits, dim=-1).unsqueeze(-1)
        
        subgroups = []
        subgroups.append(self.feats['AtomicNum'](x[:,0:12]) * attention[:,0])
        subgroups.append(self.feats['Degree'](x[:,12:18]) * attention[:,1])
        subgroups.append(self.feats['TotalNumHs'](x[:,18:23]) * attention[:,2])
        subgroups.append(self.feats['ImplicitValence'](x[:,23:29]) * attention[:,3])
        subgroups.append(self.feats['Hybridization'](x[:,29:34]) * attention[:,4])
        subgroups.append(self.feats['FormalCharge'](x[:,34:35]) * attention[:,5])
        subgroups.append(self.feats['IsInRing'](x[:,35:36]) * attention[:,6])
        subgroups.append(self.feats['IsAromatic'](x[:,36:37]) * attention[:,7])
        subgroups.append(self.feats['NumRadicalElectrons'](x[:,37:38]) * attention[:,8])
        x = torch.stack(subgroups, dim=-2)
        x = torch.sum(x, dim=-2)
        
        return x, attention

In [9]:
def attSequential(n_feats):
    return torch.nn.Sequential(torch.nn.Linear(n_feats, 1),
                       torch.nn.BatchNorm1d(1), torch.nn.ReLU(),
                       torch.nn.Linear(1, 1), torch.nn.ReLU())

In [10]:
#warstwa attention pooling
class MyAttentionModule4(torch.nn.Module): # zakladamy ze atom ma 49 featerow
    def __init__(self, groupFeatures=1):
        super().__init__()
        self.groupFeatures = groupFeatures
        self.gates = torch.nn.ModuleDict({ # do wyliczenia atencji dla kazdej grupy cech - jest ich 9
            'AtomicNum': GINConv(attSequential(12), train_eps=True),
            'Degree': GINConv(attSequential(6), train_eps=True),
            'TotalNumHs': GINConv(attSequential(5), train_eps=True),
            'ImplicitValence': GINConv(attSequential(6), train_eps=True),
            'Hybridization': GINConv(attSequential(5), train_eps=True),
            'FormalCharge': GINConv(attSequential(1), train_eps=True),
            'IsInRing': GINConv(attSequential(1), train_eps=True),
            'IsAromatic': GINConv(attSequential(1), train_eps=True),
            'NumRadicalElectrons': GINConv(attSequential(1), train_eps=True)
        })
        
        self.feats = torch.nn.ModuleDict({ # do transformacji grupy cech w wektor, na razie dziala tylko dla groupFeatures=1
            'AtomicNum': torch.nn.Linear(12, groupFeatures),
            'Degree': torch.nn.Linear(6, groupFeatures),
            'TotalNumHs': torch.nn.Linear(5, groupFeatures),
            'ImplicitValence': torch.nn.Linear(6, groupFeatures),
            'Hybridization': torch.nn.Linear(5, groupFeatures),
            'FormalCharge': torch.nn.Linear(1, groupFeatures),
            'IsInRing': torch.nn.Linear(1, groupFeatures),
            'IsAromatic': torch.nn.Linear(1, groupFeatures),
            'NumRadicalElectrons': torch.nn.Linear(1, groupFeatures)
        })

    def forward(self, x, edge_index, batch):
        gates = []
        gates.append(self.gates['AtomicNum'](x[:,0:12], edge_index))
        gates.append(self.gates['Degree'](x[:,12:18], edge_index))
        gates.append(self.gates['TotalNumHs'](x[:,18:23], edge_index))
        gates.append(self.gates['ImplicitValence'](x[:,23:29], edge_index))
        gates.append(self.gates['Hybridization'](x[:,29:34], edge_index))
        gates.append(self.gates['FormalCharge'](x[:,34:35], edge_index))
        gates.append(self.gates['IsInRing'](x[:,35:36], edge_index))
        gates.append(self.gates['IsAromatic'](x[:,36:37], edge_index))
        gates.append(self.gates['NumRadicalElectrons'](x[:,37:38], edge_index))
        logits = torch.cat(gates, dim=-1)
        attention = torch.softmax(logits, dim=-1).unsqueeze(-1)
        
        subgroups = []
        subgroups.append(self.feats['AtomicNum'](x[:,0:12]) * attention[:,0])
        subgroups.append(self.feats['Degree'](x[:,12:18]) * attention[:,1])
        subgroups.append(self.feats['TotalNumHs'](x[:,18:23]) * attention[:,2])
        subgroups.append(self.feats['ImplicitValence'](x[:,23:29]) * attention[:,3])
        subgroups.append(self.feats['Hybridization'](x[:,29:34]) * attention[:,4])
        subgroups.append(self.feats['FormalCharge'](x[:,34:35]) * attention[:,5])
        subgroups.append(self.feats['IsInRing'](x[:,35:36]) * attention[:,6])
        subgroups.append(self.feats['IsAromatic'](x[:,36:37]) * attention[:,7])
        subgroups.append(self.feats['NumRadicalElectrons'](x[:,37:38]) * attention[:,8])
        x = torch.stack(subgroups, dim=-2)
        x = torch.sum(x, dim=-2)
        
        return x, attention

In [11]:
class GraphNeuralNetwork(torch.nn.Module):  # TODO: assign hyperparameters to attributes and define the forward pass
    def __init__(self, hidden_size, n_convs=3, my_layer=None, features_after_layer=26, n_features=49, dropout=0.2):
        super().__init__()
        self.myAttentionModule = my_layer
        self.dropout = dropout

        convs = torch.nn.ModuleList()
        convs.append(GCNConv(features_after_layer, hidden_size))
        for i in range(1, n_convs):
            convs.append(GCNConv(hidden_size, hidden_size))
        self.convs = convs
        self.linear = torch.nn.Linear(hidden_size, 1)
    
    def forward(self, x, edge_index, batch):
        att = None
        if self.myAttentionModule is not None:
            x, att = self.myAttentionModule(x, edge_index, batch)
        for i in range(0, len(self.convs)-1):
            x = self.convs[i](x, edge_index)
            x = x.relu()
        x = self.convs[-1](x, edge_index)
        
        x = gap(x, batch)
        
        x = F.dropout(x, p=self.dropout, training=self.training)

        out = self.linear(x)

        return out, att

In [12]:
def train(model, train_loader, valid_loader, epochs=20, learning_rate = 0.01):
    model.train()
    
    # training loop
    optimizer = torch.optim.Adam(model.parameters(), learning_rate) # TODO: define an optimizer
    loss_fn = torch.nn.MSELoss()  # TODO: define a loss function
    for epoch in trange(1, epochs + 1, leave=False):
        for data in tqdm(train_loader, leave=False):
            x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
            model.zero_grad()
            preds, att = model(x, edge_index, batch)
            loss = loss_fn(preds, y.reshape(-1, 1))
            loss.backward()
            # print("==============")
            # for par in model.myAttentionModule.parameters():
            #     print(par)
            optimizer.step()
    return model


def predict(model, test_loader):
    # evaluation loop
    preds_batches = []
    with torch.no_grad():
        for data in tqdm(test_loader):
            x, edge_index, batch = data.x, data.edge_index, data.batch
            
            preds, att = model(x, edge_index, batch)
            preds_batches.append(preds.cpu().detach().numpy())
    preds = np.concatenate(preds_batches)
    return preds, att

In [13]:
from copy import deepcopy

def train_best(model, train_loader, valid_loader, epochs=20, learning_rate = 0.01, saveImg=False, title=""):
    model.train()

    torch.save(model, "train.pth")
    best_val = 1000000
    
    # training loop
    optimizer = torch.optim.Adam(model.parameters(), learning_rate) # TODO: define an optimizer
    loss_fn = torch.nn.MSELoss()  # TODO: define a loss function
    for epoch in trange(1, epochs + 1, leave=False):
        # preds_batches = []
        running_loss = 0.0
        for data in train_loader:
            x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
            model.zero_grad()
            preds, att = model(x, edge_index, batch)
            loss = loss_fn(preds, y.reshape(-1, 1))

            loss.backward()
            optimizer.step()

        # evaluation loop
        preds_batches = []
        with torch.no_grad():
            for data in valid_loader:
                x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
                preds, att = model(x, edge_index, batch)
                loss = loss_fn(preds, y.reshape(-1, 1))
                preds_batches.append(preds.cpu().detach().numpy())
        preds = np.concatenate(preds_batches)
        mae = rmse(y_valid, preds.flatten())
        if mae < best_val:
            torch.save(model, "train.pth")
            best_val = mae
            print(best_val)

    model = torch.load("train.pth")
    model.eval()
    return model

In [14]:
def visualize(model, train_loader, valid_loader, test_loader, epochs=20, learning_rate = 0.01, saveImg=False, title=""):
    model.train()

    best_state = deepcopy(model.state_dict())
    best_val = 1000000
    
    # training loop
    optimizer = torch.optim.Adam(model.parameters(), learning_rate) # TODO: define an optimizer
    loss_fn = torch.nn.MSELoss()  # TODO: define a loss function
    train_losses = []
    val_losses = []
    train_errors = []
    val_errors = []
    for epoch in trange(1, epochs + 1, leave=False):
        # preds_batches = []
        running_loss = 0.0
        for data in train_loader:
            x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
            model.zero_grad()
            preds, att = model(x, edge_index, batch)
            loss = loss_fn(preds, y.reshape(-1, 1))
            # print(len(train_dataset))

            running_loss += loss.item()
            # preds_batches.append(preds.cpu().detach().numpy())

            loss.backward()
            optimizer.step()
        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)
        # preds = np.concatenate(preds_batches)
        # mae = rmse(y_train, preds.flatten())
        # train_errors.append(mae)

        # evaluation loop
        preds_batches = []
        running_loss = 0.0
        with torch.no_grad():
            for data in valid_loader:
                x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
                preds, att = model(x, edge_index, batch)
                loss = loss_fn(preds, y.reshape(-1, 1))
                # print(len(train_dataset))

                running_loss += loss.item()
                preds_batches.append(preds.cpu().detach().numpy())
        epoch_loss = running_loss / len(valid_loader)
        val_losses.append(epoch_loss)
        preds = np.concatenate(preds_batches)
        mae = rmse(y_valid, preds.flatten())
        if mae < best_val:
            best_state = deepcopy(model.state_dict())
            best_val = mae
            print(best_val)
        val_errors.append(mae)

    model.load_state_dict(best_state)

    ##### visualize ########
    plt.plot(train_losses, label='train_loss')
    plt.plot(val_losses, label='val_loss')
    plt.legend()
    plt.show()
    if saveImg:
        plt.savefig(title + "_loss.png")

    # plt.plot(train_errors,label='train_errors')
    plt.plot(val_errors, label='val_RMSE')
    plt.legend()
    plt.show()
    if saveImg:
        plt.savefig(title + "_val_error.png")
    return model

In [15]:
df = pd.DataFrame({"Repr 1": [], "Repr 10": [],
                   "transfer learning - size = 3": [], "transfer learning - size = 35": [], "transfer learning - size = 100": [], "transfer learning - size = 35 (big)": []})
pd.set_option("display.precision", 2)

In [16]:
n_times = 1

for n_convs in [1, 3, 5]:
    for n_channels in [64, 512]:
        row = []
        #########################
        scores = []
        for _ in range(n_times):
            m =  GraphNeuralNetwork(n_channels, n_convs=n_convs, features_after_layer=26)

            m = train_best(m, train_loader1, valid_loader1, 70)
            predictions, att = predict(m, test_loader1)
            rmse_score = rmse(y_test, predictions.flatten())
            scores.append("{:.2f}".format(rmse_score))
        row.append(" | ".join(scores))

        #########################
        scores = []
        for _ in range(n_times):
            m =  GraphNeuralNetwork(n_channels, n_convs=n_convs, features_after_layer=25)

            m = train_best(m, train_loader10, valid_loader10, 70)
            predictions, att = predict(m, test_loader10)
            rmse_score = rmse(y_test, predictions.flatten())
            scores.append("{:.2f}".format(rmse_score))
        row.append(" | ".join(scores))

        #########################
        for vect_size in [3, 35, 100]:
            scores = []
            for _ in range(n_times):
                m =  GraphNeuralNetwork(n_channels, n_convs=n_convs, my_layer=MyAttentionModule4(vect_size), features_after_layer=vect_size)

                m.myAttentionModule.load_state_dict(torch.load('attention_pooling' + str(vect_size) + '.pth'))
                m.eval()
                for par in m.myAttentionModule.parameters():
                        par.requires_grad = False

                m = train_best(m, train_loader, valid_loader, 70)
                predictions, att = predict(m, test_loader)
                rmse_score = rmse(y_test, predictions.flatten())
                scores.append("{:.2f}".format(rmse_score))
            row.append(" | ".join(scores))

        #########################
        scores = []
        for _ in range(n_times):
            m =  GraphNeuralNetwork(n_channels, n_convs=n_convs, my_layer=MyAttentionModule4(35), features_after_layer=35)

            m.myAttentionModule.load_state_dict(torch.load('attention_pooling35_big.pth'))
            m.eval()
            for par in m.myAttentionModule.parameters():
                    par.requires_grad = False

            m = train_best(m, train_loader, valid_loader, 70)
            predictions, att = predict(m, test_loader)
            rmse_score = rmse(y_test, predictions.flatten())
            scores.append("{:.2f}".format(rmse_score))
        row.append(" | ".join(scores))

        df.loc[str(n_convs) + " convs, " + str(n_channels) + " channels"] = row

  0%|          | 0/70 [00:00<?, ?it/s]

1.919210000241129
1.8875742012074346
1.8128823640992133
1.7828375152351428
1.7484528173592107
1.7008084572651492
1.662525493547055
1.5934826422368402
1.545240702663627
1.528743632285316
1.5107668068537372
1.4600473617386434
1.4259611257099267
1.4110027002222325
1.3800354996381161


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.8659060461439392
1.8377834976476903
1.7517939977746408
1.717046140096354
1.6918970731926377
1.6256029081830696
1.5986230923735258
1.5579976444057555
1.5503093044875527
1.5170412150073322
1.4875040306594676
1.4761089340868143
1.4444659769972963
1.4440566448902117
1.4201252319780804
1.4110182616482496
1.3989328806615169
1.391021213108861


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

3.0870910779973766
2.189100186675679
2.1613052160104824
2.0611582423017163
2.060824494013457
2.0546840749533515
2.045551451966513
2.0382254176073697
2.019116103421966
2.017594243367857
2.0043235443066645


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.197210947342556
2.073512992042253
1.9997812959524564
1.874556774003057
1.7851887857834072
1.7562887072848594
1.6784919960029894
1.6554540879391506
1.6257196699498806
1.6218138887553153
1.553740557078893
1.5499615161906348
1.5368622390843156
1.5269646646270914
1.52276136519477
1.4922921860656064
1.4753968317884443
1.4697451035187215
1.4587216806337435
1.4557226985772365
1.4542277757698447
1.4309677551891782


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.0831263170051026
1.9413869876089362
1.9255884780781447
1.8254380949461888
1.8152034643100654
1.7864527781378445
1.7218022699423106
1.6938066174587127
1.6414945389948592
1.6292729284352476
1.582764669193196
1.5627790021533017
1.5297559916545849
1.5238314757289195
1.5009720959573898
1.4421782216151957
1.4159648472485633
1.3979880229801283
1.376301337032023
1.3650116236585563


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.9983794288179046
1.801950634485625
1.6015408449727886
1.4941554743011511
1.4598194480549034
1.4084674447719288
1.4042729564309242
1.3863328239727128
1.3749361339234718
1.352853993019618
1.3336179804170916
1.3248869128899554
1.321134010080794
1.3208861898548416
1.31083827683457
1.2888652854116913


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.9965476113520741
1.8507389032014334
1.7029651640371517
1.5226383523451907
1.4760979040421711
1.4203167245394293
1.3844797831498241
1.3727112680527331


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.9499196082584371
1.6631863642828542
1.6163522539969764
1.5376643012997049
1.5025224538734103
1.4646591184810882
1.4402458055951026
1.4242034314258076
1.3933725649667286
1.3900498345330934


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.111751284885849
2.0899294360782066
2.0388174829742827
2.0227440197976314
2.019444418165573
2.015503376842233
2.01364070426803
2.0098241156596592
2.00657888672797


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.9800487935860718
1.7715563350535866
1.737341869026017
1.64335317679222
1.6390311965435982
1.6300105306192727
1.5687549227202198
1.5501777043693596
1.5405869330760116
1.5098845468451867
1.480905292482298
1.454832458947198
1.4486157149612586
1.4419394901123346
1.4172067492658682
1.4121588006671804


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.2450304896164264
1.9448839047213737
1.9007180499102647
1.7610400616217117
1.7085180418521806
1.5882125742941142
1.5410311147513536
1.5103040311099383
1.4708867502273897
1.4500832380419963
1.438689163931888
1.4362738879757364
1.3996028330975163
1.3806520877846145
1.3785544084993315
1.3774329364686395
1.3662455288016355
1.3548658562916267
1.347193543049383


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.8907471879809903
1.63814265451715
1.4589414187760394
1.4258487974199883
1.3974273874001535
1.3609753055922746
1.3245909022675704
1.3223722005452718
1.3216503979979561
1.3138720049414039
1.3126323629143253
1.2980252296546704
1.2961175720037676
1.295850216029008
1.295367156953934
1.2857346543375718


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.2593763514266847
1.8665154480468513
1.79528746828003
1.664619323950496
1.626638293038277
1.5349713223314467
1.5140956286606704
1.479857670989559
1.4587641091926336
1.440531751525087
1.4309458788800742
1.4165501445390454
1.3876976029298491
1.3766406339989687
1.3616135086550623


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.235420419580229
1.7915238315559567
1.7059149312084498
1.6140493313784399
1.6042493914177562
1.5487218317628035
1.504146795759281
1.4676365511825207
1.4419061120794072
1.4323486667351146
1.412253036250473
1.4117477741297348
1.4082975888429898
1.4005828652455103
1.3895564635779991


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.1372533100645965
2.118094361911294
2.0836641739629878
2.028119926652294
2.0106115891232386
1.9885724921551862
1.929931415340632
1.8692707106083508
1.8394112126599855
1.7912613547494112
1.7649617540559595
1.7413409703063474
1.7030840978047046
1.68700035387531
1.655183919406697
1.634911486824778
1.5946690757870863
1.5825637903540424
1.568598521214862


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.2597523069082524
1.8653226888457046
1.688940973183897
1.6277652381734977
1.583436945438083
1.5625559868765315
1.5156316725757322
1.4971950730229275
1.4943603040481397
1.426542398252899
1.399335082167342
1.3688127667439713
1.3538264667179474
1.3285128562672532
1.303241736826639
1.2613920400790872


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.073165398402355
1.8633565559631797
1.8360584809211469
1.7744807670540852
1.7533452572158388
1.6893290030577304
1.64617346870535
1.6448740602284886
1.5772078530035696
1.5411588811940125
1.5111351959910087
1.5033026376634173
1.4951482190608436
1.4753263926187374
1.4601855839970486
1.4511003611292197
1.4430840019122553
1.4390566014939268
1.4358011662942254
1.4330182023599904
1.4174990260437796
1.409212379778454
1.4017961223308122


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.2179765564935474
1.7590315023000704
1.6396743706970613
1.4788270663190202
1.4745658319566295
1.4022046256888145
1.3584212368097504
1.3173435310346178
1.3047445639762991
1.297300081496483
1.2863315492697909
1.2747402201434372
1.26409057924734
1.246188850392894


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

3.4316875280885233
2.3172912034071076
1.9133140330376979
1.8542776950135975
1.7213684427721478
1.6147083415138677
1.543024036718655
1.5283027819089487
1.4806634875392315
1.4512137605153619
1.447334306639332
1.446723198340461
1.425213446676107
1.4236189711076008
1.4181081104945696
1.3903739128020345
1.3772430933265571


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

3.5820595872063214
2.261405155244928
1.8747783991106608
1.7706074853048557
1.7254856231508149
1.6587565828766566
1.641325679296234
1.5569840344595167
1.5496738752842996
1.5253960705767085
1.5230774349724197
1.4997542773508836
1.486876743795592
1.4859147836618558
1.4847465528586195
1.4530737209307616
1.4452747186472452
1.4276983591969687
1.4224430377999995
1.4183881088608012
1.4129993408241277
1.3842318658334927
1.3782261399887767
1.373629404709079
1.364730715548696


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

3.660813205083213
2.8547947623026033
2.211020561085547
2.0686209261557047
2.06428035942092
2.0436223533470295
2.030495073349139
2.013399966476248
1.9806371299383245
1.9609046664571885
1.9418349970652011
1.931625141405596
1.9263838248592808
1.9189184183123413
1.9119502975846432
1.898987978361421
1.8422465310493932
1.8056136455912666
1.796267255574078
1.7653903390292134
1.7598031158431318
1.7535204023002873
1.689544634461888
1.641015997920109
1.6304198426832863
1.5977768915129138
1.5591539433843635
1.5336576347434976
1.529298339325055
1.5076116291928192
1.5011019424537506


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

3.5273247529148835
2.437066372156476
2.131195727532482
1.7822984377331192
1.78073051335549
1.6577664462030444
1.5811509694033707
1.5561265288394484
1.5015998628215053
1.4720524811381899
1.401405248146472
1.3666474317654673
1.3268981640568236
1.308481890107092
1.2880176018958704
1.2624498747380106
1.2534182520712305
1.2428499633578078
1.2190512111937843
1.2162147700760606
1.1902260129306876
1.1798730298769742
1.1764137308493872
1.1570122806152197


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

3.493041354614386
2.292642869260036
1.9856567196066468
1.8111131843417336
1.750221173951596
1.6909003051901288
1.6425058807666912
1.5934004215455724
1.560052526583636
1.5371002901837605
1.5258197245432652
1.5222571858836027
1.4991028707676528
1.4849567725222796
1.466435486232716
1.4462591041739739
1.4255726645520943


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

3.4444478542413837
1.936656801380953
1.7446393770120183
1.7382234394856309
1.5808608737840133
1.5367162575037783
1.4406673770237919
1.391457002748705
1.350361342777464
1.2963172877329976
1.286877719330493
1.2852224407981314
1.252366112058997
1.2369944891008129
1.2216730156494902
1.2188178797303442
1.199424374831199


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.255222927836605
1.9876450324148198
1.8783373134016532
1.8706244464730057
1.6180887929057963
1.5688646303327398
1.5139503228644346
1.4921213067496466
1.477165193603976
1.4673668504797597
1.4623291466544879
1.4398939862077922
1.4316138012644604


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.7873189912373346
1.7003609509245043
1.6077988514756734
1.5373729401158325
1.4972386194482148
1.48939393090681
1.4844409873177489
1.4730862648296381
1.4712925910039099
1.4681701551483273
1.4659464736757355
1.4605315199446347
1.4446585985069282
1.4354721156352586
1.427247754784203


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.6625246305974994
2.068815535325529
2.014151487186758
1.9254726697265399
1.8409397406991599
1.7607565795769058
1.7186117551839786
1.7006700785207618
1.6971875660626827
1.647661644223737
1.6035336600316514
1.5739688900584519
1.5243635426981863
1.4879670829815497
1.4247270990815386
1.3998550559012646
1.3582643799321175
1.3580644848824173
1.3503814006249681
1.323787301884579


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.289807157074048
2.1247384160534533
1.7455398660953132
1.6752676452178124
1.6469033335910475
1.619266144400537
1.5692742347047066
1.5551826003379998
1.5277662546571074
1.5065565993532584
1.4977929596812738
1.4930485486535525
1.4921486838523341
1.4613741966554215
1.4485625941252198
1.4482112053432592
1.4442308097543912
1.3900920468660825
1.3833110669406752
1.3734791061042475


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.7217103536346454
2.074520058079494
1.9457252792406003
1.8381846650583884
1.7902194359075991
1.7706548995468507
1.7081602942926108
1.6807635862006935
1.6399871531875048
1.5654322567830377
1.549073469543986
1.5260531679163944
1.5214311644954002
1.5074080821249003
1.4767547793722502
1.4676558745180464
1.4542161959561586
1.4473018270307063
1.4460936973026728
1.4331707156892997


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.8759647525480918
1.6626367375481081
1.6441862330494579
1.5635152111483184
1.525127822987958
1.495302614930617
1.4275587357739983
1.377008179685238
1.3589833565394562
1.3490539300100166
1.3359131750122086
1.3237513507226526
1.3044143358580766
1.2776343640104717
1.249636029867538


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

3.0208624589498023
2.36717136852326
2.079986920668971
1.9104394491581247
1.7735344680091554
1.6665089826289001
1.652842550595414
1.5729794071211844
1.529273673278022
1.4989145718142096
1.4750704117012496
1.4395787453289963
1.429615463344367
1.4174768667614774


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

3.442296158319279
2.118564033927657
2.0952793837605856
2.072149629117616
2.037780180767393
1.8804171153285456
1.7920204630535168
1.6671171522877208
1.5621589721024833
1.5126522483291291
1.5065821262258972
1.5055541197344733
1.5024678881327664
1.4954308972718553
1.4774602836736046


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

3.244459892006319
2.2125773948907637
2.145400384150295
2.0922676991476528
2.077041135254518
2.0691910562491005
2.068576829780089
2.0633615266440084
2.0483608675614997


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

3.510082242692972
2.119262773254372
2.0525628840774868
1.9733432755136824
1.7586608662712742
1.7360814162282383
1.6518055362163402
1.6163332195478923
1.5979231899966195
1.5822042053947605
1.5526236713581536
1.5115239503034232
1.4975539497150132
1.47417930303045
1.4585537302219043
1.44983635612046
1.4386398700086527
1.433152222532711
1.4280607118889714
1.3774978599187329
1.3630514114336267
1.3345326967953315
1.3010809731670991
1.2534080712191635
1.2397631078527314
1.2263408600907213
1.2183339959400694


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

3.406385465945386
2.2765600734422122
2.089817866215952
1.8932903903933995
1.7528851968668837
1.665645788074467
1.640887820004751
1.6298383614007605
1.6186722420231323
1.609325243630616
1.55501438990352
1.544793616967955
1.5192483263115495
1.4949276566733267
1.4642738185465878
1.4421707881786157
1.4284909224754492
1.4124389185176112
1.4094603957299865
1.4090246612459612


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

3.5155766862477744
2.1118106381910358
1.9877283938899954
1.7271361089985928
1.5755474307741035
1.4366467727098433
1.4152346389548829
1.368905245166827
1.3431280102996752
1.3341132648412444
1.3293872287073156
1.3146112287867233
1.300832595021333
1.29992024256105
1.283199266417951


  0%|          | 0/5 [00:00<?, ?it/s]

In [17]:
df

Unnamed: 0,Repr 1,Repr 10,transfer learning - size = 3,transfer learning - size = 35,transfer learning - size = 100,transfer learning - size = 35 (big)
"1 convs, 64 channels",1.49,1.48,2.33,1.64,1.56,1.35
"1 convs, 512 channels",1.48,1.48,2.35,1.62,1.53,1.35
"3 convs, 64 channels",1.37,1.4,1.71,1.58,1.52,1.3
"3 convs, 512 channels",1.44,1.45,2.14,1.46,1.55,1.17
"5 convs, 64 channels",1.46,1.46,1.4,1.52,1.51,1.25
"5 convs, 512 channels",1.52,1.56,2.34,1.29,1.45,1.28


In [19]:
m =  GraphNeuralNetwork(n_channels, n_convs=n_convs, my_layer=MyAttentionModule4(35), features_after_layer=35)

m.myAttentionModule.load_state_dict(torch.load('attention_pooling35_big.pth'))
for par in m.myAttentionModule.parameters():
        par.requires_grad = False

m = train_best(m, train_loader, valid_loader, 70)
predictions, att = predict(m, test_loader)
rmse_score = rmse(y_test, predictions.flatten())
###################### atencja #################################

df_single = pd.DataFrame({"AtomicNum": [], "Degree": [], "TotalNumHs": [], "ImplicitValence": [], "Hybridization": [], "FormalCharge": [],
                          "IsInRing": [], "IsAromatic": [], "NumRadicalElectrons": []})
df_single.style.set_caption("Hello World")

df_batch = pd.DataFrame({"AtomicNum": [], "Degree": [], "TotalNumHs": [], "ImplicitValence": [], "Hybridization": [], "FormalCharge": [],
                          "IsInRing": [], "IsAromatic": [], "NumRadicalElectrons": []})
df_batch.style.set_caption("Hello World")

preds_batches = []
with torch.no_grad():
    for data in test_loader:
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        preds, att = m(x, edge_index, batch)
        preds_batches.append(preds.cpu().detach().numpy())
        att = att.squeeze()
        df_single.loc[len(df_single)] = att[0].tolist()
        df_batch.loc[len(df_single)] = torch.mean(gap(att, batch), dim=0).tolist()
preds = np.concatenate(preds_batches)

rmse_score = rmse(y_test, predictions.flatten())

print(f'RMSE = {rmse_score:.2f}')
df_single[:10]

  0%|          | 0/70 [00:00<?, ?it/s]

3.3655222600643584
2.1522534478454642
1.9321602654377317
1.7427049183001957
1.6375673789358671
1.5013309329953834
1.380769958293459
1.3650897134731423
1.3495792890871938
1.3324222014567297
1.3257246624602181
1.296787756954208


  0%|          | 0/5 [00:00<?, ?it/s]

RMSE = 1.37


Unnamed: 0,AtomicNum,Degree,TotalNumHs,ImplicitValence,Hybridization,FormalCharge,IsInRing,IsAromatic,NumRadicalElectrons
0,0.11,0.11,0.11,0.16,0.11,0.11,0.11,0.11,0.11
1,0.09,0.19,0.09,0.13,0.09,0.09,0.16,0.09,0.09
2,0.11,0.11,0.11,0.11,0.11,0.11,0.11,0.11,0.11
3,0.11,0.11,0.11,0.16,0.11,0.11,0.11,0.11,0.11
4,0.11,0.11,0.11,0.16,0.11,0.11,0.11,0.11,0.11


In [20]:
df_batch[:10]

Unnamed: 0,AtomicNum,Degree,TotalNumHs,ImplicitValence,Hybridization,FormalCharge,IsInRing,IsAromatic,NumRadicalElectrons
1,0.09,0.15,0.09,0.2,0.09,0.09,0.12,0.09,0.09
2,0.09,0.15,0.09,0.2,0.09,0.09,0.12,0.09,0.09
3,0.09,0.15,0.09,0.2,0.09,0.09,0.12,0.09,0.09
4,0.09,0.15,0.09,0.2,0.09,0.09,0.13,0.09,0.09
5,0.09,0.15,0.09,0.21,0.09,0.09,0.11,0.09,0.09


In [21]:
df

Unnamed: 0,Repr 1,Repr 10,transfer learning - size = 3,transfer learning - size = 35,transfer learning - size = 100,transfer learning - size = 35 (big)
"1 convs, 64 channels",1.49,1.48,2.33,1.64,1.56,1.35
"1 convs, 512 channels",1.48,1.48,2.35,1.62,1.53,1.35
"3 convs, 64 channels",1.37,1.4,1.71,1.58,1.52,1.3
"3 convs, 512 channels",1.44,1.45,2.14,1.46,1.55,1.17
"5 convs, 64 channels",1.46,1.46,1.4,1.52,1.51,1.25
"5 convs, 512 channels",1.52,1.56,2.34,1.29,1.45,1.28
