In order to do structure aware transformer on FS-MOL we first need to extract the K_Hop Subgraphs of a single Graph instance. This is where we will figure out how to do that. 

In [1]:
import os
import sys

FS_MOL_CHECKOUT_PATH = os.path.abspath('../')

os.chdir(FS_MOL_CHECKOUT_PATH)
sys.path.insert(0, FS_MOL_CHECKOUT_PATH)

In [2]:
from fs_mol.data import FSMolDataset, DataFold, FSMolTask

FS_MOL_DATASET_PATH = '/FS-MOL/datasets/fs-mol/'

dataset = FSMolDataset.from_directory(FS_MOL_DATASET_PATH, num_workers=0)



In [3]:
def task_reader(paths, idx):
    print(paths)
    
    task = FSMolTask.load_from_file(paths[0], output_type='s_attn')
    
    return [task]

iterator = dataset.get_task_reading_iterable(DataFold.TRAIN, task_reader)

In [4]:
iterator = iter(iterator)

In [5]:
example = next(iterator)

graph = example.samples[0].graph

graph

[/FS-MOL/datasets/fs-mol/train/CHEMBL659000.jsonl.gz]
torch.Size([2, 43])
torch.Size([2, 49])
torch.Size([2, 40])
torch.Size([2, 53])
torch.Size([2, 42])
torch.Size([2, 44])
torch.Size([2, 48])
torch.Size([2, 48])
torch.Size([2, 27])
torch.Size([2, 43])
torch.Size([2, 55])
torch.Size([2, 49])
torch.Size([2, 27])
torch.Size([2, 30])
torch.Size([2, 27])
torch.Size([2, 39])
torch.Size([2, 42])
torch.Size([2, 45])
torch.Size([2, 46])
torch.Size([2, 55])
torch.Size([2, 53])
torch.Size([2, 61])
torch.Size([2, 32])
torch.Size([2, 35])
torch.Size([2, 40])
torch.Size([2, 33])
torch.Size([2, 51])
torch.Size([2, 29])
torch.Size([2, 35])
torch.Size([2, 38])
torch.Size([2, 52])
torch.Size([2, 42])
torch.Size([2, 46])
torch.Size([2, 47])
torch.Size([2, 51])
torch.Size([2, 28])
torch.Size([2, 47])
torch.Size([2, 47])
torch.Size([2, 41])
torch.Size([2, 42])
torch.Size([2, 34])
torch.Size([2, 40])
torch.Size([2, 45])
torch.Size([2, 26])
torch.Size([2, 32])


Data(x=[42, 32], edge_index=[2, 43], edge_attr=[43], subgraph_node_index=[220], subgraph_edge_index=[2, 178], subgraph_indicator_index=[220], subgraph_edge_attr=[178])

In [6]:
import torch_geometric.utils as utils
from torch_geometric.data import Data
import torch

def add_subgraph_info(graph, k_hops: int):
    print(graph.edge_index.shape)
    num_nodes = graph.x.shape[0]
    
    node_indices = []
    edge_indices = []
    edge_attributes = []
    indicators = []
    edge_index_start = 0
    
    for node_idx in range(num_nodes):
        sub_nodes, sub_edge_index, _, edge_mask = utils.k_hop_subgraph(
                            node_idx, 
                            k_hops, 
                            graph.edge_index,
                            relabel_nodes=True, 
                            num_nodes=num_nodes
                            )

        node_indices.append(sub_nodes)
        edge_indices.append(sub_edge_index + edge_index_start)
        indicators.append(torch.zeros(sub_nodes.shape[0]).fill_(node_idx))
        edge_attributes.append(graph.edge_attr[edge_mask]) # CHECK THIS DIDN"T BREAK ANYTHING
        edge_index_start += len(sub_nodes)
        
    graph.subgraph_node_index = torch.cat(node_indices)
    graph.subgraph_edge_index = torch.cat(edge_indices, dim=1)
    graph.subgraph_indicator_index = torch.cat(indicators)
    graph.subgraph_edge_attr = torch.cat(edge_attributes)
    
    return graph
        
add_subgraph_info(graph, 4)

torch.Size([2, 43])


Data(x=[42, 32], edge_index=[2, 43], edge_attr=[43], subgraph_node_index=[220], subgraph_edge_index=[2, 178], subgraph_indicator_index=[220], subgraph_edge_attr=[178])