In [1]:
import pandas as pd
import numpy as np
import torch
from tdc.single_pred.adme import ADME
from tdc import Evaluator
from tqdm.notebook import tqdm, trange
from torch.utils.data import TensorDataset, DataLoader
from rdkit import Chem
from rdkit.Chem import AllChem
from matplotlib import pyplot as plt
from IPython import display

from typing import List, Tuple


class Featurizer:
    def __init__(self, y_column, smiles_col='Drug', **kwargs):
        self.y_column = y_column
        self.smiles_col = smiles_col
        self.__dict__.update(kwargs)
    
    def __call__(self, df):
        raise NotImplementedError()

In [2]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader


def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise ValueError("input {0} not in allowable set{1}:".format(
            x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))


def one_of_k_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))


from rdkit.Chem import rdMolDescriptors

class GraphFeaturizer(Featurizer):
    def __call__(self, df, getRepresentation):
        graphs = []
        labels = []
        for i, row in df.iterrows():
            y = row[self.y_column]
            smiles = row[self.smiles_col]
            mol = Chem.MolFromSmiles(smiles)
            
            edges = []
            for bond in mol.GetBonds():
                begin = bond.GetBeginAtomIdx()
                end = bond.GetEndAtomIdx()
                edges.append((begin, end))  # TODO: Add edges in both directions
            edges = np.array(edges)
            
            nodes = []
            for atom in mol.GetAtoms():
                # print(atom.GetAtomicNum(), atom.GetNumImplicitHs(), atom.GetTotalNumHs(), atom.GetSymbol(), atom.GetNumExplicitHs(), atom.GetTotalValence())
                results = getRepresentation(atom)
                # print(results)
                nodes.append(results)
            nodes = np.array(nodes)
            
            graphs.append((nodes, edges.T))
            labels.append(y)
        labels = np.array(labels)
        return [Data(
            x=torch.FloatTensor(x), 
            edge_index=torch.LongTensor(edge_index), 
            y=torch.FloatTensor([y])
        ) for ((x, edge_index), y) in zip(graphs, labels)]

In [3]:
def defaultRepresentation(atom):
    return one_of_k_encoding_unk(atom.GetAtomicNum(), range(11)) + one_of_k_encoding(
                    atom.GetDegree(), range(11)
                ) + one_of_k_encoding_unk(
                    atom.GetImplicitValence(), range(11)
                ) + [atom.GetIsAromatic()] + one_of_k_encoding_unk(
                    atom.GetTotalNumHs(), range(11)
                ) + [atom.GetNumImplicitHs(), atom.GetFormalCharge(), atom.GetNumRadicalElectrons(), atom.IsInRing()] # TODO: Add atom features as a list, you can use one_of_k_encodings defined above

def representation1(atom):
    return one_of_k_encoding_unk(atom.GetAtomicNum(), range(12)) + one_of_k_encoding_unk(
                    atom.GetDegree(), range(6)) + one_of_k_encoding_unk(
                    atom.GetTotalNumHs(), range(5)
                ) + [atom.GetFormalCharge(), atom.IsInRing(), atom.GetIsAromatic()]

def representation10(atom):
    return one_of_k_encoding_unk(atom.GetAtomicNum(), range(12)) + one_of_k_encoding_unk(
                    atom.GetDegree(), range(6)) + one_of_k_encoding_unk(
                    atom.GetTotalNumHs(), range(5)
                ) + [atom.IsInRing(), atom.GetIsAromatic()]

def representationAll(atom):
    return one_of_k_encoding_unk(atom.GetAtomicNum(), range(12)) + one_of_k_encoding_unk(
                    atom.GetDegree(), range(6)) + one_of_k_encoding_unk(
                    atom.GetTotalNumHs(), range(5)) + one_of_k_encoding_unk(
                    atom.GetImplicitValence(), range(6))  + one_of_k_encoding_unk(
                    atom.GetHybridization(),
                    [
                        Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                        Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
                        Chem.rdchem.HybridizationType.SP3D2
                    ]
                ) + [atom.GetFormalCharge(), atom.IsInRing(), atom.GetIsAromatic()
                     ] + [atom.GetNumRadicalElectrons()]

def printProperties(atom):
    print("=========")
    print("GetDegree", atom.GetDegree())
    print("GetImplicitValence", atom.GetImplicitValence())
    print("GetAtomicNum", atom.GetAtomicNum())
    print("GetTotalNumHs", atom.GetTotalNumHs())
    print("GetNumImplicitHs", atom.GetNumImplicitHs())
    print("GetNeighbors", atom.GetNeighbors())
    print("GetNumExplicitHs", atom.GetNumExplicitHs())
    print("GetTotalDegree", atom.GetTotalDegree())
    print("GetTotalNumHs", atom.GetTotalNumHs())
    print("GetTotalValence", atom.GetTotalValence())

In [4]:
# from torch.utils.data import Dataset

# class QM9Dataset(Dataset):

#     def __init__(self, file_name):
#         file_out = pd.read_csv(file_name)
#         x = file_out.iloc[:, 5]
#         y = file_out.iloc[:, 10]

#         self.X_train = torch.tensor(x)
#         self.y_train = torch.tensor(y)

