In [1]:
import pandas as pd
import numpy as np
import torch
from tdc.single_pred.adme import ADME
from tdc import Evaluator
from tqdm.notebook import tqdm, trange
from torch.utils.data import TensorDataset, DataLoader
from rdkit import Chem
from rdkit.Chem import AllChem
from matplotlib import pyplot as plt
from IPython import display

from typing import List, Tuple


class Featurizer:
    def __init__(self, y_column, smiles_col='Drug', **kwargs):
        self.y_column = y_column
        self.smiles_col = smiles_col
        self.__dict__.update(kwargs)
    
    def __call__(self, df):
        raise NotImplementedError()

In [2]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader


def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise ValueError("input {0} not in allowable set{1}:".format(
            x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))


def one_of_k_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))


from rdkit.Chem import rdMolDescriptors

class GraphFeaturizer(Featurizer):
    def __call__(self, df, getRepresentation):
        graphs = []
        labels = []
        for i, row in df.iterrows():
            y = row[self.y_column]
            smiles = row[self.smiles_col]
            mol = Chem.MolFromSmiles(smiles)
            
            edges = []
            for bond in mol.GetBonds():
                begin = bond.GetBeginAtomIdx()
                end = bond.GetEndAtomIdx()
                edges.append((begin, end))  # TODO: Add edges in both directions
            edges = np.array(edges)
            
            nodes = []
            for atom in mol.GetAtoms():
                # print(atom.GetAtomicNum(), atom.GetNumImplicitHs(), atom.GetTotalNumHs(), atom.GetSymbol(), atom.GetNumExplicitHs(), atom.GetTotalValence())
                results = getRepresentation(atom)
                # print(results)
                nodes.append(results)
            nodes = np.array(nodes)
            
            graphs.append((nodes, edges.T))
            labels.append(y)
        labels = np.array(labels)
        return [Data(
            x=torch.FloatTensor(x), 
            edge_index=torch.LongTensor(edge_index), 
            y=torch.FloatTensor([y])
        ) for ((x, edge_index), y) in zip(graphs, labels)]

In [3]:
def defaultRepresentation(atom):
    return one_of_k_encoding_unk(atom.GetAtomicNum(), range(11)) + one_of_k_encoding(
                    atom.GetDegree(), range(11)
                ) + one_of_k_encoding_unk(
                    atom.GetImplicitValence(), range(11)
                ) + [atom.GetIsAromatic()] + one_of_k_encoding_unk(
                    atom.GetTotalNumHs(), range(11)
                ) + [atom.GetNumImplicitHs(), atom.GetFormalCharge(), atom.GetNumRadicalElectrons(), atom.IsInRing()] # TODO: Add atom features as a list, you can use one_of_k_encodings defined above

def representation1(atom):
    return one_of_k_encoding_unk(atom.GetAtomicNum(), range(12)) + one_of_k_encoding_unk(
                    atom.GetDegree(), range(6)) + one_of_k_encoding_unk(
                    atom.GetTotalNumHs(), range(5)
                ) + [atom.GetFormalCharge(), atom.IsInRing(), atom.GetIsAromatic()]

def representation10(atom):
    return one_of_k_encoding_unk(atom.GetAtomicNum(), range(12)) + one_of_k_encoding_unk(
                    atom.GetDegree(), range(6)) + one_of_k_encoding_unk(
                    atom.GetTotalNumHs(), range(5)
                ) + [atom.IsInRing(), atom.GetIsAromatic()]

def representationAll(atom):
    return one_of_k_encoding_unk(atom.GetAtomicNum(), range(12)) + one_of_k_encoding_unk(
                    atom.GetDegree(), range(6)) + one_of_k_encoding_unk(
                    atom.GetTotalNumHs(), range(5)) + one_of_k_encoding_unk(
                    atom.GetImplicitValence(), range(6))  + one_of_k_encoding_unk(
                    atom.GetHybridization(),
                    [
                        Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                        Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
                        Chem.rdchem.HybridizationType.SP3D2
                    ]
                ) + [atom.GetFormalCharge(), atom.IsInRing(), atom.GetIsAromatic()
                     ] + [atom.GetNumRadicalElectrons()]

def printProperties(atom):
    print("=========")
    print("GetDegree", atom.GetDegree())
    print("GetImplicitValence", atom.GetImplicitValence())
    print("GetAtomicNum", atom.GetAtomicNum())
    print("GetTotalNumHs", atom.GetTotalNumHs())
    print("GetNumImplicitHs", atom.GetNumImplicitHs())
    print("GetNeighbors", atom.GetNeighbors())
    print("GetNumExplicitHs", atom.GetNumExplicitHs())
    print("GetTotalDegree", atom.GetTotalDegree())
    print("GetTotalNumHs", atom.GetTotalNumHs())
    print("GetTotalValence", atom.GetTotalValence())

