In [1]:
import os, sys

import torch as tn
import torch.optim 
from torch.nn import ModuleDict 
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero, Linear, GATConv, HeteroConv
from torch_geometric.data import HeteroData

sys.path = ['/Users/walder2/kg_uq/'] + sys.path
path_to_data = '/Users/walder2/kg_uq/recipe_data'

from kg_extraction import kg_to_hetero

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
log_sig = torch.nn.LogSigmoid()

class HeteroGNN(tn.nn.Module):
    def __init__(self, hidden_channels, emb_dim, num_layers, edge_types, node_types):
        super().__init__()
        
        self.emb_dim = emb_dim
        self.convs = tn.nn.ModuleList()
        
        for _ in range(num_layers):
            conv = HeteroConv({key : GATConv((-1, -1), hidden_channels, add_self_loops=False) for key in edge_types}, aggr='sum')
            self.convs.append(conv)
            
        self.lin_dict = ModuleDict({x: Linear(hidden_channels, emb_dim, bias=False) for x in node_types})
        
        
        
    def forward(self, x_dict, edge_index_dict):
        rv = tn.zeros(self.emb_dim)
        
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
            
        # may want to change the aggregation layer here
        for key, x in x_dict.items():
            rv += self.lin_dict[key](torch.mean(x, dim=0))
        return rv
        

In [8]:
train, (kg, kg_idx), (node_map, emb_map) = kg_to_hetero(data_dir=path_to_data, 
                                                        embeddings_model=None, 
                                                        undirected=True)

In [9]:
emb_dim = 10 

edge_types = tuple(set((y for x in train for y in x.edge_types)))
node_types = tuple((y for x in train for y in x.node_types))

encoder = HeteroGNN(hidden_channels=16, num_layers=2, emb_dim=emb_dim, edge_types=edge_types, node_types=node_types)

In [10]:
epochs = 100
ntrain = len(train)
optimizer = torch.optim.Adam(encoder.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.90)

for epoch in range(epochs):
    optimizer.zero_grad()
    
    rv = tn.zeros((ntrain, emb_dim))
    
    for i in range(ntrain):
        rv[i] = encoder(train[i].x_dict, train[i].edge_index_dict)
        
    d_ij = log_sig(tn.matmul(rv, rv.T))
    loss = -tn.sum(tn.diag(d_ij) - tn.logsumexp(d_ij, dim=1))
    
    loss.backward()
    optimizer.step()
    scheduler.step()

    if (epoch+1) % 10 == 0:
        print('Epoch %s, loss %s' % (repr(epoch+1), repr(loss.detach())))
        

Epoch 10, loss tensor(11.3430)
Epoch 20, loss tensor(11.3430)
Epoch 30, loss tensor(11.3430)
Epoch 40, loss tensor(11.3430)
Epoch 50, loss tensor(11.3430)
Epoch 60, loss tensor(11.3430)
Epoch 70, loss tensor(11.3430)
Epoch 80, loss tensor(11.3430)
Epoch 90, loss tensor(11.3430)
Epoch 100, loss tensor(11.3430)