#     def __len__(self):
#         return len(self.y_train)
    
#     def __getitem__(self, idx):
#         return self.X_train[idx], self.y_train[idx]

dataset = pd.read_csv("./datasets/qm9.csv")

# Normalize targets to mean = 0 and std = 1.
mean = dataset['g298'].mean()
std = dataset['g298'].std()
dataset['g298'] = (dataset['g298'] - mean) / std
# mean, std = mean[:, 'g298'].item(), std[:, 'g298'].item()

from sklearn.utils import shuffle
dataset = shuffle(dataset)
# train_dataset = dataset[20000:70000]
# val_dataset = dataset[:10000]
# test_dataset = dataset[10000:20000]
train_dataset = dataset[:5000]
val_dataset = dataset[5000:7000]
test_dataset = dataset[7000:20000]

In [5]:
class ECFPFeaturizer(Featurizer):
    def __init__(self, y_column, radius=2, length=1024, **kwargs):
        self.radius = radius
        self.length = length
        super().__init__(y_column, **kwargs)
    
    def __call__(self, df):
        fingerprints = []
        labels = []
        for i, row in df.iterrows():
            y = row[self.y_column]
            smiles = row[self.smiles_col]
            mol = Chem.MolFromSmiles(smiles)
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, self.radius, nBits=self.length)
            fingerprints.append(fp)
            labels.append(y)
        fingerprints = np.array(fingerprints)
        labels = np.array(labels)
        return fingerprints, labels

rmse = Evaluator(name = 'MAE')

featurizer = ECFPFeaturizer(y_column='g298', smiles_col="smiles")
X_train, y_train = featurizer(train_dataset)
X_valid, y_valid = featurizer(val_dataset)
X_test, y_test = featurizer(test_dataset)

featurizer = GraphFeaturizer('g298', smiles_col="smiles")

graph = featurizer(test_dataset.iloc[:1], defaultRepresentation)[0]

In [6]:
from torch_geometric.loader import DataLoader as GraphDataLoader


# prepare data loaders
batch_size = 64

train_loader1 = GraphDataLoader(featurizer(train_dataset, representation1), batch_size=batch_size, shuffle=True)
valid_loader1 = GraphDataLoader(featurizer(val_dataset, representation1), batch_size=batch_size)
test_loader1 = GraphDataLoader(featurizer(test_dataset, representation1), batch_size=batch_size)

train_loader10 = GraphDataLoader(featurizer(train_dataset, representation10), batch_size=batch_size, shuffle=True)
valid_loader10 = GraphDataLoader(featurizer(val_dataset, representation10), batch_size=batch_size)
test_loader10 = GraphDataLoader(featurizer(test_dataset, representation10), batch_size=batch_size)

train_loader = GraphDataLoader(featurizer(train_dataset, representationAll), batch_size=batch_size, shuffle=True)
valid_loader = GraphDataLoader(featurizer(val_dataset, representationAll), batch_size=batch_size)
test_loader = GraphDataLoader(featurizer(test_dataset, representationAll), batch_size=batch_size)

In [7]:
from torch_geometric.nn import GCNConv, GINConv, global_mean_pool
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool as gap

In [8]:
#warstwa attention pooling
class MyAttentionModule3(torch.nn.Module): # zakladamy ze atom ma 49 featerow
    def __init__(self, groupFeatures=1):
        super().__init__()
        self.groupFeatures = groupFeatures
        self.gates = torch.nn.ModuleDict({ # do wyliczenia atencji dla kazdej grupy cech - jest ich 9
            'AtomicNum': GCNConv(12, 1),
            'Degree': GCNConv(6, 1),
            'TotalNumHs': GCNConv(5, 1),
            'ImplicitValence': GCNConv(6, 1),
            'Hybridization': GCNConv(5, 1),
            'FormalCharge': GCNConv(1, 1),
            'IsInRing': GCNConv(1, 1),
            'IsAromatic': GCNConv(1, 1),
            'NumRadicalElectrons': GCNConv(1, 1)
        })
        
        self.feats = torch.nn.ModuleDict({ # do transformacji grupy cech w wektor, na razie dziala tylko dla groupFeatures=1
            'AtomicNum': torch.nn.Linear(12, groupFeatures),
            'Degree': torch.nn.Linear(6, groupFeatures),
            'TotalNumHs': torch.nn.Linear(5, groupFeatures),
            'ImplicitValence': torch.nn.Linear(6, groupFeatures),
            'Hybridization': torch.nn.Linear(5, groupFeatures),
            'FormalCharge': torch.nn.Linear(1, groupFeatures),
            'IsInRing': torch.nn.Linear(1, groupFeatures),
            'IsAromatic': torch.nn.Linear(1, groupFeatures),
            'NumRadicalElectrons': torch.nn.Linear(1, groupFeatures)
        })

    def forward(self, x, edge_index, batch):
        gates = []
        gates.append(self.gates['AtomicNum'](x[:,0:12], edge_index))
        gates.append(self.gates['Degree'](x[:,12:18], edge_index))
        gates.append(self.gates['TotalNumHs'](x[:,18:23], edge_index))
        gates.append(self.gates['ImplicitValence'](x[:,23:29], edge_index))
        gates.append(self.gates['Hybridization'](x[:,29:34], edge_index))
        gates.append(self.gates['FormalCharge'](x[:,34:35], edge_index))
        gates.append(self.gates['IsInRing'](x[:,35:36], edge_index))
        gates.append(self.gates['IsAromatic'](x[:,36:37], edge_index))
        gates.append(self.gates['NumRadicalElectrons'](x[:,37:38], edge_index))
        logits = torch.cat(gates, dim=-1)
        attention = torch.softmax(logits, dim=-1).unsqueeze(-1)
        
        subgroups = []
        subgroups.append(self.feats['AtomicNum'](x[:,0:12]) * attention[:,0])
        subgroups.append(self.feats['Degree'](x[:,12:18]) * attention[:,1])
        subgroups.append(self.feats['TotalNumHs'](x[:,18:23]) * attention[:,2])
        subgroups.append(self.feats['ImplicitValence'](x[:,23:29]) * attention[:,3])
        subgroups.append(self.feats['Hybridization'](x[:,29:34]) * attention[:,4])
        subgroups.append(self.feats['FormalCharge'](x[:,34:35]) * attention[:,5])
        subgroups.append(self.feats['IsInRing'](x[:,35:36]) * attention[:,6])
        subgroups.append(self.feats['IsAromatic'](x[:,36:37]) * attention[:,7])
        subgroups.append(self.feats['NumRadicalElectrons'](x[:,37:38]) * attention[:,8])
        x = torch.stack(subgroups, dim=-2)
        x = torch.sum(x, dim=-2)
        
        return x, attention