In [4]:
class ECFPFeaturizer(Featurizer):
    def __init__(self, y_column, radius=2, length=1024, **kwargs):
        self.radius = radius
        self.length = length
        super().__init__(y_column, **kwargs)
    
    def __call__(self, df):
        fingerprints = []
        labels = []
        for i, row in df.iterrows():
            y = row[self.y_column]
            smiles = row[self.smiles_col]
            mol = Chem.MolFromSmiles(smiles)
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, self.radius, nBits=self.length)
            fingerprints.append(fp)
            labels.append(y)
        fingerprints = np.array(fingerprints)
        labels = np.array(labels)
        return fingerprints, labels

data = ADME('Solubility_AqSolDB')
split = data.get_split()
rmse = Evaluator(name = 'RMSE')

featurizer = ECFPFeaturizer(y_column='Y')
X_train, y_train = featurizer(split['train'])
X_valid, y_valid = featurizer(split['valid'])
X_test, y_test = featurizer(split['test'])

featurizer = GraphFeaturizer('Y')
graph = featurizer(split['test'].iloc[:1], defaultRepresentation)[0]

Found local copy...
Loading...
Done!


In [5]:
from torch_geometric.loader import DataLoader as GraphDataLoader


# prepare data loaders
batch_size = 64

#dla repr1
train_loader1 = GraphDataLoader(featurizer(split['train'], representation1), batch_size=batch_size, shuffle=True)
valid_loader1 = GraphDataLoader(featurizer(split['valid'], representation1), batch_size=batch_size)
test_loader1 = GraphDataLoader(featurizer(split['test'], representation1), batch_size=batch_size)

#dla repr10
train_loader10 = GraphDataLoader(featurizer(split['train'], representation10), batch_size=batch_size, shuffle=True)
valid_loader10 = GraphDataLoader(featurizer(split['valid'], representation10), batch_size=batch_size)
test_loader10 = GraphDataLoader(featurizer(split['test'], representation10), batch_size=batch_size)

#dla naszej warstwy
train_loader = GraphDataLoader(featurizer(split['train'], representationAll), batch_size=batch_size, shuffle=True)
valid_loader = GraphDataLoader(featurizer(split['valid'], representationAll), batch_size=batch_size)
test_loader = GraphDataLoader(featurizer(split['test'], representationAll), batch_size=batch_size)



In [6]:
from torch_geometric.nn import GCNConv, GINConv, global_mean_pool
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool as gap

