In [22]:
import sys
sys.path.append('../..')

import warnings
warnings.filterwarnings('ignore')

from preprocess.process_dataset import get_dgl_graph
from preprocess.subgraph_extraction import extract_subgraph
from preprocess.graph_to_simplicial_complex import get_simplicial_complex, get_embeddings, _get_simplices, random_sample
from hodgelaplacians import HodgeLaplacians
from global_parameters import device
from layers.simplicial_convolution import SimplicialAttentionLayer, SimplicialConvolution
import numpy as np
import timeit
from torch.utils.data import Dataset, DataLoader

from ignite.contrib.metrics import ROC_AUC

In [None]:
dataset = 'cat_edge_DAWN'
num_classes = 11

In [2]:
dataset = 'cat_edge_cooking'
num_classes = 20

In [3]:
graph, nx_graph = get_dgl_graph(dataset)

> loading nx_graph from cache
> loading dgl_graph from cache


# Sampling

> Hyperparameters for graph subsampling controlling the number of nodes sampled at each hop corresponding to each simplex order.

In [4]:
max_nodes = [10, 100, 150, 200]

In [None]:
simplex, order, label = random_sample(dataset, num_classes=num_classes, max_dim=4)
to_remove = frozenset(simplex)
subgraph = extract_subgraph(simplex, graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=max_nodes[order])
isolated_nodes = ((subgraph.in_degrees() == 0) & (subgraph.out_degrees() == 0)).nonzero().squeeze(1)
subgraph.remove_nodes(isolated_nodes)
simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes)
embeddings, laplacians, boundaries, idx = get_embeddings(simplex_labels, to_remove, num_classes, dim=4)
print('Laplacians:',[ laplacian.shape if laplacian is not None else 0 for laplacian in laplacians])
print('Boundaries:',[ boundary.shape if boundary is not None else 0 for boundary in boundaries])
print('Embeddings',[ embedding.shape if embedding is not None else 0 for embedding in embeddings])
print('Label',label)
print('Order:',order, 'Index:',idx)
print('subgraph : ', subgraph.num_nodes())

# Model Training and Evaluation

In [5]:
from models.model import SimplicialModel1, BaseGNN
import torch
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(f'/home/adarsh/H-KGC/datasets/{dataset}/logs')
gs = 0

cm = SimplicialModel1(num_classes, dim=4, device=device).to(device)
baseGnn = BaseGNN(num_classes, dim=4, device=device).to(device)


In [6]:
class MyDataset(Dataset):

    def __getitem__(self, index):
        b = 1
        while b!=0:
            try:
                simplex, order, label = random_sample(dataset, num_classes=num_classes, max_dim=4)
                to_remove = frozenset(simplex)
                subgraph = extract_subgraph(simplex, graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=max_nodes[order])
                isolated_nodes = ((subgraph.in_degrees() == 0) & (subgraph.out_degrees() == 0)).nonzero().squeeze(1)
                subgraph.remove_nodes(isolated_nodes)
                simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes)
                b = 0
            except:
                pass
        embeddings, laplacians, boundaries, idx = get_embeddings(simplex_labels, to_remove, num_classes, dim=4)
        return embeddings, laplacians, boundaries, order, idx, label, subgraph
    
    def __len__(self):
        return 5000

def custom_collate(X):
    return X[0]

dataloader = DataLoader(MyDataset(), batch_size=1, num_workers=16, collate_fn=custom_collate)

In [7]:
optim1 = torch.optim.Adam(cm.parameters(), lr=1e-5, betas=(0.8, 0.9))
optim2 = torch.optim.Adam(baseGnn.parameters(), lr=1e-5, betas=(0.8, 0.9))