In [9]:
def attSequential(n_feats):
    return torch.nn.Sequential(torch.nn.Linear(n_feats, 1),
                       torch.nn.BatchNorm1d(1), torch.nn.ReLU(),
                       torch.nn.Linear(1, 1), torch.nn.ReLU())

In [10]:
#warstwa attention pooling
class MyAttentionModule4(torch.nn.Module): # zakladamy ze atom ma 49 featerow
    def __init__(self, groupFeatures=1):
        super().__init__()
        self.groupFeatures = groupFeatures
        self.gates = torch.nn.ModuleDict({ # do wyliczenia atencji dla kazdej grupy cech - jest ich 9
            'AtomicNum': GINConv(attSequential(12), train_eps=True),
            'Degree': GINConv(attSequential(6), train_eps=True),
            'TotalNumHs': GINConv(attSequential(5), train_eps=True),
            'ImplicitValence': GINConv(attSequential(6), train_eps=True),
            'Hybridization': GINConv(attSequential(5), train_eps=True),
            'FormalCharge': GINConv(attSequential(1), train_eps=True),
            'IsInRing': GINConv(attSequential(1), train_eps=True),
            'IsAromatic': GINConv(attSequential(1), train_eps=True),
            'NumRadicalElectrons': GINConv(attSequential(1), train_eps=True)
        })
        
        self.feats = torch.nn.ModuleDict({ # do transformacji grupy cech w wektor, na razie dziala tylko dla groupFeatures=1
            'AtomicNum': torch.nn.Linear(12, groupFeatures),
            'Degree': torch.nn.Linear(6, groupFeatures),
            'TotalNumHs': torch.nn.Linear(5, groupFeatures),
            'ImplicitValence': torch.nn.Linear(6, groupFeatures),
            'Hybridization': torch.nn.Linear(5, groupFeatures),
            'FormalCharge': torch.nn.Linear(1, groupFeatures),
            'IsInRing': torch.nn.Linear(1, groupFeatures),
            'IsAromatic': torch.nn.Linear(1, groupFeatures),
            'NumRadicalElectrons': torch.nn.Linear(1, groupFeatures)
        })

    def forward(self, x, edge_index, batch):
        gates = []
        gates.append(self.gates['AtomicNum'](x[:,0:12], edge_index))
        gates.append(self.gates['Degree'](x[:,12:18], edge_index))
        gates.append(self.gates['TotalNumHs'](x[:,18:23], edge_index))
        gates.append(self.gates['ImplicitValence'](x[:,23:29], edge_index))
        gates.append(self.gates['Hybridization'](x[:,29:34], edge_index))
        gates.append(self.gates['FormalCharge'](x[:,34:35], edge_index))
        gates.append(self.gates['IsInRing'](x[:,35:36], edge_index))
        gates.append(self.gates['IsAromatic'](x[:,36:37], edge_index))
        gates.append(self.gates['NumRadicalElectrons'](x[:,37:38], edge_index))
        logits = torch.cat(gates, dim=-1)
        attention = torch.softmax(logits, dim=-1).unsqueeze(-1)
        
        subgroups = []
        subgroups.append(self.feats['AtomicNum'](x[:,0:12]) * attention[:,0])
        subgroups.append(self.feats['Degree'](x[:,12:18]) * attention[:,1])
        subgroups.append(self.feats['TotalNumHs'](x[:,18:23]) * attention[:,2])
        subgroups.append(self.feats['ImplicitValence'](x[:,23:29]) * attention[:,3])
        subgroups.append(self.feats['Hybridization'](x[:,29:34]) * attention[:,4])
        subgroups.append(self.feats['FormalCharge'](x[:,34:35]) * attention[:,5])
        subgroups.append(self.feats['IsInRing'](x[:,35:36]) * attention[:,6])
        subgroups.append(self.feats['IsAromatic'](x[:,36:37]) * attention[:,7])
        subgroups.append(self.feats['NumRadicalElectrons'](x[:,37:38]) * attention[:,8])
        x = torch.stack(subgroups, dim=-2)
        x = torch.sum(x, dim=-2)
        
        return x, attention