In [7]:
#warstwa attention pooling
class MyAttentionModule3(torch.nn.Module): # zakladamy ze atom ma 49 featerow
    def __init__(self, groupFeatures=1):
        super().__init__()
        self.groupFeatures = groupFeatures
        self.gates = torch.nn.ModuleDict({ # do wyliczenia atencji dla kazdej grupy cech - jest ich 9
            'AtomicNum': GCNConv(12, 1),
            'Degree': GCNConv(6, 1),
            'TotalNumHs': GCNConv(5, 1),
            'ImplicitValence': GCNConv(6, 1),
            'Hybridization': GCNConv(5, 1),
            'FormalCharge': GCNConv(1, 1),
            'IsInRing': GCNConv(1, 1),
            'IsAromatic': GCNConv(1, 1),
            'NumRadicalElectrons': GCNConv(1, 1)
        })
        
        self.feats = torch.nn.ModuleDict({ # do transformacji grupy cech w wektor, na razie dziala tylko dla groupFeatures=1
            'AtomicNum': torch.nn.Linear(12, groupFeatures),
            'Degree': torch.nn.Linear(6, groupFeatures),
            'TotalNumHs': torch.nn.Linear(5, groupFeatures),
            'ImplicitValence': torch.nn.Linear(6, groupFeatures),
            'Hybridization': torch.nn.Linear(5, groupFeatures),
            'FormalCharge': torch.nn.Linear(1, groupFeatures),
            'IsInRing': torch.nn.Linear(1, groupFeatures),
            'IsAromatic': torch.nn.Linear(1, groupFeatures),
            'NumRadicalElectrons': torch.nn.Linear(1, groupFeatures)
        })

    def forward(self, x, edge_index, batch):
        gates = []
        gates.append(self.gates['AtomicNum'](x[:,0:12], edge_index))
        gates.append(self.gates['Degree'](x[:,12:18], edge_index))
        gates.append(self.gates['TotalNumHs'](x[:,18:23], edge_index))
        gates.append(self.gates['ImplicitValence'](x[:,23:29], edge_index))
        gates.append(self.gates['Hybridization'](x[:,29:34], edge_index))
        gates.append(self.gates['FormalCharge'](x[:,34:35], edge_index))
        gates.append(self.gates['IsInRing'](x[:,35:36], edge_index))
        gates.append(self.gates['IsAromatic'](x[:,36:37], edge_index))
        gates.append(self.gates['NumRadicalElectrons'](x[:,37:38], edge_index))
        logits = torch.cat(gates, dim=-1)
        attention = torch.softmax(logits, dim=-1).unsqueeze(-1)
        
        subgroups = []
        subgroups.append(self.feats['AtomicNum'](x[:,0:12]) * attention[:,0])
        subgroups.append(self.feats['Degree'](x[:,12:18]) * attention[:,1])
        subgroups.append(self.feats['TotalNumHs'](x[:,18:23]) * attention[:,2])
        subgroups.append(self.feats['ImplicitValence'](x[:,23:29]) * attention[:,3])
        subgroups.append(self.feats['Hybridization'](x[:,29:34]) * attention[:,4])
        subgroups.append(self.feats['FormalCharge'](x[:,34:35]) * attention[:,5])
        subgroups.append(self.feats['IsInRing'](x[:,35:36]) * attention[:,6])
        subgroups.append(self.feats['IsAromatic'](x[:,36:37]) * attention[:,7])
        subgroups.append(self.feats['NumRadicalElectrons'](x[:,37:38]) * attention[:,8])
        x = torch.stack(subgroups, dim=-2)
        x = torch.sum(x, dim=-2)
        
        return x, attention

In [8]:
def attSequential(n_feats):
    return torch.nn.Sequential(torch.nn.Linear(n_feats, 1),
                       torch.nn.BatchNorm1d(1), torch.nn.ReLU(),
                       torch.nn.Linear(1, 1), torch.nn.ReLU())

In [9]:
#warstwa attention pooling
class MyAttentionModule4(torch.nn.Module): # zakladamy ze atom ma 49 featerow
    def __init__(self, groupFeatures=1):
        super().__init__()
        self.groupFeatures = groupFeatures
        self.gates = torch.nn.ModuleDict({ # do wyliczenia atencji dla kazdej grupy cech - jest ich 9
            'AtomicNum': GINConv(attSequential(12), train_eps=True),
            'Degree': GINConv(attSequential(6), train_eps=True),
            'TotalNumHs': GINConv(attSequential(5), train_eps=True),
            'ImplicitValence': GINConv(attSequential(6), train_eps=True),
            'Hybridization': GINConv(attSequential(5), train_eps=True),
            'FormalCharge': GINConv(attSequential(1), train_eps=True),
            'IsInRing': GINConv(attSequential(1), train_eps=True),
            'IsAromatic': GINConv(attSequential(1), train_eps=True),
            'NumRadicalElectrons': GINConv(attSequential(1), train_eps=True)
        })
        
        self.feats = torch.nn.ModuleDict({ # do transformacji grupy cech w wektor, na razie dziala tylko dla groupFeatures=1
            'AtomicNum': torch.nn.Linear(12, groupFeatures),
            'Degree': torch.nn.Linear(6, groupFeatures),
            'TotalNumHs': torch.nn.Linear(5, groupFeatures),
            'ImplicitValence': torch.nn.Linear(6, groupFeatures),
            'Hybridization': torch.nn.Linear(5, groupFeatures),
            'FormalCharge': torch.nn.Linear(1, groupFeatures),
            'IsInRing': torch.nn.Linear(1, groupFeatures),
            'IsAromatic': torch.nn.Linear(1, groupFeatures),
            'NumRadicalElectrons': torch.nn.Linear(1, groupFeatures)
        })

    def forward(self, x, edge_index, batch):
        gates = []
        gates.append(self.gates['AtomicNum'](x[:,0:12], edge_index))
        gates.append(self.gates['Degree'](x[:,12:18], edge_index))
        gates.append(self.gates['TotalNumHs'](x[:,18:23], edge_index))
        gates.append(self.gates['ImplicitValence'](x[:,23:29], edge_index))
        gates.append(self.gates['Hybridization'](x[:,29:34], edge_index))
        gates.append(self.gates['FormalCharge'](x[:,34:35], edge_index))
        gates.append(self.gates['IsInRing'](x[:,35:36], edge_index))
        gates.append(self.gates['IsAromatic'](x[:,36:37], edge_index))
        gates.append(self.gates['NumRadicalElectrons'](x[:,37:38], edge_index))
        logits = torch.cat(gates, dim=-1)
        attention = torch.softmax(logits, dim=-1).unsqueeze(-1)
        
        subgroups = []
        subgroups.append(self.feats['AtomicNum'](x[:,0:12]) * attention[:,0])
        subgroups.append(self.feats['Degree'](x[:,12:18]) * attention[:,1])
        subgroups.append(self.feats['TotalNumHs'](x[:,18:23]) * attention[:,2])
        subgroups.append(self.feats['ImplicitValence'](x[:,23:29]) * attention[:,3])
        subgroups.append(self.feats['Hybridization'](x[:,29:34]) * attention[:,4])
        subgroups.append(self.feats['FormalCharge'](x[:,34:35]) * attention[:,5])
        subgroups.append(self.feats['IsInRing'](x[:,35:36]) * attention[:,6])
        subgroups.append(self.feats['IsAromatic'](x[:,36:37]) * attention[:,7])
        subgroups.append(self.feats['NumRadicalElectrons'](x[:,37:38]) * attention[:,8])
        x = torch.stack(subgroups, dim=-2)
        x = torch.sum(x, dim=-2)
        
        return x, attention

