In [1]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from torch_geometric.data import NeighborSampler
from torch_geometric.nn import SAGEConv
import os
from torch_geometric.utils import to_networkx
import networkx as nx
# importing matplotlib.pyplot
import matplotlib.pyplot as plt
import numpy as np

In [2]:
# path to dataset
root = '/Users/sachin/Desktop/arangodb/scripts/ArangoML/graph_embeddings/products'

In [3]:
dataset = PygNodePropPredDataset('ogbn-products', root)

In [4]:
# getting train val test split idx based on sales ranking
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name='ogbn-products')
data = dataset[0]

In [5]:
test_idx = split_idx['test']


In [6]:
test_loader = NeighborSampler(data.edge_index, node_idx=test_idx,
                              sizes=[15, 10, 5], batch_size=1,
                              shuffle=False, num_workers=12)

In [7]:
# selecting random test node and its adjacency matrix
dummy_n_ids = []
dummy_adjs = []
for idx, (batch_size, n_id, adjs) in enumerate(test_loader):
    if idx ==1:
        dummy_n_ids.append(n_id)
        dummy_adjs.append(adjs)
        break

In [8]:
dummy_n_ids[0].shape

torch.Size([718])

In [9]:
len(dummy_adjs[0])

3

In [10]:
# creating adjs for performing a trace on the GraphSage model
# will contain only edge_idx and size attributes
edge_list_0 = []
edge_list_1 = []
edge_list_2 = []
edge_adjs = []
for idx, e_idx in enumerate(dummy_adjs[0]):
    if idx == 0:
        edge_list_0.append(e_idx[0])
        #edge_list_0.append(e_idx[1])
        edge_list_0.append(torch.tensor(np.asarray(e_idx[2])))
    elif idx == 1:
        edge_list_1.append(e_idx[0])
        #edge_list_1.append(e_idx[1])
        edge_list_1.append(torch.tensor(np.asarray(e_idx[2])))
    else:
        edge_list_2.append(e_idx[0])
        #edge_list_2.append(e_idx[1])
        edge_list_2.append(torch.tensor(np.asarray(e_idx[2])))

In [11]:
edge_index_0 = edge_list_0[0]
edge_size_0 = edge_list_0[1]

edge_index_1 = edge_list_1[0]
edge_size_1 = edge_list_1[1]

edge_index_2 = edge_list_2[0]
edge_size_2 = edge_list_2[1]

In [12]:
edge_size_0

tensor([718, 153])

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = data.x.to(device)

In [14]:
# total number of nodes involved in the computation graph
x[dummy_n_ids[0]].shape

torch.Size([718, 100])

In [15]:
# lets create node dummy input for the trace
dummy_x = x[dummy_n_ids[0]]
print(dummy_x.shape)

torch.Size([718, 100])


In [16]:
# padding nodes
max_nodes = 1000
total_nodes = dummy_x.size(0)
nodes_padded = max_nodes - total_nodes
dummy_x_pad = F.pad(input=dummy_x, pad=(0, 0, 0, nodes_padded), mode='constant', value=0)
print(dummy_x_pad.shape)

torch.Size([1000, 100])


In [27]:
# graph sage
class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3):
        super(SAGE, self).__init__()

        self.num_layers = num_layers

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))
    
    def forward(self, x, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2):
        # `train_loader` computes the k-hop neighborhood of a batch of nodes,
        # and returns, for each layer, a bipartite graph object, holding the
        # bipartite edges `edge_index`, the index `e_id` of the original edges,
        # and the size/shape `size` of the bipartite graph.
        # Target nodes are also included in the source nodes so that one can
        # easily apply skip-connections or add self-loops.
        max_target_nodes = 500
        for i in range(3):
            xs = []
            
            if i == 0:
                edge_index = edge_index_0
                size = edge_size_0
            elif i == 1:
                edge_index = edge_index_1
                size = edge_size_1
            elif i ==2:
                edge_index = edge_index_2
                size = edge_size_2
                
            x_target = x[:size[1]]  # Target nodes are always placed first.
            tar_nodes_padded = max_target_nodes - size[1]
            x_target = F.pad(input=x_target, pad=(0, 0, 0, tar_nodes_padded), mode='constant', value=0)

            x = self.convs[i]((x, x_target), edge_index)
            
            if i != self.num_layers - 1:
                x = F.relu(x)
                #x = F.dropout(x, p=0.5, training=self.training)
            xs.append(x)
            if i == 0: 
                x_all = torch.cat(xs, dim=0)
                layer_1_embeddings = x_all
            elif i == 1:
                x_all = torch.cat(xs, dim=0)
                layer_2_embeddings = x_all
            elif i == 2:
                x_all = torch.cat(xs, dim=0)
                layer_3_embeddings = x_all    
        #return x.log_softmax(dim=-1)
        return layer_1_embeddings, layer_2_embeddings, layer_3_embeddings

In [28]:
# import model and chechkpoint
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SAGE(dataset.num_features, 256, dataset.num_classes)

In [29]:
# loading checkpont
model_checkpoint = './weight_checkpoint.pth.tar'
model_w = torch.load(model_checkpoint)
model_w = model_w["state_dict"]
model.load_state_dict(model_w)

<All keys matched successfully>

# Performing Trace on the Model

In [30]:
class PyTorch_to_TorchScript(torch.nn.Module):
    def __init__(self):
        super(PyTorch_to_TorchScript, self).__init__()
        self.model = model
    def forward(self, data, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2):
        return self.model(data, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2)

In [31]:
# after trace it will save the model in cwd
pt_model = PyTorch_to_TorchScript().eval()

In [32]:
# trace
traced_script_module = torch.jit.trace(pt_model, (dummy_x_pad, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2), strict=False)

In [33]:
# saving the traced model in cwd
traced_script_module.save("./model.pt")

## Loading traced model

In [34]:
tr_model = torch.jit.load("./model.pt")

In [35]:
tr_out = tr_model(dummy_x_pad, edge_index_0, edge_size_0, edge_index_1, edge_size_1, edge_index_2, edge_size_2)

In [36]:
# layer-1, layer-2, layer-3 embeddings
out[0].shape, out[1].shape, out[2].shape

(torch.Size([500, 256]), torch.Size([500, 256]), torch.Size([500, 47]))

In [37]:
out[0]

tensor([[0.0000, 1.1483, 0.0000,  ..., 0.0866, 1.2719, 0.0000],
        [0.0000, 0.9348, 0.0000,  ..., 0.8225, 1.4933, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 1.6156, 2.7243, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       grad_fn=<CatBackward>)