In [11]:
class GraphNeuralNetwork(torch.nn.Module):  # TODO: assign hyperparameters to attributes and define the forward pass
    def __init__(self, hidden_size, n_convs=3, my_layer=None, features_after_layer=26, n_features=49, dropout=0.2):
        super().__init__()
        self.myAttentionModule = my_layer
        self.dropout = dropout

        convs = torch.nn.ModuleList()
        convs.append(GCNConv(features_after_layer, hidden_size))
        for i in range(1, n_convs):
            convs.append(GCNConv(hidden_size, hidden_size))
        self.convs = convs
        self.linear = torch.nn.Linear(hidden_size, 1)
    
    def forward(self, x, edge_index, batch):
        att = None
        if self.myAttentionModule is not None:
            x, att = self.myAttentionModule(x, edge_index, batch)
        for i in range(0, len(self.convs)-1):
            x = self.convs[i](x, edge_index)
            x = x.relu()
        x = self.convs[-1](x, edge_index)
        
        x = gap(x, batch)
        
        x = F.dropout(x, p=self.dropout, training=self.training)

        out = self.linear(x)

        return out, att

In [12]:
def train(model, train_loader, valid_loader, epochs=20, learning_rate = 0.01):
    model.train()
    
    # training loop
    optimizer = torch.optim.Adam(model.parameters(), learning_rate) # TODO: define an optimizer
    loss_fn = torch.nn.MSELoss()  # TODO: define a loss function
    for epoch in trange(1, epochs + 1, leave=False):
        for data in tqdm(train_loader, leave=False):
            x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
            model.zero_grad()
            preds, att = model(x, edge_index, batch)
            loss = loss_fn(preds, y.reshape(-1, 1))
            loss.backward()
            # print("==============")
            # for par in model.myAttentionModule.parameters():
            #     print(par)
            optimizer.step()
    return model


def predict(model, test_loader):
    # evaluation loop
    preds_batches = []
    with torch.no_grad():
        for data in test_loader:
            x, edge_index, batch = data.x, data.edge_index, data.batch
            
            preds, att = model(x, edge_index, batch)
            preds_batches.append(preds.cpu().detach().numpy())
    preds = np.concatenate(preds_batches)
    return preds, att

In [13]:
from copy import deepcopy

def train_best(model, train_loader, valid_loader, epochs=20, learning_rate = 0.01, saveImg=False, title=""):
    model.train()

    best_state = deepcopy(model.state_dict())
    best_val = 1000000
    
    # training loop
    optimizer = torch.optim.Adam(model.parameters(), learning_rate) # TODO: define an optimizer
    loss_fn = torch.nn.MSELoss()  # TODO: define a loss function
    for epoch in trange(1, epochs + 1, leave=False):
        # preds_batches = []
        running_loss = 0.0
        for data in train_loader:
            x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
            model.zero_grad()
            preds, att = model(x, edge_index, batch)
            loss = loss_fn(preds, y.reshape(-1, 1))

            loss.backward()
            optimizer.step()

        # evaluation loop
        preds_batches = []
        with torch.no_grad():
            for data in valid_loader:
                x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
                preds, att = model(x, edge_index, batch)
                loss = loss_fn(preds, y.reshape(-1, 1))
                preds_batches.append(preds.cpu().detach().numpy())
        preds = np.concatenate(preds_batches)
        mae = rmse(y_valid, preds.flatten())
        if mae < best_val:
            best_state = deepcopy(model.state_dict())
            best_val = mae
            print(best_val)

    model.load_state_dict(best_state)
    return model

