In [6]:
import torch
# !pip install torch_geometric
from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn.norm import BatchNorm

In [115]:
class GAT_model(torch.nn.Module):
    def __init__(self, embedding_size=32):
        super().__init__()
        
        self.embedding_size = embedding_size
        
        # embed and transform
        self.embed = torch.nn.Embedding(
            num_embeddings, self.embedding_size)
        
        self.first_transform = torch.nn.Linear(embedding_size, embedding_size)
        self.relu = torch.nn.ReLU()
        
        # first block
        self.conv1 = GATConv(
            self.embedding_size, self.embedding_size, heads=3, dropout=0.2)
        self.head_reshape1 = torch.nn.Linear(
            self.embedding_size*3, self.embedding_size)
        self.pool1 = TopKPooling(embedding_size, ratio=0.6)
        
        # second block
        self.conv2 = GATConv(
            self.embedding_size, self.embedding_size, heads=3, dropout=0.2)
        self.head_reshape2 = torch.nn.Linear(
            self.embedding_size*3, self.embedding_size)
        self.pool2 = TopKPooling(self.embedding_size, ratio=0.5)
        
        # third block
        self.conv3 = GATConv(
            self.embedding_size, self.embedding_size, heads=3, dropout=0.2)
        self.head_reshape3 = torch.nn.Linear(
            self.embedding_size*3, self.embedding_size)
        self.pool3 = TopKPooling(self.embedding_size, ratio=0.3)
        
        # final FF block
        self.linear1 = torch.nn.Linear(6*self.embedding_size, self.embedding_size//2)
        self.linear2 = torch.nn.Linear(self.embedding_size//2, 1)
    
        
    
    def forward(self, x, edge_index, batch_index):
        x = self.embed(x)
        x = x.reshape(-1, self.embedding_size)
        
        x = self.first_transform(x)
        x = self.relu(x)
        
        # block 1
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.head_reshape1(x)
        x, edge_index, _, batch_index, _, _ = self.pool1(
            x, edge_index, None, batch_index)
        # aggregate intermediate values
        pooled1 = torch.cat([gap(x, batch_index), gmp(x, batch_index)], dim=1)
        
        # block 2
        x = self.conv2(x, edge_index)
        x = self.relu(x)
        x = self.head_reshape2(x)
        x, edge_index, _, batch_index, _, _ = self.pool2(
            x, edge_index, None, batch_index)
        # aggregate intermediate values
        pooled2 = torch.cat([gap(x, batch_index), gmp(x, batch_index)], dim=1)
        
        # block 3
        x = self.conv3(x, edge_index)
        x = self.relu(x)
        x = self.head_reshape3(x)
        x, edge_index, _, batch_index, _, _ = self.pool3(
            x, edge_index, None, batch_index)
        
        pooled3 = torch.cat([gap(x, batch_index), gmp(x, batch_index)], dim=1)
        
        # final block 
        x = torch.cat([pooled1, pooled2, pooled3], dim=1)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        
        return x