In [10]:
class GraphNeuralNetwork(torch.nn.Module):  # TODO: assign hyperparameters to attributes and define the forward pass
    def __init__(self, hidden_size, n_convs=3, my_layer=None, features_after_layer=26, n_features=49, dropout=0.2):
        super().__init__()
        self.myAttentionModule = my_layer
        self.dropout = dropout

        convs = torch.nn.ModuleList()
        convs.append(GCNConv(features_after_layer, hidden_size))
        for i in range(1, n_convs):
            convs.append(GCNConv(hidden_size, hidden_size))
        self.convs = convs
        self.linear = torch.nn.Linear(hidden_size, 1)
    
    def forward(self, x, edge_index, batch):
        att = None
        if self.myAttentionModule is not None:
            x, att = self.myAttentionModule(x, edge_index, batch)
        for i in range(0, len(self.convs)-1):
            x = self.convs[i](x, edge_index)
            x = x.relu()
        x = self.convs[-1](x, edge_index)
        
        x = gap(x, batch)
        
        x = F.dropout(x, p=self.dropout, training=self.training)

        out = self.linear(x)

        return out, att

In [11]:
def train(model, train_loader, valid_loader, epochs=20, learning_rate = 0.01):
    model.train()
    
    # training loop
    optimizer = torch.optim.Adam(model.parameters(), learning_rate) # TODO: define an optimizer
    loss_fn = torch.nn.MSELoss()  # TODO: define a loss function
    for epoch in trange(1, epochs + 1, leave=False):
        for data in tqdm(train_loader, leave=False):
            x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
            model.zero_grad()
            preds, att = model(x, edge_index, batch)
            loss = loss_fn(preds, y.reshape(-1, 1))
            loss.backward()
            # print("==============")
            # for par in model.myAttentionModule.parameters():
            #     print(par)
            optimizer.step()
    return model


def predict(model, test_loader):
    # evaluation loop
    preds_batches = []
    with torch.no_grad():
        for data in tqdm(test_loader):
            x, edge_index, batch = data.x, data.edge_index, data.batch
            
            preds, att = model(x, edge_index, batch)
            preds_batches.append(preds.cpu().detach().numpy())
    preds = np.concatenate(preds_batches)
    return preds, att

In [12]:
from copy import deepcopy

def train_best(model, train_loader, valid_loader, epochs=20, learning_rate = 0.01, saveImg=False, title=""):
    model.train()

    best_state = deepcopy(model.state_dict())
    best_val = 1000000
    
    # training loop
    optimizer = torch.optim.Adam(model.parameters(), learning_rate) # TODO: define an optimizer
    loss_fn = torch.nn.MSELoss()  # TODO: define a loss function
    for epoch in trange(1, epochs + 1, leave=False):
        # preds_batches = []
        running_loss = 0.0
        for data in train_loader:
            x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
            model.zero_grad()
            preds, att = model(x, edge_index, batch)
            loss = loss_fn(preds, y.reshape(-1, 1))

            loss.backward()
            optimizer.step()

        # evaluation loop
        preds_batches = []
        with torch.no_grad():
            for data in valid_loader:
                x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
                preds, att = model(x, edge_index, batch)
                loss = loss_fn(preds, y.reshape(-1, 1))
                preds_batches.append(preds.cpu().detach().numpy())
        preds = np.concatenate(preds_batches)
        mae = rmse(y_valid, preds.flatten())
        if mae < best_val:
            best_state = deepcopy(model.state_dict())
            best_val = mae
            print(best_val)

    model.load_state_dict(best_state)
    return model

