In [1]:
import sys
sys.path.append('../..')

import warnings
warnings.filterwarnings('ignore')

from preprocess.process_dataset import get_dgl_graph
from preprocess.subgraph_extraction import extract_subgraph
from preprocess.graph_to_simplicial_complex import get_simplicial_complex, get_embeddings, device, _get_simplices, random_sample
from hodgelaplacians import HodgeLaplacians
from layers.simplicial_convolution import SimplicialAttentionLayer, SimplicialConvolution
import numpy as np

In [2]:
dataset = 'cat_edge_cooking'
num_classes = 20

In [None]:
dataset = 'cat_edge_DAWN'
num_classes = 11

In [3]:
graph, nx_graph = get_dgl_graph(dataset)

> loading nx_graph from cache
> loading dgl_graph from cache


# Sampling

In [25]:
simplex, order, label = random_sample(dataset, num_classes=num_classes, max_dim=4)
to_remove = frozenset(simplex)
subgraph = extract_subgraph(simplex, graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=30)
simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes)
embeddings, laplacians, boundaries, idx = get_embeddings(simplex_labels, to_remove, num_classes, dim=4)
print('Laplacians:',[ laplacian.shape if laplacian is not None else 0 for laplacian in laplacians])
print('Boundaries:',[ boundary.shape if boundary is not None else 0 for boundary in boundaries])
print('Embeddings',[ embedding.shape if embedding is not None else 0 for embedding in embeddings])
print('Label',label)
print('Order:',order, 'Index:',idx)

Laplacians: [torch.Size([4, 4]), torch.Size([4, 4]), 0, 0]
Boundaries: [torch.Size([1, 4]), torch.Size([4, 4]), 0, 0]
Embeddings [torch.Size([4, 20]), torch.Size([4, 20]), 0, 0]
Label tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 1.], device='cuda:0')
Order: 1 Index: 0


# Model Training and Evaluation

In [30]:
from models.model import SimplicialModel1
import torch
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(f'/home/adarsh/H-KGC/datasets/{dataset}/logs')
gs = 0

cm = SimplicialModel1(num_classes, dim=4, device=device).to(device)
optim = torch.optim.Adam(cm.parameters(), lr=1e-4)

In [34]:
def train():
    gs = 0
    for ep in tqdm(range(500)):
        simplex, order, label = random_sample(dataset, num_classes=num_classes, max_dim=4)
        to_remove = frozenset(simplex)
        subgraph = extract_subgraph(simplex, graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=50)
        simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes)
        embeddings, laplacians, boundaries, idx = get_embeddings(simplex_labels, to_remove, num_classes, dim=4)

        pred = cm(embeddings, laplacians, boundaries, order, idx).squeeze()
        loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, label)
        optim.zero_grad()
        loss.backward()
        optim.step()
        writer.add_scalar('train loss',loss.item(), gs)
        gs+=1
        torch.cuda.empty_cache()

In [35]:
import cProfile
import pstats

profile = cProfile.Profile()
profile.runcall(train)
ps = pstats.Stats(profile)
ps.print_stats()

 14%|█▍        | 70/500 [00:52<06:28,  1.11it/s]

In [27]:
with torch.no_grad():
    H = []
    for ep in tqdm(range(50)):
        simplex, order, label = random_sample(dataset, num_classes=num_classes, max_dim=4)
        to_remove = frozenset(simplex)
        subgraph = extract_subgraph(simplex, graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=30)
        simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes)
        embeddings, laplacians, boundaries, idx = get_embeddings(simplex_labels, to_remove, num_classes, dim=4)

        pred = cm(embeddings, laplacians, boundaries, order, idx).squeeze()
        H.append((torch.round(torch.sigmoid(pred))==label).long())

100%|██████████| 50/50 [00:23<00:00,  2.09it/s]


### Test Accuracy

In [29]:
torch.sum(torch.stack(H),dim=0)/len(H)

tensor([0.7000, 0.6800, 0.6600, 0.6800, 0.6800, 0.6200, 0.5000, 0.6200, 0.5800,
        0.7200, 0.5600, 0.7000, 0.6400, 0.6800, 0.5800, 0.6400, 0.6800, 0.5800,
        0.6200, 0.5800], device='cuda:0')