In [None]:
import sys
sys.path.append('../..')

import warnings
warnings.filterwarnings('ignore')

from preprocess.process_dataset import get_dgl_graph
from preprocess.subgraph_extraction import extract_subgraph
from preprocess.graph_to_simplicial_complex import get_simplicial_complex, get_embeddings, _get_simplices, random_sample
from hodgelaplacians import HodgeLaplacians
from global_parameters import device
from layers.simplicial_convolution import SimplicialAttentionLayer, SimplicialConvolution
import numpy as np
import timeit
from torch.utils.data import Dataset, DataLoader

In [None]:
dataset = 'cat_edge_DAWN'
num_classes = 11

In [None]:
dataset = 'cat_edge_cooking'
num_classes = 20

In [None]:
graph, nx_graph = get_dgl_graph(dataset)

# Sampling

> Hyperparameters for graph subsampling controlling the number of nodes sampled at each hop corresponding to each simplex order.

In [None]:
max_nodes = [10, 100, 150, 200]

In [None]:
simplex, order, label = random_sample(dataset, num_classes=num_classes, max_dim=4)
to_remove = frozenset(simplex)
subgraph = extract_subgraph(simplex, graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=max_nodes[order])
isolated_nodes = ((subgraph.in_degrees() == 0) & (subgraph.out_degrees() == 0)).nonzero().squeeze(1)
subgraph.remove_nodes(isolated_nodes)
simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes)
embeddings, laplacians, boundaries, idx = get_embeddings(simplex_labels, to_remove, num_classes, dim=4)
print('Laplacians:',[ laplacian.shape if laplacian is not None else 0 for laplacian in laplacians])
print('Boundaries:',[ boundary.shape if boundary is not None else 0 for boundary in boundaries])
print('Embeddings',[ embedding.shape if embedding is not None else 0 for embedding in embeddings])
print('Label',label)
print('Order:',order, 'Index:',idx)
print('subgraph : ', subgraph.num_nodes())

# Model Training and Evaluation

In [None]:
from models.model import SimplicialModel1, BaseGNN
import torch
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(f'/home/adarsh/H-KGC/datasets/{dataset}/logs')
gs = 0

cm = SimplicialModel1(num_classes, dim=4, device=device).to(device)
baseGnn = BaseGNN(num_classes, dim=4, device=device).to(device)
optim1 = torch.optim.Adam(cm.parameters(), lr=1e-5, betas=(0.8, 0.9))
optim2 = torch.optim.Adam(baseGnn.parameters(), lr=1e-5, betas=(0.8, 0.9))

In [None]:
class MyDataset(Dataset):

    def __getitem__(self, index):
        b = 1
        while b!=0:
            try:
                simplex, order, label = random_sample(dataset, num_classes=num_classes, max_dim=4)
                to_remove = frozenset(simplex)
                subgraph = extract_subgraph(simplex, graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=max_nodes[order])
                isolated_nodes = ((subgraph.in_degrees() == 0) & (subgraph.out_degrees() == 0)).nonzero().squeeze(1)
                subgraph.remove_nodes(isolated_nodes)
                simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes)
                b = 0
            except:
                pass
        embeddings, laplacians, boundaries, idx = get_embeddings(simplex_labels, to_remove, num_classes, dim=4)
        return embeddings, laplacians, boundaries, order, idx, label, subgraph
    
    def __len__(self):
        return 5000

def custom_collate(X):
    return X[0]

dataloader = DataLoader(MyDataset(), batch_size=1, num_workers=16, collate_fn=custom_collate)

In [None]:
gs = 0
timeout = 0
loss1 = 0 
loss2 = 0
ep = 0
with tqdm(dataloader) as tepoch:
    for embeddings, laplacians, boundaries, order, idx, label, subgraph in tepoch:
        label, subgraph = label.to(device), subgraph.to(device)
        embeddings = [ x.to(device) if x is not None else None for x in embeddings]
        laplacians = [ x.to(device) if x is not None else None for x in laplacians]
        boundaries = [ x.to(device) if x is not None else None for x in boundaries]

        try:
            pred = cm(embeddings, laplacians, boundaries, order, idx).squeeze()
            loss1 += torch.nn.functional.binary_cross_entropy_with_logits(pred, label)
            
            
            subgraph = subgraph.to(device)
            pred = baseGnn(subgraph, embeddings[0], order).squeeze()
            loss2 += torch.nn.functional.binary_cross_entropy_with_logits(pred, label)

            # if ep%4==0:
            optim1.zero_grad()
            loss1.backward() 
            optim1.step()
            optim2.zero_grad()
            loss2.backward()
            optim2.step()
            writer.add_scalars('Train Loss',{'Simplicial CNN': loss1.item(), 'Vanilla GNN': loss2.item()}, gs)

            gs+=1
            torch.cuda.empty_cache()

            loss1 = 0
            loss2 = 0
        except:
            break

> The following analysis established increasing the density of sampled subgraph increases the autograd function's time. \
> The preprocessing steps are efficient enough

In [None]:
import cProfile
import pstats

# profile = cProfile.Profile()
# profile.runcall(train)
# ps = pstats.Stats(profile)
# ps.print_stats()

In [None]:
with torch.no_grad():
    ep = 0
    H1 = []
    H2 = []
    with tqdm(dataloader) as tepoch:
        for embeddings, laplacians, boundaries, order, idx, label, subgraph in tepoch:
            label, subgraph = label.to(device), subgraph.to(device)
            embeddings = [ x.to(device) if x is not None else None for x in embeddings]
            laplacians = [ x.to(device) if x is not None else None for x in laplacians]
            boundaries = [ x.to(device) if x is not None else None for x in boundaries]
            try:
                pred1 = cm(embeddings, laplacians, boundaries, order, idx).squeeze()
                pred2 = baseGnn(subgraph, embeddings[0], order).squeeze()
                H1.append((torch.round(torch.sigmoid(pred1))==label).long())
                H2.append((torch.round(torch.sigmoid(pred2))==label).long())
                ep += 1
                torch.cuda.empty_cache()
                if ep%1000 == 0:
                    break
            except:
                pass

### Test Accuracy

In [None]:
A = torch.sum(torch.stack(H2),dim=0)/len(H2)

In [None]:
B = torch.sum(torch.stack(H1),dim=0)/len(H1)

In [None]:
B-A