In [13]:
def visualize(model, train_loader, valid_loader, test_loader, epochs=20, learning_rate = 0.01, saveImg=False, title=""):
    model.train()

    torch.save(model, "train.pth")
    best_val = 1000000
    
    # training loop
    optimizer = torch.optim.Adam(model.parameters(), learning_rate) # TODO: define an optimizer
    loss_fn = torch.nn.MSELoss()  # TODO: define a loss function
    train_losses = []
    val_losses = []
    train_errors = []
    val_errors = []
    for epoch in trange(1, epochs + 1, leave=False):
        # preds_batches = []
        running_loss = 0.0
        for data in train_loader:
            x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
            model.zero_grad()
            preds, att = model(x, edge_index, batch)
            loss = loss_fn(preds, y.reshape(-1, 1))
            # print(len(train_dataset))

            running_loss += loss.item()
            # preds_batches.append(preds.cpu().detach().numpy())

            loss.backward()
            optimizer.step()
        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)
        # preds = np.concatenate(preds_batches)
        # mae = rmse(y_train, preds.flatten())
        # train_errors.append(mae)

        # evaluation loop
        preds_batches = []
        running_loss = 0.0
        with torch.no_grad():
            for data in valid_loader:
                x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
                preds, att = model(x, edge_index, batch)
                loss = loss_fn(preds, y.reshape(-1, 1))
                # print(len(train_dataset))

                running_loss += loss.item()
                preds_batches.append(preds.cpu().detach().numpy())
        epoch_loss = running_loss / len(valid_loader)
        val_losses.append(epoch_loss)
        preds = np.concatenate(preds_batches)
        mae = rmse(y_valid, preds.flatten())
        if mae < best_val:
            torch.save(model, "train.pth")
            best_val = mae
            print(best_val)
        val_errors.append(mae)

    model = torch.load("train.pth")
    model.eval()

    ##### visualize ########
    plt.plot(train_losses, label='train_loss')
    plt.plot(val_losses, label='val_loss')
    plt.legend()
    plt.show()
    if saveImg:
        plt.savefig(title + "_loss.png")

    # plt.plot(train_errors,label='train_errors')
    plt.plot(val_errors, label='val_RMSE')
    plt.legend()
    plt.show()
    if saveImg:
        plt.savefig(title + "_val_error.png")
    return model

In [17]:
df = pd.DataFrame({"Repr 1": [], "Repr 10": [],
                   "transfer learning - size = 3": [], "transfer learning - size = 35": [], "transfer learning - size = 100": [], "transfer learning - size = 35 (big)": []})
pd.set_option("display.precision", 2)

In [18]:
n_times = 1

for n_convs in [1, 3, 5]:
    for n_channels in [64, 512]:
        row = []
        #########################
        scores = []
        for _ in range(n_times):
            m =  GraphNeuralNetwork(n_channels, n_convs=n_convs, features_after_layer=26)

            m = train_best(m, train_loader1, valid_loader1, 70)
            predictions, att = predict(m, test_loader1)
            rmse_score = rmse(y_test, predictions.flatten())
            scores.append("{:.2f}".format(rmse_score))
        row.append(" | ".join(scores))

        #########################
        scores = []
        for _ in range(n_times):
            m =  GraphNeuralNetwork(n_channels, n_convs=n_convs, features_after_layer=25)

            m = train_best(m, train_loader10, valid_loader10, 70)
            predictions, att = predict(m, test_loader10)
            rmse_score = rmse(y_test, predictions.flatten())
            scores.append("{:.2f}".format(rmse_score))
        row.append(" | ".join(scores))

        #########################
        for vect_size in [3, 35, 100]:
            scores = []
            for _ in range(n_times):
                m =  GraphNeuralNetwork(n_channels, n_convs=n_convs, my_layer=MyAttentionModule4(vect_size), features_after_layer=vect_size)

                m.myAttentionModule.load_state_dict(torch.load('attention_pooling' + str(vect_size) + '.pth'))
                m.eval()
                for par in m.myAttentionModule.parameters():
                        par.requires_grad = False

                m = train_best(m, train_loader, valid_loader, 70)
                predictions, att = predict(m, test_loader)
                rmse_score = rmse(y_test, predictions.flatten())
                scores.append("{:.2f}".format(rmse_score))
            row.append(" | ".join(scores))

        #########################
        scores = []
        for _ in range(n_times):
            m =  GraphNeuralNetwork(n_channels, n_convs=n_convs, my_layer=MyAttentionModule4(35), features_after_layer=35)

            m.myAttentionModule.load_state_dict(torch.load('attention_pooling35_big.pth'))
            m.eval()
            for par in m.myAttentionModule.parameters():
                    par.requires_grad = False

            m = train_best(m, train_loader, valid_loader, 70)
            predictions, att = predict(m, test_loader)
            rmse_score = rmse(y_test, predictions.flatten())
            scores.append("{:.2f}".format(rmse_score))
        row.append(" | ".join(scores))

        df.loc[str(n_convs) + " convs, " + str(n_channels) + " channels"] = row

  0%|          | 0/70 [00:00<?, ?it/s]