In [8]:
gs = 0
timeout = 0
loss1 = 0 
loss2 = 0
ep = 0
with tqdm(dataloader) as tepoch:
    for embeddings, laplacians, boundaries, order, idx, label, subgraph in tepoch:
        label, subgraph = label.to(device), subgraph.to(device)
        embeddings = [ x.to(device) if x is not None else None for x in embeddings]
        laplacians = [ x.to(device) if x is not None else None for x in laplacians]
        boundaries = [ x.to(device) if x is not None else None for x in boundaries]

        try:
            pred = cm(embeddings, laplacians, boundaries, order, idx).squeeze()
            loss1 += torch.nn.functional.binary_cross_entropy_with_logits(pred, label)
            
            
            subgraph = subgraph.to(device)
            pred = baseGnn(subgraph, embeddings[0], order).squeeze()
            loss2 += torch.nn.functional.binary_cross_entropy_with_logits(pred, label)

            # if ep%4==0:
            optim1.zero_grad()
            loss1.backward() 
            optim1.step()
            optim2.zero_grad()
            loss2.backward()
            optim2.step()
            writer.add_scalars('Train Loss',{'Simplicial CNN': loss1.item(), 'Vanilla GNN': loss2.item()}, gs)

            gs+=1
            torch.cuda.empty_cache()

            loss1 = 0
            loss2 = 0
        except:
            break

100%|██████████| 5000/5000 [18:26<00:00,  4.52it/s]  


> The following analysis established increasing the density of sampled subgraph increases the autograd function's time. \
> The preprocessing steps are efficient enough

In [None]:
import cProfile
import pstats

# profile = cProfile.Profile()
# profile.runcall(train)
# ps = pstats.Stats(profile)
# ps.print_stats()

In [28]:
with torch.no_grad():
    ep = 0
    H1 = []
    H2 = []
    labels = []
    with tqdm(dataloader) as tepoch:
        for embeddings, laplacians, boundaries, order, idx, label, subgraph in tepoch:
            label, subgraph = label.to(device), subgraph.to(device)
            embeddings = [ x.to(device) if x is not None else None for x in embeddings]
            laplacians = [ x.to(device) if x is not None else None for x in laplacians]
            boundaries = [ x.to(device) if x is not None else None for x in boundaries]
            try:
                pred1 = cm(embeddings, laplacians, boundaries, order, idx).squeeze()
                pred2 = baseGnn(subgraph, embeddings[0], order).squeeze()
                # H1.append((torch.round(torch.sigmoid(pred1))==label).long())
                # H2.append((torch.round(torch.sigmoid(pred2))==label).long())
                H1.append(torch.sigmoid(pred1))
                H2.append(torch.sigmoid(pred2))
                labels.append(label)
                ep += 1
                torch.cuda.empty_cache()
                if ep>1000:
                    break
            except:
                pass

 20%|██        | 1000/5000 [03:47<15:09,  4.40it/s]


In [29]:
H1, H2, labels = torch.stack(H1), torch.stack(H2), torch.stack(labels)

## Test Accuracy

In [64]:
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score

### AUC is better with Simplicial GNN

In [70]:
A = roc_auc_score(labels.cpu().numpy(), H1.cpu().numpy(), average=None)
A

array([0.85362325, 0.83792439, 0.8409061 , 0.83943522, 0.83748793,
       0.82497818, 0.29355887, 0.83479952, 0.8181412 , 0.83709467,
       0.83870028, 0.83820453, 0.84192304, 0.84744291, 0.80389505,
       0.83329562, 0.85313566, 0.82802221, 0.85743802, 0.75849401])

In [71]:
B = roc_auc_score(labels.cpu().numpy(), H2.cpu().numpy(), average=None)
B

array([0.45680373, 0.35913932, 0.41404994, 0.31628856, 0.40753399,
       0.40301593, 0.61055588, 0.35659405, 0.4929639 , 0.42093393,
       0.49317433, 0.3385974 , 0.4762268 , 0.46380069, 0.38009829,
       0.40822307, 0.38902195, 0.50085343, 0.3597665 , 0.68225275])

In [69]:
A-B

array([ 0.39681953,  0.47878507,  0.42685616,  0.52314665,  0.42995393,
        0.42196225, -0.316997  ,  0.47820547,  0.3251773 ,  0.41616074,
        0.34552594,  0.49960712,  0.36569625,  0.38364223,  0.42379676,
        0.42507254,  0.46411371,  0.32716878,  0.49767152,  0.07624126])