In [1]:
import sys
sys.path.append('../..')

import warnings
warnings.filterwarnings('ignore')

from preprocess.process_dataset import get_dgl_graph
from preprocess.subgraph_extraction import extract_subgraph
from preprocess.graph_to_simplicial_complex import get_simplicial_complex, get_embeddings, _get_simplices, random_sample
from hodgelaplacians import HodgeLaplacians
from global_parameters import device
from layers.simplicial_convolution import SimplicialAttentionLayer, SimplicialConvolution
import numpy as np
import timeit
from torch.utils.data import Dataset, DataLoader
import random
import torch
import os

def set_seed(seed = 42):
    '''
        For Reproducibility: Sets the seed of the entire notebook.
    '''

    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    # Sets a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)

# set_seed(2)

In [None]:
# Author mentions this dataset to have 10 classes. but the labels shows more than 1500 classes with severe class imbalance.
dataset = 'cat_edge_DAWN' # problem with dataset
num_classes = 10
simplex_order=4

In [2]:
# Our model performs best on this dataset. Results are available below
dataset = 'cat_edge_cooking'
num_classes = 20
simplex_order = 4
max_nodes = [10, 100, 150, 200] # parameters to control the sparsity of the sampled graph

In [2]:
# AUC > 0.5 for all classes but does not achieve the highest auc. The model with highest AUC has AUC < 0.5 for some classes.
dataset = 'cat_edge_MAG_10'
num_classes = 10
simplex_order=4
max_nodes = [10, 80, 100, 150]

In [30]:
# Our model performs best on this dataset. Results are available below
dataset = 'cat_edge_algebra_questions'
num_classes = 32
simplex_order = 4
max_nodes = [3, 20, 25, 30]

In [3]:
graph, nx_graph = get_dgl_graph(dataset)

# Sampling

> Hyperparameters for graph subsampling controlling the number of nodes sampled at each hop corresponding to each simplex order.

In [10]:
simplex, order, pos_label, neg_label = random_sample(dataset, num_classes=num_classes, max_dim=simplex_order, dim=1)
to_remove = frozenset(simplex)
subgraph = extract_subgraph(simplex, graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=max_nodes[order])
isolated_nodes = ((subgraph.in_degrees() == 0) & (subgraph.out_degrees() == 0)).nonzero().squeeze(1)
subgraph.remove_nodes(isolated_nodes)
pos_simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes, positive=True)
neg_simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes, positive=False)
pos_embeddings, laplacians, boundaries, idx = get_embeddings(pos_simplex_labels, to_remove, num_classes, dim=simplex_order)
neg_embeddings, _, _, _ = get_embeddings(neg_simplex_labels, to_remove, num_classes, dim=simplex_order)
print('Laplacians:',[ laplacian.shape if laplacian is not None else 0 for laplacian in laplacians])
print('Boundaries:',[ boundary.shape if boundary is not None else 0 for boundary in boundaries])
print('Embeddings',[ embedding.shape if embedding is not None else 0 for embedding in pos_embeddings])
print('Embeddings',[ embedding.shape if embedding is not None else 0 for embedding in neg_embeddings])
print('Label',pos_label)
print('Label',neg_label)
print('Order:',order, 'Index:',idx)
print('subgraph : ', subgraph.num_nodes())

Laplacians: [torch.Size([64, 64]), torch.Size([696, 696]), torch.Size([1344, 1344]), torch.Size([916, 916])]
Boundaries: [torch.Size([1, 64]), torch.Size([64, 696]), torch.Size([696, 1344]), torch.Size([1344, 916])]
Embeddings [torch.Size([64, 20]), torch.Size([696, 20]), torch.Size([1344, 20]), torch.Size([916, 20])]
Embeddings [torch.Size([64, 20]), torch.Size([696, 20]), torch.Size([1344, 20]), torch.Size([916, 20])]
Label tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0.])
Label tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])
Order: 1 Index: 50
subgraph :  64


In [8]:
_, neg_labels = _get_simplices(dataset, num_classes, True)

# Model Training and Evaluation

In [4]:
from models.model import SimplicialModel1, BaseGNN
import torch
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(f'/home/adarsh/H-KGC/datasets/{dataset}/logs')
gs = 0

cm = SimplicialModel1(num_classes, dim=simplex_order, device=device).to(device)
baseGnn = BaseGNN(num_classes, dim=simplex_order, device=device).to(device)