1.9239051927224744
1.8824044650857024
1.8708188823695409
1.8544374099999812
1.8514084692328474
1.8505132278028962
1.8445121074406807


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.9202650505641574
1.899382670135747
1.8779640098929185
1.8764937760804201
1.8734432236981806
1.8459894030322437
1.8440723721801575


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.2111081134628714
2.1993403799033966
2.198601756081361
2.196553355901469
2.191896473762494
2.188236792649506
2.1867627697859713
2.184694445927528


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.006275511074213
1.9686244163744882
1.9251508140100526
1.923495775654457
1.9055809401847608
1.8875762879330271
1.883742406027809
1.883458562269263
1.8710887540855654


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.0883162272902998
2.026305533327386
1.9498733497591907
1.9203216628288267
1.916607533693913
1.891485258106759
1.8524325079111086
1.8392726819319016
1.8325868467407926
1.826979008309253
1.8259048853623259
1.8103511658036004


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.8689479907031568
1.8490449941565912
1.8432180988148414
1.8319006974358878
1.8237140759264738
1.821234907804118
1.8206668762305838
1.8080845297523491
1.801659534220267
1.7993488850243853
1.7959151707860204


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.8988952916050632
1.8854509774977821
1.850625755708985
1.8365761139196664


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.9090209076090874
1.8502024077249433
1.8382101219101776


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.2093757941969154
2.1913067330062987
2.1875101048045082
2.1815202434928493


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.970370867360132
1.950321733365964
1.9314371930781502
1.9242463064462412
1.91024940179593
1.9011162618334152
1.8805715890360317
1.87994124342532
1.8755903687203994
1.8727136330028433


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