In [14]:
def visualize(model, train_loader, valid_loader, test_loader, epochs=20, learning_rate = 0.01, saveImg=False, title=""):
    model.train()

    best_state = deepcopy(model.state_dict())
    best_val = 1000000
    
    # training loop
    optimizer = torch.optim.Adam(model.parameters(), learning_rate) # TODO: define an optimizer
    loss_fn = torch.nn.MSELoss()  # TODO: define a loss function
    train_losses = []
    val_losses = []
    train_errors = []
    val_errors = []
    for epoch in trange(1, epochs + 1, leave=False):
        # preds_batches = []
        running_loss = 0.0
        for data in train_loader:
            x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
            model.zero_grad()
            preds, att = model(x, edge_index, batch)
            loss = loss_fn(preds, y.reshape(-1, 1))
            # print(len(train_dataset))

            running_loss += loss.item()
            # preds_batches.append(preds.cpu().detach().numpy())

            loss.backward()
            optimizer.step()
        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)
        # preds = np.concatenate(preds_batches)
        # mae = rmse(y_train, preds.flatten())
        # train_errors.append(mae)

        # evaluation loop
        preds_batches = []
        running_loss = 0.0
        with torch.no_grad():
            for data in valid_loader:
                x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y
                preds, att = model(x, edge_index, batch)
                loss = loss_fn(preds, y.reshape(-1, 1))
                # print(len(train_dataset))

                running_loss += loss.item()
                preds_batches.append(preds.cpu().detach().numpy())
        epoch_loss = running_loss / len(valid_loader)
        val_losses.append(epoch_loss)
        preds = np.concatenate(preds_batches)
        mae = rmse(y_valid, preds.flatten())
        if mae < best_val:
            best_state = deepcopy(model.state_dict())
            best_val = mae
            print(best_val)
        val_errors.append(mae)

    model.load_state_dict(best_state)

    ##### visualize ########
    plt.plot(train_losses, label='train_loss')
    plt.plot(val_losses, label='val_loss')
    plt.legend()
    plt.show()
    if saveImg:
        plt.savefig(title + "_loss.png")

    # plt.plot(train_errors,label='train_errors')
    plt.plot(val_errors, label='val_RMSE')
    plt.legend()
    plt.show()
    if saveImg:
        plt.savefig(title + "_val_error.png")
    return model

In [15]:
m =  GraphNeuralNetwork(512, n_convs=3, features_after_layer=25)
predictions, att = predict(m, test_loader10)
rmse_score = rmse(y_test, predictions.flatten())
print("{:.2f}".format(rmse_score))

0.78


In [16]:
######################## wizualizacje #################################################
m = GraphNeuralNetwork(512, my_layer=MyAttentionModule4(3), features_after_layer=3)
m = visualize(m, train_loader, valid_loader, test_loader, epochs=100, saveImg=True, title="MyAttentionModule4(3)_qm9")
torch.save(m.myAttentionModule.state_dict(), "attention_pooling3.pth")

m = GraphNeuralNetwork(512, my_layer=MyAttentionModule4(35), features_after_layer=35)
m = visualize(m, train_loader, valid_loader, test_loader, epochs=100, saveImg=True, title="MyAttentionModule4(35)_qm9")
torch.save(m.myAttentionModule.state_dict(), "attention_pooling35.pth")

m = GraphNeuralNetwork(512, my_layer=MyAttentionModule4(100), features_after_layer=100)
m = visualize(m, train_loader, valid_loader, test_loader, epochs=100, saveImg=True, title="MyAttentionModule4(100)_qm9")
torch.save(m.myAttentionModule.state_dict(), "attention_pooling100.pth")

  0%|          | 0/100 [00:00<?, ?it/s]

0.6141783793969974
0.5633457925429605
0.5601623081095508
0.5495416510668535


KeyboardInterrupt: 

In [55]:
######################## tabelka ##########################################################
df = pd.DataFrame({"Repr 1": [], "Repr 10": [],
                   "Atention Pooling v2 - size = 3": [], "Atention Pooling v2 - size = 35": [], "Atention Pooling v2 - size = 100": []})
pd.set_option("display.precision", 2)
n_times = 1

for n_convs in [1, 3, 5]:
    for n_channels in [64, 512]:
        row = []

        #########################
        scores = []
        for _ in range(n_times):
            m =  GraphNeuralNetwork(n_channels, n_convs=n_convs, features_after_layer=26)
            m = train_best(m, train_loader1, valid_loader1, epochs=70)
            predictions, att = predict(m, test_loader1)
            rmse_score = rmse(y_test, predictions.flatten())
            scores.append("{:.2f}".format(rmse_score))
        row.append(" | ".join(scores))

        #########################
        scores = []
        for _ in range(n_times):
            m =  GraphNeuralNetwork(n_channels, n_convs=n_convs, features_after_layer=25)
            m = train_best(m, train_loader10, valid_loader10, epochs= 70)
            predictions, att = predict(m, test_loader10)
            rmse_score = rmse(y_test, predictions.flatten())
            scores.append("{:.2f}".format(rmse_score))
        row.append(" | ".join(scores))

        #########################
        for vect_size in [3, 35, 100]:
            scores = []
            for _ in range(n_times):
                m =  GraphNeuralNetwork(n_channels, n_convs=n_convs, my_layer=MyAttentionModule4(vect_size), features_after_layer=vect_size)
                m = train_best(m, train_loader, valid_loader, epochs=70)
                predictions, att = predict(m, test_loader)
                rmse_score = rmse(y_test, predictions.flatten())
                scores.append("{:.2f}".format(rmse_score))
            row.append(" | ".join(scores))

        df.loc[str(n_convs) + " convs, " + str(n_channels) + " channels"] = row

df.to_csv("qm9_out.csv")

  0%|          | 0/70 [00:00<?, ?it/s]

0.5501999533014232
0.5040687140639524
0.49571092890067847
0.4846211159043977
0.48320055516095733
0.4798495588766813
0.47862055974973894
0.46959427808645865
0.4692307116427061


  0%|          | 0/70 [00:00<?, ?it/s]