In [5]:
class MyDataset(Dataset):

    def __init__(self, dim=None) -> None:
        super().__init__()
        self.dim = dim

    def __getitem__(self, index):
        b = 1
        while b!=0:
            simplex, order, pos_label, neg_label = random_sample(dataset, num_classes=num_classes, max_dim=simplex_order, dim=self.dim)  # randomly sample 
            to_remove = frozenset(simplex)
            try:
                subgraph = extract_subgraph(simplex, graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=max_nodes[order])
                isolated_nodes = ((subgraph.in_degrees() == 0) & (subgraph.out_degrees() == 0)).nonzero().squeeze(1)
                subgraph.remove_nodes(isolated_nodes)
                pos_simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes, positive=True)
                neg_simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes, positive=False)
                pos_embeddings, laplacians, boundaries, idx = get_embeddings(pos_simplex_labels, to_remove, num_classes, dim=simplex_order)
                neg_embeddings, _, _, _ = get_embeddings(neg_simplex_labels, to_remove, num_classes, dim=simplex_order)
                b = 0
            except:
                pass
        return pos_embeddings, neg_embeddings, laplacians, boundaries, order, idx, pos_label, subgraph
    
    def __len__(self):
        return 10000

def custom_collate(X):
    return X[0]

dataloader = DataLoader(MyDataset(dim=1), batch_size=1, num_workers=16, collate_fn=custom_collate)

In [6]:
optim1 = torch.optim.Adam(cm.parameters())
optim2 = torch.optim.Adam(baseGnn.parameters())

In [7]:
gs = 0
timeout = 0
loss1 = 0 
loss2 = 0
ep = 0
with tqdm(dataloader) as tepoch:
    for pos_embeddings, neg_embeddings, laplacians, boundaries, order, idx, label, subgraph in tepoch:
        label, subgraph = label.to(device), subgraph.to(device)
        pos_embeddings = [ x.to(device) if x is not None else None for x in pos_embeddings]
        neg_embeddings = [ x.to(device) if x is not None else None for x in neg_embeddings]
        laplacians = [ x.to(device) if x is not None else None for x in laplacians]
        boundaries = [ x.to(device) if x is not None else None for x in boundaries]

        pos_pred = cm(pos_embeddings, laplacians, boundaries, order, idx, label).squeeze()
        neg_pred = cm(neg_embeddings, laplacians, boundaries, order, idx, torch.ones_like(pos_embeddings[0][0])).squeeze()
        loss1 += torch.nn.functional.binary_cross_entropy_with_logits(pos_pred, torch.ones_like(pos_pred)) + torch.nn.functional.binary_cross_entropy_with_logits(neg_pred, torch.zeros_like(neg_pred))
        # loss1 += torch.nn.functional.margin_ranking_loss(pos_pred, neg_pred, label, margin=10, reduction='mean')
        
        subgraph = subgraph.to(device)
        pos_pred = baseGnn(subgraph, pos_embeddings[0], order, label).squeeze()
        neg_pred = baseGnn(subgraph, neg_embeddings[0], order, torch.ones_like(pos_embeddings[0][0])).squeeze()
        loss2 += torch.nn.functional.binary_cross_entropy_with_logits(pos_pred, torch.ones_like(pos_pred)) + torch.nn.functional.binary_cross_entropy_with_logits(neg_pred, torch.zeros_like(neg_pred))
        # loss2 += torch.nn.functional.margin_ranking_loss(pos_pred, neg_pred, label, margin=10, reduction='mean')
        ep+=1
        
        if ep%32==0:
            loss1 = loss1 / (2*32)
            loss2 = loss2 / (2*32)
            optim1.zero_grad()
            loss1.backward() 
            optim1.step()
            optim2.zero_grad()
            loss2.backward()
            optim2.step()
            writer.add_scalars('Train Loss',{'Simplicial CNN': loss1.item(), 'Vanilla GNN': loss2.item()}, gs)

            loss1 = 0
            loss2 = 0
            gs+=1
        torch.cuda.empty_cache()


100%|██████████| 10000/10000 [18:43<00:00,  8.90it/s] 


> The following analysis established increasing the density of sampled subgraph increases the autograd function's time. \
> The preprocessing steps are efficient enough

In [None]:
import cProfile
import pstats

# profile = cProfile.Profile()
# profile.runcall(train)
# ps = pstats.Stats(profile)
# ps.print_stats()