6.6280118934304175
3.970491568325811
2.2687763661780735
2.1948151260074464
2.162550729127905
2.06865168470698
2.05740753144236
1.9460267186743494


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.9325431860851146
1.8653189326206043
1.844039909564728
1.83749829831167
1.8336068994056027
1.8233891175558552
1.8043976619234836
1.8032968224597137
1.789958707763793


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.8135539298562566
1.6993360878639885
1.6885876697918238
1.601957954766087
1.5858535669793297
1.5645577292669195
1.5502527651410325
1.5428647062993208
1.5211438253774008
1.5202397624620274
1.5067218538558627
1.4849105412565649
1.4603146348714047
1.4566700959288572
1.4246726958126805
1.4244575319883075
1.4104772663651108
1.4026701419119965
1.3984617187106438


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.8870590871032598
1.7450174320543508
1.6783174575772293
1.6255218454440337
1.5553863122463154
1.5326127397416378
1.4915281034202208
1.4886664975425312
1.487519923635393
1.4577398251955764
1.4501861137896384
1.426146831878994


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.4562389586878655
2.0401166836517257
1.9692780801315766
1.9138323444742749
1.89181664030575
1.8617221888077689
1.835579179253183
1.8043014216946491
1.7823326193501703
1.7431856030289545
1.7422829361715828
1.7138130360074986
1.7055674275233266
1.6961986024054339
1.6717113428028303
1.6659618119448998
1.6463086137953833


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.9426676686155255
1.75723625561017
1.7198558775396031
1.694935809797026
1.6876341525739398
1.5993223362028088
1.593086700769341
1.5752532193197215
1.5597870882572984
1.545270637221411
1.5431074660433017
1.5031868879115222
1.494317095980469
1.4900655417052573
1.4873994935700299
1.4794641153552224
1.467616902129646
1.4671442647662178
1.4605251442989984
1.442238944487148
1.4412246612600559


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.9711632287931635
1.84641310424506
1.7343743059991
1.721062344735718
1.68118665363238
1.6534378300037589
1.6219635998079953
1.5994969227695182
1.5820379921100873
1.5729723834465144
1.5467106443421113
1.5152639646234154
1.5128128982045634
1.4888630112262697
1.4883967018727462
1.4725466723903282
1.45521455362106


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.7192836897023978
1.6883662289845778
1.6112009138366723
1.6012756760824105
1.5214404698565986
1.5091839785209995
1.4918499506992586
1.4705950291591636
1.4449996549252466
1.4422145933262327
1.4097203751545624
1.3849929879586094
1.382792694785252
1.3757132042996427
1.3686928932617943
1.3656276844200645
1.3642610852322599
1.3472988044361898


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.7386515185411178
1.7122691873648843
1.6972600994321485
1.5669464182517663
1.5646584156267176
1.5438415417358486
1.5102544345902689
1.5022343525916215
1.4821937003799173
1.4717586172538215
1.461083535010673
1.4443587139822627


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.0484212651355653
1.7624432019545655
1.7359618541963509
1.6498225788741978
1.6090717352451496
1.5923067535190116
1.5544581858539133
1.5439842005401705
1.539299858740487
1.515136254660383
1.5107798426199908
1.5024722687763312
1.4804395006807345
1.460315694969541
1.4545631838505713
1.4351245881840857


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.1856532298948004
2.0872137795617425
1.9504013919724328
1.9128047235374106
1.9126929836073416
1.9052292084965141
1.838107377835481
1.7896863407532195
1.7771637431716132
1.7622401775521694


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.7512769930377114
1.7264509575966038
1.6395988432686879
1.6242661585748468
1.5549478482881267
1.5438167142500043
1.4984894657795043
1.4977881104636748
1.4487460241275996
1.4417491720927849
1.4300594732494316


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.8362439321075437
1.8047134412541495
1.7532059623888383
1.7050720349688016
1.6929390577521106
1.679922206440848
1.6766625161810473
1.6690362549816442
1.6686289807975552
1.602742267641253
1.5972219808822494
1.564652769705757
1.5638882923233122
1.560677552483208
1.5588284579819658
1.5244305177436648
1.519306788958806
1.5085519327061232
1.4982027579227692


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.75629808089122
1.6486641637361596
1.6236028173821315
1.5526594292479599
1.4963945314887694
1.4393138267512864
1.4149987039725263
1.4140297400874158
1.378447142298283
1.3642425588599543
1.355407947932431


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.7496708435074497
1.720067526633254
1.6758601563571933
1.6426997636729983
1.6087502440827193
1.5529745324415023
1.5330511885246785
1.5190697452905537
1.498257140019345
1.4833248546241218
1.478494095050053
1.4556834153910758
1.4456846409869661
1.4130737330824672
1.3943385144468634
1.3942283505941149
1.3921790647771597
1.3841391911848444
1.3791118448337563
1.3726702435473803
1.345278514990747


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.841473083239969
1.6784425315229696
1.653437287827916
1.6017856873674035
1.5857702729788092
1.5616181991692804
1.5473871340321381
1.544154978552965
1.5312865205094544
1.5290383479239238
1.5156955591219567
1.501473121203392
1.4920576059603066
1.4822912690265044
1.4622476152630046
1.428134086249175
1.427308864194191
1.4231815358690867
1.4146146698997386


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.076069829012615
1.9422019365821994
1.9061508731731378
1.865284115390599
1.8105948962783305
1.7254081775627776
1.7196977171844017
1.683280356324051
1.648829552395884
1.628170997782659
1.5780329752313798


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.8168484098948792
1.69816949199348
1.6579548958102102
1.5648560684426942
1.5640926564483202
1.5082131830561782
1.500638363944748
1.4977893917299003
1.492077989938716
1.4831611156923323
1.455540419316024
1.454935630780983
1.4405598649285138
1.4378425222455435
1.4255530790854347
1.413663133720055
1.4034268933321985
1.4010448027762437
1.3922676125109394


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.761623383504531
1.8001863845251138
1.7540693766505153
1.733500637559276
1.7058382213062966
1.6430424410326319
1.6210814107985392
1.6013472732531626
1.57334788993754
1.5480634915641736
1.5474392859383128
1.5115309562214152
1.510718536656948
1.498976329114164
1.4895567305638955
1.4654136265153388


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.8212545678417087
1.716413478224713
1.6160955660411274
1.5693476552958594
1.5577903579120422
1.5265532732544573
1.5227637738434905
1.5108418994816195
1.4426445488665347
1.3972953867684548
1.3885217456087402
1.371540023212772
1.3303096365320886
1.322358792917672
1.3212027699775033
1.321119733441724
1.315106577651445


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.7742731661431803
1.7488160693516566
1.6627376663102695
1.6384430532143035
1.6116838014219879
1.5881892794801167
1.5611366571218228
1.5607767671289292
1.5550249136722762
1.5193065472484653
1.5164415262288724
1.4975248902732232
1.4915679495352097
1.4877241299701243


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.7854953530003375
1.7637375888492972
1.6173879909226458
1.5839717056144904
1.5590481876321147
1.5387392867246168
1.5378589196149515
1.5348193394208363
1.528878998686646
1.5262852325114382
1.5218132491412848
1.5188419253915946
1.502617151338612
1.4995090606262158


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.324078738585975
2.3049589978459086
2.302687600117142
2.3020991662540684
2.3017426126310543


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.8781730843471336
1.7799389993596189
1.6901251063984473
1.675658495170125
1.580310039633441
1.5136312680890247
1.49626872824438
1.4920869368197238
1.4861852105687192
1.4800051928296254
1.4642672823049312
1.4486449835224178
1.421802838662332
1.401142914591916
1.378153882651602


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