0.5118019648357927
0.5108146163040406
0.4840445827762349
0.4742861989983277
0.47330477481478156
0.4681885657518951


  0%|          | 0/70 [00:00<?, ?it/s]

0.561036359623823
0.4769177162116058
0.4744529705127069
0.40883556785632447
0.3919386160856415
0.38347617034513065
0.37588342552404896


  0%|          | 0/70 [00:00<?, ?it/s]

0.5742107988669302
0.4906674313564381
0.434849694428749
0.4063036310486408
0.38088680098015626
0.3791061764394474
0.375671778923798


  0%|          | 0/70 [00:00<?, ?it/s]

0.5648204590068437
0.4204737493766412
0.40843622716617833
0.39246518593037233
0.3902598677958982
0.37078907308215087
0.3595757179397779
0.35346201845336206


  0%|          | 0/70 [00:00<?, ?it/s]

0.5520287189781918
0.48559857960289954
0.47923544684714
0.47886685720290395
0.471933309270446
0.46607840339390455
0.46480330948283693


  0%|          | 0/70 [00:00<?, ?it/s]

0.5110998178408238
0.4860796685755312
0.47270051578897415
0.46767733701163205


  0%|          | 0/70 [00:00<?, ?it/s]

0.580995371184952
0.500531178153558
0.48313098006812244
0.4717063011754981
0.43893967279513935
0.43649096040999275
0.42205237154106046
0.4111093730560148
0.4087300721747237
0.39159351770083184
0.3885609372455406
0.3858369447002459
0.3776937232135942


  0%|          | 0/70 [00:00<?, ?it/s]

0.5115080770676937
0.4839151094237457
0.47667585911141314
0.4610168447248598
0.45971754702866247
0.4449372632458583
0.43868778504774103
0.4380334812758879
0.42161911148377895
0.42040963100370005
0.4177185036045646
0.4162841444610132
0.4100220688922407


  0%|          | 0/70 [00:00<?, ?it/s]

0.4971449339566476
0.46939684863810033
0.3999995704549018
0.38963417552717733
0.37576055820742543
0.37317306945242273
0.37162927551413566
0.3705623131299011
0.369619645832106


  0%|          | 0/70 [00:00<?, ?it/s]

0.5939083538699361
0.5615732030683832
0.5558171326819229
0.5047088203492285
0.4996396830875631
0.46113020371174956
0.45055473600158247
0.4311138086792339
0.4308139026254159
0.42418769599034023
0.41468181470194376
0.4108372316430756
0.40117436921112204
0.394027935025768


  0%|          | 0/70 [00:00<?, ?it/s]

0.6024437365507039
0.5874033679263448
0.5599956640625363
0.548706039783326
0.5459461065742708
0.5351095562370498
0.532330340531174
0.5073445092172706
0.4780029155425512
0.4705713139339443
0.46925587375914457
0.4622008044281099
0.4380974568998195
0.4255670204326471
0.4172136271366522
0.4131157312551266
0.412631130078639
0.4057599442814407
0.4017204481952684
0.4003492571599886
0.3942722603266621


  0%|          | 0/70 [00:00<?, ?it/s]

0.6402671247952043
0.6142171073451764
0.5876454278940789
0.5532898057821921
0.5531518375183435
0.5384364251581455
0.5334754907506167
0.5327773013030714
0.5316496719074199
0.5195250010786147
0.5172499747546381
0.5155615694760607
0.4996217550690369
0.482203989301975
0.47989741909303657
0.4673173154863694
0.4589998236036089


  0%|          | 0/70 [00:00<?, ?it/s]

0.5990910163986459
0.560650603171747
0.5329261031420295
0.5192210622399327
0.5039859453464626
0.4796484833384004
0.47194156123973896
0.40369999642342863
0.39981257435271494
0.39177203012138206
0.375800431581988
0.37310950642631874


  0%|          | 0/70 [00:00<?, ?it/s]

0.5961680073045709
0.5464317482454847
0.5360795353333796
0.5241916930285079
0.5215370248711109
0.5135582275801003
0.51198900564209
0.4805410595677199
0.4427752716958865
0.4242470715817855
0.4138198931521222
0.3983854882043393
0.39687253401010253
0.3810191248581649
0.3789726740280708


  0%|          | 0/70 [00:00<?, ?it/s]

0.6013737373952691
0.5771409082717582
0.5696834634191016
0.5389842464076748
0.49568055165774105
0.4788574731403034
0.4633257464534358
0.4620427472834565
0.4455838321973563
0.431746147987735
0.4303277774400391
0.4252157013912176
0.42393361418152226
0.4234115601704591
0.4174226800693867
0.41124543255903345
0.4085237609230134
0.40821851432693457


  0%|          | 0/70 [00:00<?, ?it/s]

0.5855291139264885
0.5762779694885737
0.5623906308164883
0.5324334075134749
0.5082324020746836
0.49596007457916674
0.49238105003498
0.46853329893231593
0.42858921839848313
0.41299982583538164
0.4027639917916573
0.39330557101694147


  0%|          | 0/70 [00:00<?, ?it/s]

