In [None]:
import torch
from torch.functional import F
from torch.nn import Linear, Parameter, Sequential, BatchNorm1d, ReLU, Dropout
from torch_geometric.nn import MessagePassing, GCNConv, GATv2Conv, GINConv, GINEConv, global_mean_pool
from torch_geometric.utils import add_self_loops, remove_self_loops, degree, softmax

In [None]:
import torch
from torch_geometric.datasets import MoleculeNet

dataset = MoleculeNet(root='../../Data/MoleculeNet', name='BACE')

import pandas as pd

smiles_list = pd.read_csv('../../Data/MoleculeNet/bace/raw/bace.csv')['mol'].tolist()

import sys
sys.path.append('../../')
sys.path.append('../../Libs')
from Libs.splitting import scaffold_split

train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, task_idx=None, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)

from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)

In [None]:
class GATModel(torch.nn.Module):
    def __init__(self, input_dim, dim_h, final_dim, num_heads, edge_dim, **kwargs):
        super().__init__()
        torch.manual_seed(42)

        # Layers
        self.conv1 = GATv2Conv(input_dim, dim_h,
                               edge_dim=edge_dim, heads=num_heads, **kwargs)
        self.conv2 = GATv2Conv(dim_h * num_heads, dim_h,
                               edge_dim=edge_dim, heads=num_heads, **kwargs)
        self.conv3 = GATv2Conv(dim_h * num_heads, dim_h,
                               edge_dim=edge_dim, heads=num_heads, **kwargs)
        self.conv4 = GATv2Conv(dim_h * num_heads, dim_h,
                               edge_dim=edge_dim, heads=num_heads, **kwargs)
        self.conv5 = GATv2Conv(
            dim_h * num_heads, dim_h, edge_dim=edge_dim, heads=num_heads, concat=False, **kwargs)

        # Linear layer
        self.lin1 = Linear(dim_h * 5, dim_h * 5)

        # Classification head
        self.lin2 = Linear(dim_h * 5, final_dim)

    def forward(self, x, edge_index, edge_attr, batch):
        # Embedding
        h1 = self.conv1(x, edge_index, edge_attr)
        h1 = h1.relu()
        h2 = self.conv2(h1, edge_index, edge_attr)
        h2 = h2.relu()
        h3 = self.conv3(h2, edge_index, edge_attr)
        h3 = h3.relu()
        h4 = self.conv4(h3, edge_index, edge_attr)
        h4 = h4.relu()
        h5 = self.conv5(h4, edge_index, edge_attr)
        h5 = h5.relu()

        C = h5.shape[-1]  # dim_h
        H = h4.shape[-1] // C  # num_heads

        print(h1.shape, h2.shape, h3.shape, h4.shape, h5.shape)
        print(h1.view(-1, H, C).mean(dim=1).shape)
        # Graph-level readout
        h1 = global_mean_pool(h1.view(-1, H, C).mean(dim=1), batch)
        h2 = global_mean_pool(h2.view(-1, H, C).mean(dim=1), batch)
        h3 = global_mean_pool(h3.view(-1, H, C).mean(dim=1), batch)
        h4 = global_mean_pool(h4.view(-1, H, C).mean(dim=1), batch)
        h5 = global_mean_pool(h5, batch)
        print(h1.shape, h2.shape, h3.shape, h4.shape, h5.shape)

        h = torch.cat((h1, h2, h3, h4, h5), dim=1)

        # Classifier
        h = self.lin1(h)
        h = h.relu()
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.lin2(h)

        # return F.log_softmax(h, dim=1)
        return h.flatten()

In [None]:
for batch in train_loader:
    print(batch)
    break

In [None]:
model = GATModel(9, 64, 1, 16, 3)
model

In [None]:
for batch in train_loader:
    print(model(batch.x.to(torch.float), batch.edge_index, batch.edge_attr.to(torch.float), batch.batch).shape)
    break