In [8]:
with torch.no_grad():
    ep = 0
    H0 = []
    H1 = []
    H2 = []
    H3 = []
    labels0 = []
    labels1 = []
    with tqdm(dataloader) as tepoch:
        for pos_embeddings, neg_embeddings, laplacians, boundaries, order, idx, label, subgraph in tepoch:
            label, subgraph = label.to(device), subgraph.to(device)
            pos_embeddings = [ x.to(device) if x is not None else None for x in pos_embeddings]
            neg_embeddings = [ x.to(device) if x is not None else None for x in neg_embeddings]
            laplacians = [ x.to(device) if x is not None else None for x in laplacians]
            boundaries = [ x.to(device) if x is not None else None for x in boundaries]
            # try:
            if order > 0: # makes no sense to perform node classification since node embedding will be 0.
                pred0 = (torch.sum(pos_embeddings[0][:order+1], dim=0)!=0).long().squeeze()
                pred1 = (torch.prod(pos_embeddings[0][:order+1], dim=0)!=0).long().squeeze()
                H0.append((pred0==label).long())
                H1.append((pred1==label).long())
                labels0.append(label)
            pred2_pos = baseGnn(subgraph, pos_embeddings[0], order, label)
            pred2_neg = baseGnn(subgraph, neg_embeddings[0], order, label)
            pred3_pos = cm(pos_embeddings, laplacians, boundaries, order, idx, label)
            pred3_neg = cm(neg_embeddings, laplacians, boundaries, order, idx, label)
            # H1.append((torch.round(torch.sigmoid(pred1))==label).long())
            # H2.append((torch.round(torch.sigmoid(pred2))==label).long())
            H2.append(pred2_pos)
            H2.append(pred2_neg)
            H3.append(pred3_pos)
            H3.append(pred3_neg)
            labels1.append(torch.ones_like(pred2_pos))
            labels1.append(torch.zeros_like(pred2_neg))
            ep += 1
            torch.cuda.empty_cache()
            if ep>10000:
                break
            # except:
            #     pass

100%|██████████| 10000/10000 [17:28<00:00,  9.54it/s] 


In [9]:
H0 = torch.stack(H0)

In [10]:
H1 = torch.stack(H1)

In [11]:
labels0 = torch.stack(labels0)

In [12]:
H2 = torch.cat(H2)

In [13]:
H3 = torch.cat(H3)

In [14]:
labels1 = torch.cat(labels1)

In [9]:
H0, H1, H2, H3, labels0, labels1 = torch.stack(H0), torch.stack(H1), torch.stack(H2), torch.stack(H3), torch.stack(labels0), torch.stack(labels1)

RuntimeError: stack expects each tensor to be equal size, but got [1, 1] at entry 0 and [8, 1] at entry 2

## Test Accuracy

In [15]:
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score

### AUC is better with Simplicial GNN

* Cooking dataset

In [16]:
# AUC with union operation
A = roc_auc_score(labels0.cpu().numpy(), H0.cpu().numpy(), average='weighted')
A

0.5231242078102861

In [17]:
# AUC with intersection operation
B = roc_auc_score(labels0.cpu().numpy(), H1.cpu().numpy(),  average='weighted')
B

0.12847889217519948

In [18]:
# AUC with vanilla GNN
C = roc_auc_score(labels1.cpu().numpy(), H2.cpu().numpy(),  average='weighted')
C

0.9864928966053979

In [19]:
# AUC with Simplicial Relational CNN with attention
D = roc_auc_score(labels1.cpu().numpy(), H3.cpu().numpy(),  average='weighted')
D

0.9976887812425005

* MAG 10 dataset

In [25]:
# AUC with union operation
A = roc_auc_score(labels0.cpu().numpy(), H0.cpu().numpy(),  average='weighted')
A

0.5414017346893194

In [26]:
# AUC with intersection operation
B = roc_auc_score(labels0.cpu().numpy(), H1.cpu().numpy(),  average='weighted')
B

0.20451034521941505

In [27]:
# AUC with vanilla GNN
C = roc_auc_score(labels1.cpu().numpy(), H2.cpu().numpy(),  average='weighted')
C

0.47676917645963046

In [28]:
# AUC with Simplicial Relational CNN with attention
D = roc_auc_score(labels1.cpu().numpy(), H3.cpu().numpy(),  average='weighted')
D

0.5840068062314466

* Algebra dataset

In [39]:
mask = (labels0.sum(dim=0)!=0)
labels0 = labels0[:,mask]
labels1 = labels1[:,mask]
H0 = H0[:,mask]
H1 = H1[:,mask]
H2 = H2[:,mask]
H3 = H3[:,mask]

In [40]:
# AUC with union operation
A = roc_auc_score(labels0.cpu().numpy(), H0.cpu().numpy(),  average='weighted')
A

0.6301262673258093

In [41]:
# AUC with intersection operation
B = roc_auc_score(labels0.cpu().numpy(), H1.cpu().numpy(),  average='weighted')
B

0.11706815805412159

In [42]:
# AUC with vanilla GNN
C = roc_auc_score(labels1.cpu().numpy(), H2.cpu().numpy(),  average='weighted')
C

0.4858226434323319

In [43]:
# AUC with Simplicial Relational CNN with attention
D = roc_auc_score(labels1.cpu().numpy(), H3.cpu().numpy(),  average='weighted')
D

0.6857028048070942