0.7700663665696368
0.6258430699017373
0.5480813684631621
0.5246136135178515
0.52333965167258
0.5182618883311261
0.5118764002443064
0.5113094588617306
0.4892476630757904
0.48789822612558276
0.4857481720232788
0.4751653413753546
0.46367547329697856
0.45119743766910164
0.4271310554046398
0.41033696731509495
0.40916602722142964
0.3962517001854777
0.38474713037150365
0.3831188286186428


  0%|          | 0/70 [00:00<?, ?it/s]

0.611284880529322
0.5798016921132954
0.565883679141672
0.5413909704483308
0.5325595203770777
0.495290180387102
0.46991538002175937
0.43329269039001633
0.4308263080636379
0.4119763966868199
0.3981382644776737
0.38780293228608625
0.3860493180065903
0.38380125918291036


  0%|          | 0/70 [00:00<?, ?it/s]

0.5984623931680215
0.5850773963589374
0.5459695930287695
0.5339812214641233
0.48530988756258764
0.4232247841968625
0.4131790364496137
0.41304306628359727
0.3975458125707353
0.3964233141411874
0.3884801657224965
0.3869223800678591
0.38329221121805396


  0%|          | 0/70 [00:00<?, ?it/s]

0.6245888747546967
0.6244082215837307
0.6111743954891482
0.5845686551527098
0.5822581866345424
0.5776427382063503
0.5773658302516803
0.5701176173809611
0.5646273334223258
0.5591248831534879
0.558012464156884
0.5500178118270967
0.545211788864557
0.5152461277135829
0.5010758204089854
0.4991763902654609
0.4848839998184014
0.48055700644152055
0.4612927018018304
0.455131507237876
0.4511284425318283
0.44842515259452376
0.44513350927822115


  0%|          | 0/70 [00:00<?, ?it/s]

0.6247117070379029
0.6142644878459667
0.595417212792111
0.5848691892958298
0.5776846867037703
0.577596518169868
0.5574309936994306
0.5566640070829169
0.551759354738294
0.5349935126873452
0.5323323754216144
0.5263573244682348
0.5136803859961836
0.4905481854376052
0.4901522603276309
0.4870854981845266
0.4813231839752734
0.461438423790292
0.46082964199684817
0.4604154504012818
0.45559996216612086
0.44283872965349647


  0%|          | 0/70 [00:00<?, ?it/s]

0.7819804065579496
0.7767797867340568
0.774594113960471
0.7745094114248571


  0%|          | 0/70 [00:00<?, ?it/s]

0.6383309652584185
0.6162205190964987
0.5897108050863753
0.5750126141373477
0.5734331495541883
0.5659146422179062
0.5641528420170541
0.5635419899128277
0.5589319450224152
0.5552536022318175
0.5525612909485301
0.5434841838195471
0.5353358723289614
0.5330759415650079


  0%|          | 0/70 [00:00<?, ?it/s]

0.6418846107341808
0.6242920235492927
0.6215664453454283
0.6002506984826467
0.5914058134757075
0.5791260389258788
0.5765026706176508
0.5717531985585395
0.5703350650309089
0.5632695526054322
0.5574435816960731
0.5548699111877413


  0%|          | 0/70 [00:00<?, ?it/s]

0.7498016981175788
0.6545367418425997
0.6100487365257838
0.607665712160602
0.5905372825891204
0.5752385220197626
0.5571505680618279
0.5502322645900306
0.5402791084952631
0.5334305760689935
0.5331263761525168
0.5215383460462141
0.5110089394279633
0.5049574599888639
0.4937705219566576
0.47757983438555973
0.4614047174119055
0.46007277855447815
0.4519998917961033


  0%|          | 0/70 [00:00<?, ?it/s]

0.6810018414105411
0.6447280782822613
0.6170040729425805
0.6005755082534646
0.5929388970623705
0.584361816641567
0.5826193926539804
0.5812824400878384
0.5634673972709808
0.5631013038562275
0.5620852667190868
0.5583571534747693
0.553008708265958
0.5495337957296331
0.5322209987763775
0.516509154920952
0.5148366165924173
0.49568934573541584
0.4870997986875317
0.45729731998348566
0.45155299109297564
0.44952647752051733


  0%|          | 0/70 [00:00<?, ?it/s]

0.7836814262892451
0.7704716078438715
0.6481954083557926
0.6344387891716563
0.5855917551405102
0.5722695899452512
0.5568676321208823
0.5540270574054211
0.5295733412473905
0.5218281312132133
0.5132848579579204
0.5007006958049004
0.49815280130506356
0.49066865205212534


  0%|          | 0/70 [00:00<?, ?it/s]

0.6616480997705977
0.6578122002416822
0.6189817936162643
0.5940903401806986
0.590288069133182
0.581898738825091
0.5769232249731114
0.5650019510798686
0.5618972307943377
0.5596185188876496
0.5507237517033675
0.546935433049442
0.5460610414174646
0.5273632764572124


  0%|          | 0/70 [00:00<?, ?it/s]

0.703158718834372
0.6204553935516193
0.6198256047408206
0.6070166085944769
0.5995730343104876
0.5846234235604006
0.573768634060646
0.5517843223346758
0.5496175779876163
0.5437580601229222
0.5369885673553979
0.5320403355672877
0.5251687081096854


