In [1]:
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
!pip install torch-geometric

Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu111.html
Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu111.html


In [2]:
import json
import pandas as pd
import time
import networkx as nx
from torch_geometric.utils import from_networkx

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCN
from torch.utils.data import Dataset, DataLoader

In [3]:
nx_graph = nx.read_gml('graph_with_features.gml')
G = from_networkx(nx_graph, group_node_attrs=['out_degree', 'in_degree', 'category_multi_hot'], group_edge_attrs=['tf_idf', 'num_link_clicked'])

path_data = pd.read_csv('data_by_index.tsv', sep='\t', header=None)

In [4]:
class CustomPathDataset(Dataset):
    def __init__(self, path_data):
        self.x = path_data[0].apply(json.loads)
        self.labels = path_data[1]
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        x = torch.LongTensor(self.x[idx])
        label = self.labels[idx]
        sample = {"indices": x, "label": label}
        return sample

In [86]:
class Model(torch.nn.Module):
    def __init__(self, graph, gnn_hidden_size=128, node_embed_size=64, lstm_hidden_size=32):
        super().__init__()
        self.graph = graph
        
        #yeet = { edge: idx for idx, edge in enumerate(graph.edge_index.T) }
        #print(list(yeet.keys())[:5])
        #print(torch.tensor((0, 530)))
        #print(yeet[torch.tensor((0, 530))])


        self.edge_feat = dict(zip(graph.edge_index.T, graph.edge_attr))
        print(graph.edge_attr.shape)
        self.edge_feat[torch.tensor([-1, -1])] = torch.zeros(graph.edge_attr.shape[1])
        self.gcn = GCN(in_channels=self.graph.x.shape[1], 
                       hidden_channels=gnn_hidden_size, 
                       num_layers=3, 
                       out_channels=node_embed_size, 
                       dropout=0.1)
        self.lstm_input_size = node_embed_size # TODO: replace with line below when adding edge features
        # self.lstm_input_size = node_embed_size + 2
        self.lstm = nn.LSTM(input_size=self.lstm_input_size,
                            hidden_size=lstm_hidden_size,
                            batch_first=True)
        self.pred_head = nn.Linear(lstm_hidden_size, self.graph.x.shape[0])

    def forward(self, indices):
        node_emb = self.gcn(self.graph.x, self.graph.edge_index)
        node_emb_with_padding = torch.cat([node_emb, torch.zeros((1, self.lstm_input_size))])
        paths = node_emb_with_padding[indices] # TODO: need to append edge features to data before passing into LSTM        

        edges = indices.unfold(step=1, size=2, dimension=1)
        edges = torch.cat([torch.full(size=(indices.shape[0], 1, 2), fill_value=-1), edges], dim=1)
        print(edges.shape)
        edges.apply_(lambda x: self.edge_feat[x])
        print(edges)
        return None

        # TODO: remember to pad zeros at beginning
        out, _ = self.lstm(paths)
        predictions = self.pred_head(torch.sum(out, dim=1))
        return F.log_softmax(predictions, dim=1)

In [87]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(G).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1, weight_decay=5e-4)

dataset = CustomPathDataset(path_data)
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

batch_size = 3
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

torch.Size([119882, 2])


In [88]:
model.train()
for epoch in range(200):  # loop over the dataset multiple times
    epoch_loss = 0
    start_time = time.time()
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data['indices'], data['label']

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = F.nll_loss(outputs, labels)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()

    # print statistics
    print('Epoch:', epoch+1)
    print('Loss:', epoch_loss / batch_size)
    print('Time:', time.time() - start_time)
    print()

torch.Size([3, 32, 2])
tensor([[[  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,  573],
         [ 573, 2156],
         [2156, 3095],
         [3095,  437],
         [ 437, 4372]],

        [[  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         [  -1,   -1],
         

KeyError: ignored

In [None]:
test_inputs = None
test_labels = None

testloader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
                                          shuffle=True, num_workers=2)

model.eval()
num_correct = 0
for i, data in enumerate(testloader, 0):
      # get the inputs; data is a list of [inputs, labels]
      inputs, labels = data['indices'], data['label']

      outputs = model(inputs)

      pred = model(outputs).argmax(dim=1)
      correct = (pred == label).sum()
      num_correct += correct

acc = int(correct) / int(len(pred))
print(f'Accuracy: {acc:.4f}')

      # print statistics
      print('Epoch:', epoch)
      print('Loss:', loss)

pred = model(test_inputs).argmax(dim=1)
correct = (pred == torch.zeros_like(pred)).sum()
acc = int(correct) / int(len(pred))
print(f'Accuracy: {acc:.4f}')