In [None]:
import sys
sys.path.append('../..')

import warnings
warnings.filterwarnings('ignore')

from preprocess.process_dataset import get_dgl_graph
from preprocess.subgraph_extraction import extract_subgraph
from preprocess.graph_to_simplicial_complex import get_simplicial_complex, get_embeddings, device, _get_simplices
from hodgelaplacians import HodgeLaplacians
from layers.simplicial_convolution import SimplicialAttentionLayer, SimplicialConvolution
import numpy as np

In [None]:
dataset = 'cat_edge_cooking'
num_classes = 20

In [None]:
dataset = 'cat_edge_DAWN'
num_classes = 11

In [None]:
graph, nx_graph = get_dgl_graph(dataset)

# Sampling

In [None]:
eid = np.random.choice(graph.edges('eid'))
u, v = graph.find_edges(eid)
to_remove = frozenset((u.item(), v.item()))
subgraph = extract_subgraph([u.item(), v.item()], graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=60)
simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes)
embeddings, laplacians, boundaries, idx, label = get_embeddings(simplex_labels, to_remove, num_classes, dim=4)
print('Laplacians:',[ laplacian.shape if laplacian is not None else 0 for laplacian in laplacians])
print('Boundaries:',[ boundary.shape if boundary is not None else 0 for boundary in boundaries])
print('Embeddings',[ embedding.shape if embedding is not None else 0 for embedding in embeddings])
print('Label',label.shape)
print('Edge',idx)

# Model Training and Evaluation

In [None]:
from models.model import SimplicialModel1
import torch
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('/home/adarsh/H-KGC/datasets/cat_edge_cooking/logs')
gs = 0

cm = SimplicialModel1(num_classes, dim=4, device=device).to(device)
optim = torch.optim.Adam(cm.parameters(), lr=1e-4)

In [None]:
for ep in tqdm(range(500)):
    eid = np.random.choice(graph.edges('eid'))
    u, v = graph.find_edges(eid)
    to_remove = frozenset((u.item(), v.item()))
    subgraph = extract_subgraph([u.item(), v.item()], graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=60)
    simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes)
    embeddings, laplacians, boundaries, idx, label = get_embeddings(simplex_labels, to_remove, num_classes, dim=4)

    pred = cm(embeddings, laplacians, boundaries, 1, idx).squeeze()
    loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, label)
    optim.zero_grad()
    loss.backward()
    optim.step()
    writer.add_scalar('train loss',loss.item(), gs)
    gs+=1
    torch.cuda.empty_cache()

In [None]:
with torch.no_grad():
    H = []
    for ep in tqdm(range(50)):
        eid = np.random.choice(graph.edges('eid'))
        u, v = graph.find_edges(eid)
        to_remove = frozenset((u.item(), v.item()))
        subgraph = extract_subgraph([u.item(), v.item()], graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=60)
        simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes)
        embeddings, laplacians, boundaries, idx, label = get_embeddings(simplex_labels, to_remove, num_classes, dim=4)

        pred = cm(embeddings, laplacians, boundaries, 1, idx).squeeze()
        H.append((torch.round(torch.sigmoid(pred))==label).long())

### Test Accuracy

In [None]:
torch.sum(torch.stack(H),dim=0)/len(H)