In [56]:
df

Unnamed: 0,Repr 1,Repr 10,Atention Pooling v2 - size = 3,Atention Pooling v2 - size = 35,Atention Pooling v2 - size = 100
"1 convs, 64 channels",0.47,0.47,0.38,0.38,0.35
"1 convs, 512 channels",0.47,0.47,0.38,0.41,0.37
"3 convs, 64 channels",0.39,0.39,0.46,0.37,0.38
"3 convs, 512 channels",0.41,0.39,0.39,0.39,0.38
"5 convs, 64 channels",0.44,0.44,0.77,0.53,0.56
"5 convs, 512 channels",0.45,0.45,0.49,0.53,0.52


In [58]:
m =  GraphNeuralNetwork(512, n_convs=3, my_layer=MyAttentionModule4(35), features_after_layer=35)
m = train_best(m, train_loader, valid_loader, epochs=70)
predictions, att = predict(m, test_loader)
rmse_score = rmse(y_test, predictions.flatten())
print("{:.2f}".format(rmse_score))

  0%|          | 0/70 [00:00<?, ?it/s]

0.6029386290882007
0.5683466222907125
0.5604436025787646
0.5376977818992003
0.5321830417351598
0.4908033227389411
0.44490565618354005
0.43717932747550264
0.41816708144753073
0.4130349799112501
0.40203166803965873
0.3925937594910248
0.39


In [59]:
###################### atencja #################################

df_single = pd.DataFrame({"AtomicNum": [], "Degree": [], "TotalNumHs": [], "ImplicitValence": [], "Hybridization": [], "FormalCharge": [],
                          "IsInRing": [], "IsAromatic": [], "NumRadicalElectrons": []})
df_single.style.set_caption("Hello World")

df_batch = pd.DataFrame({"AtomicNum": [], "Degree": [], "TotalNumHs": [], "ImplicitValence": [], "Hybridization": [], "FormalCharge": [],
                          "IsInRing": [], "IsAromatic": [], "NumRadicalElectrons": []})
df_batch.style.set_caption("Hello World")

preds_batches = []
with torch.no_grad():
    for data in test_loader:
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        preds, att = m(x, edge_index, batch)
        preds_batches.append(preds.cpu().detach().numpy())
        att = att.squeeze()
        df_single.loc[len(df_single)] = att[0].tolist()
        df_batch.loc[len(df_single)] = torch.mean(gap(att, batch), dim=0).tolist()
preds = np.concatenate(preds_batches)

rmse_score = rmse(y_test, predictions.flatten())

print(f'RMSE = {rmse_score:.2f}')
df_single.to_csv("qm9_att_single.csv")
df_batch.to_csv("qm9_att_batch.csv")

RMSE = 0.39


In [62]:
df_single[:10]

Unnamed: 0,AtomicNum,Degree,TotalNumHs,ImplicitValence,Hybridization,FormalCharge,IsInRing,IsAromatic,NumRadicalElectrons
0,0.47,0.06,0.06,0.08,0.06,0.11,0.06,0.06,0.06
1,0.55,0.05,0.06,0.07,0.05,0.09,0.05,0.05,0.05
2,0.48,0.06,0.06,0.08,0.06,0.1,0.06,0.06,0.06
3,0.53,0.05,0.06,0.07,0.05,0.09,0.05,0.05,0.05
4,0.49,0.06,0.06,0.08,0.06,0.1,0.06,0.06,0.06
5,0.48,0.06,0.06,0.08,0.06,0.1,0.06,0.06,0.06
6,0.48,0.06,0.06,0.08,0.06,0.1,0.06,0.06,0.06
7,0.49,0.05,0.06,0.07,0.05,0.1,0.05,0.05,0.05
8,0.49,0.05,0.06,0.07,0.05,0.1,0.05,0.05,0.05
9,0.47,0.06,0.06,0.08,0.06,0.11,0.06,0.06,0.06


In [63]:
df_batch[:10]

Unnamed: 0,AtomicNum,Degree,TotalNumHs,ImplicitValence,Hybridization,FormalCharge,IsInRing,IsAromatic,NumRadicalElectrons
1,0.26,0.08,0.08,0.14,0.07,0.14,0.07,0.07,0.07
2,0.26,0.08,0.08,0.14,0.07,0.14,0.07,0.07,0.07
3,0.26,0.08,0.08,0.14,0.07,0.14,0.07,0.07,0.07
4,0.26,0.08,0.08,0.14,0.07,0.14,0.07,0.07,0.07
5,0.26,0.08,0.08,0.14,0.07,0.14,0.07,0.07,0.07
6,0.26,0.08,0.08,0.15,0.07,0.14,0.07,0.07,0.07
7,0.26,0.08,0.08,0.14,0.07,0.14,0.07,0.07,0.07
8,0.26,0.08,0.08,0.14,0.07,0.14,0.07,0.07,0.07
9,0.26,0.08,0.08,0.14,0.07,0.14,0.07,0.07,0.07
10,0.26,0.08,0.08,0.14,0.07,0.14,0.07,0.07,0.07
