In [1]:
import warnings
warnings.filterwarnings('ignore')

from process_dataset import get_dgl_graph
from subgraph_extraction import extract_subgraph
from graph_to_simplicial_complex import get_simplicial_complex, get_embeddings, device
from hodgelaplacians import HodgeLaplacians
from layers.simplicial_convolution import SimplicialAttentionLayer, SimplicialConvolution
import numpy as np
from model import CookingModel

In [2]:
graph, nx_graph = get_dgl_graph()

> loading nx_graph from cache
> loading dgl_graph from cache


# Sampling

In [3]:
eid = np.random.choice(graph.edges('eid'))
u, v = graph.find_edges(eid)
to_remove = frozenset((u.item(), v.item()))
subgraph = extract_subgraph([u.item(), v.item()], graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=60)
simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph)
embeddings, laplacians, boundaries, idx, label = get_embeddings(simplex_labels, to_remove, dim=4)
print('Laplacians:',[ laplacian.shape if laplacian is not None else 0 for laplacian in laplacians])
print('Boundaries:',[ boundary.shape if boundary is not None else 0 for boundary in boundaries])
print('Embeddings',[ embedding.shape if embedding is not None else 0 for embedding in embeddings])
print('Label',label.shape)
print('Edge',idx)

Laplacians: [torch.Size([7, 7]), torch.Size([10, 10]), torch.Size([1, 1]), 0]
Boundaries: [torch.Size([1, 7]), torch.Size([7, 10]), torch.Size([10, 1]), 0]
Embeddings [torch.Size([7, 20]), torch.Size([10, 20]), torch.Size([1, 20]), 0]
Label torch.Size([20])
Edge 4


# Model Training and Evaluation

In [4]:
from model import CookingModel
import torch
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('/home/adarsh/H-KGC/datasets/cat_edge_cooking/logs')
gs = 0

cm = CookingModel(dim=4, device=device).to(device)
optim = torch.optim.Adam(cm.parameters(), lr=1e-4)

In [5]:
for ep in tqdm(range(500)):
    eid = np.random.choice(graph.edges('eid'))
    u, v = graph.find_edges(eid)
    to_remove = frozenset((u.item(), v.item()))
    subgraph = extract_subgraph([u.item(), v.item()], graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=50)
    simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph)
    embeddings, laplacians, boundaries, idx, label = get_embeddings(simplex_labels, to_remove, dim=4)

    pred = cm(embeddings, laplacians, boundaries, 1, idx).squeeze()
    loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, label)
    optim.zero_grad()
    loss.backward()
    optim.step()
    writer.add_scalar('train loss',loss.item(), gs)
    gs+=1
    torch.cuda.empty_cache()

100%|██████████| 500/500 [03:19<00:00,  2.51it/s]


In [6]:
with torch.no_grad():
    H = []
    for ep in tqdm(range(50)):
        eid = np.random.choice(graph.edges('eid'))
        u, v = graph.find_edges(eid)
        to_remove = frozenset((u.item(), v.item()))
        subgraph = extract_subgraph([u.item(), v.item()], graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=50)
        simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph)
        embeddings, laplacians, boundaries, idx, label = get_embeddings(simplex_labels, to_remove, dim=4)

        pred = cm(embeddings, laplacians, boundaries, 1, idx).squeeze()
        H.append((torch.round(torch.sigmoid(pred))==label).long())

100%|██████████| 50/50 [00:18<00:00,  2.67it/s]


### Test Accuracy

In [7]:
torch.sum(torch.stack(H),dim=0)/len(H)

tensor([0.8600, 0.0600, 0.9400, 0.8800, 0.2000, 0.8400, 0.8400, 0.8400, 0.7200,
        0.8200, 0.8000, 0.9400, 0.8400, 0.8200, 0.2200, 0.8400, 0.8400, 0.8400,
        0.8600, 0.3600], device='cuda:0')