2.108788761067722
2.009442597364585
1.7914146399595603
1.7868792554557367
1.708368658133901
1.65581426843549
1.622608035668734
1.6222855537463963
1.5789159718380492
1.5456348207998767
1.5191277286633533
1.499797238049101
1.4945454973051917
1.481274450458955
1.4475974381485155


  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

1.8292562102344363
1.7136875793628301
1.6168999905843453
1.5911928172625052
1.5781093224352765
1.5566501737351732
1.543287926304379
1.5418125006934897
1.5161119736026962
1.5028243355653375
1.4968223055806262
1.4787897703191486
1.4681521182603297
1.4634721189720536


  0%|          | 0/32 [00:00<?, ?it/s]

In [20]:
df

Unnamed: 0,Repr 1,Repr 10,transfer learning - size = 3,transfer learning - size = 35,transfer learning - size = 100,transfer learning - size = 35 (big)
"1 convs, 64 channels",1.9,1.9,2.22,1.92,1.95,1.79
"1 convs, 512 channels",1.92,1.9,2.24,1.93,2.17,1.78
"3 convs, 64 channels",1.4,1.46,1.65,1.42,1.47,1.33
"3 convs, 512 channels",1.43,1.43,1.75,1.41,1.5,1.35
"5 convs, 64 channels",1.39,1.43,1.6,1.36,1.47,1.67
"5 convs, 512 channels",1.5,1.56,2.33,1.38,1.5,1.46


In [None]:
m =  GraphNeuralNetwork(n_channels, n_convs=n_convs, my_layer=MyAttentionModule4(vect_size), features_after_layer=vect_size)

m.myAttentionModule.load_state_dict(torch.load('attention_pooling35_big.pth'))
for par in m.myAttentionModule.parameters():
        par.requires_grad = False

m = train_best(m, train_loader, valid_loader, 70)
predictions, att = predict(m, test_loader)
rmse_score = rmse(y_test, predictions.flatten())
###################### atencja #################################

df_single = pd.DataFrame({"AtomicNum": [], "Degree": [], "TotalNumHs": [], "ImplicitValence": [], "Hybridization": [], "FormalCharge": [],
                          "IsInRing": [], "IsAromatic": [], "NumRadicalElectrons": []})
df_single.style.set_caption("Hello World")

df_batch = pd.DataFrame({"AtomicNum": [], "Degree": [], "TotalNumHs": [], "ImplicitValence": [], "Hybridization": [], "FormalCharge": [],
                          "IsInRing": [], "IsAromatic": [], "NumRadicalElectrons": []})
df_batch.style.set_caption("Hello World")

preds_batches = []
with torch.no_grad():
    for data in test_loader:
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        preds, att = m(x, edge_index, batch)
        preds_batches.append(preds.cpu().detach().numpy())
        att = att.squeeze()
        df_single.loc[len(df_single)] = att[0].tolist()
        df_batch.loc[len(df_single)] = torch.mean(gap(att, batch), dim=0).tolist()
preds = np.concatenate(preds_batches)

rmse_score = rmse(y_test, predictions.flatten())

print(f'RMSE = {rmse_score:.2f}')
df_single[:10]

In [None]:
df_batch[:10]