In [16]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

import pandas as pd

In [2]:
from preprocess import get_graph

In [3]:
data = pd.read_csv(f'/scratch/arihanth.srikar/data/zinc/x001.csv')

In [26]:
class GraphDataset(Dataset):
    def __init__(self, data_dir: str='/scratch/arihanth.srikar', split: str='train') -> None:

        # data = pd.read_pickle(f'{data_dir}/data/zinc/zinc.pkl')
        # data = pd.read_csv(f'{data_dir}/data/zinc/x001.csv')
        # data = data[data['set'] == split]
        self.smiles = data['smiles'].tolist()

    def __len__(self):
        return len(self.smiles)
    
    def __getitem__(self, idx):
        
        # get graph from smiles
        x = get_graph(self.smiles[idx])
        
        # node features, positions, edge indices, edge features
        node_feats = torch.tensor(x['node_feats'], dtype=torch.int64)  # N*9
        positions  = torch.tensor(x['positions'], dtype=torch.float64) # N*3
        edge_list  = torch.tensor(x['edge_index'], dtype=torch.int64)  # 2*E
        edge_feats = torch.tensor(x['edge_attr'], dtype=torch.int64)   # E*3

        # use 0 index for padding and prepare mask
        node_feats = node_feats + 1 # 0 is reserved for padding
        edge_feats = edge_feats + 1 # 0 is reserved for padding
        mask = torch.ones(node_feats.size(0)).bool()

        # construct adjacency matrix
        row, col = edge_list
        adj_mat = torch.zeros(row.size(0), col.size(0))
        adj_mat[row, col] = 1
        adj_mat[col, row] = 1
        adj_mat[torch.arange(row.size(0)), torch.arange(row.size(0))] = 1

        # contruct N*N*E dense edge features
        dense_edges_feats = torch.zeros((edge_list.size(1), edge_list.size(1), edge_feats.size(1)), dtype=torch.int64)
        dense_edges_feats[row, col, :] = edge_feats

        return node_feats, positions, mask, adj_mat, dense_edges_feats

    def collate_fn(self, data):

        # unpack the input data
        node_feats, positions, mask, adj_mat, dense_edges_feats = zip(*data)
        
        # fina the largest graph in the batch
        max_nodes = max([feats.size(0) for feats in node_feats])
        
        # pad the adjacency matrix, node features, positions with all 0s
        adj_mat = torch.vstack([F.pad(mat, (0, max_nodes-mat.size(0), 0, max_nodes-mat.size(0)), "constant", 0).unsqueeze(0) for mat in adj_mat])
        node_feats = pad_sequence(node_feats, batch_first=True, padding_value=0)
        positions = pad_sequence(positions, batch_first=True, padding_value=0)
        
        # pad the mask with all False
        mask = pad_sequence(mask, batch_first=True, padding_value=False)
        
        # pad each matrix in dense_edges_feats with all 0s
        dense_edges_feats = torch.vstack([F.pad(mat, (0, 0, 0, max_nodes-mat.size(0), 0, max_nodes-mat.size(0)), "constant", 0).unsqueeze(0) for mat in dense_edges_feats])
        
        return node_feats, positions, mask, adj_mat, dense_edges_feats

In [27]:
Zinc = GraphDataset(split='val')

In [28]:
# train_loader = DataLoader(Zinc, batch_size=1, collate_fn=Zinc.collate_fn, shuffle=True, num_workers=8, pin_memory=True, prefetch_factor=4)
train_loader = DataLoader(Zinc, batch_size=12, collate_fn=Zinc.collate_fn)

In [29]:
node_feats, positions, mask, adj_mat, dense_edges_feats = next(iter(train_loader))

In [30]:
node_feats.shape, positions.shape, mask.shape, adj_mat.shape, dense_edges_feats.shape

(torch.Size([12, 30, 9]),
 torch.Size([12, 0]),
 torch.Size([12, 30]),
 torch.Size([12, 30, 30]),
 torch.Size([12, 30, 30, 3]))

In [14]:
import sys
sys.path.append('../../graph_transformer_pytorch/graph_transformer_pytorch/')
from graph_transformer_pytorch import GraphTransformer

In [31]:
node_emb = nn.Linear(9, 64)
edge_emb = nn.Linear(3, 64)
model = GraphTransformer(
    dim = 64,
    depth = 2,
    edge_dim = 64,             # optional - if left out, edge dimensions is assumed to be the same as the node dimensions above
    with_feedforwards = True,   # whether to add a feedforward after each attention layer, suggested by literature to be needed
    gated_residual = True,      # to use the gated residual to prevent over-smoothing
    rel_pos_emb = True,          # set to True if the nodes are ordered, default to False
    accept_adjacency_matrix = True  # set this to True
)

In [32]:
nodes, edges = model(node_emb(node_feats.float()), edge_emb(dense_edges_feats.float()), adj_mat=adj_mat, mask=mask)

In [33]:
nodes.shape, edges.shape

(torch.Size([12, 30, 64]), torch.Size([12, 30, 30, 64]))