In [1]:
import sys
sys.path.append('../..')

import warnings
warnings.filterwarnings('ignore')

from preprocess.process_dataset import get_dgl_graph
from preprocess.subgraph_extraction import extract_subgraph
from preprocess.graph_to_simplicial_complex import get_simplicial_complex, get_embeddings, _get_simplices, random_sample
from hodgelaplacians import HodgeLaplacians
from global_parameters import device
from layers.simplicial_convolution import SimplicialAttentionLayer, SimplicialConvolution
import numpy as np
import timeit
from torch.utils.data import Dataset, DataLoader
import random
import torch
import os

def set_seed(seed = 42):
    '''
        For Reproducibility: Sets the seed of the entire notebook.
    '''

    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    # Sets a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(2)

In [2]:
# Author mentions this dataset to have 10 classes. but the labels shows more than 1500 classes with severe class imbalance.
dataset = 'cat_edge_DAWN' # problem with dataset
num_classes = 10
simplex_order=4

In [2]:
# Our model performs best on this dataset. Results are available below
dataset = 'cat_edge_cooking'
num_classes = 20
simplex_order=4
max_nodes = [10, 100, 150, 200] # parameters to control the sparsity of the sampled graph

In [2]:
# AUC > 0.5 for all classes but does not achieve the highest auc. The model with highest AUC has AUC < 0.5 for some classes.
dataset = 'cat_edge_MAG_10'
num_classes = 10
simplex_order=4
max_nodes = [10, 80, 100, 150]

In [39]:
# Our model performs best on this dataset. Results are available below
dataset = 'cat_edge_algebra_questions'
num_classes = 32
simplex_order = 4
max_nodes = [3, 20, 25, 30]

In [3]:
graph, nx_graph = get_dgl_graph(dataset)

> loading cat_edge_MAG_10 dataset
> loading nx_graph from cache
> loading dgl_graph from cache


# Sampling

> Hyperparameters for graph subsampling controlling the number of nodes sampled at each hop corresponding to each simplex order.

In [None]:
simplex, order, label = random_sample(dataset, num_classes=num_classes, max_dim=simplex_order)
to_remove = frozenset(simplex)
subgraph = extract_subgraph(simplex, graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=max_nodes[order])
isolated_nodes = ((subgraph.in_degrees() == 0) & (subgraph.out_degrees() == 0)).nonzero().squeeze(1)
subgraph.remove_nodes(isolated_nodes)
simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes)
embeddings, laplacians, boundaries, idx = get_embeddings(simplex_labels, to_remove, num_classes, dim=simplex_order)
print('Laplacians:',[ laplacian.shape if laplacian is not None else 0 for laplacian in laplacians])
print('Boundaries:',[ boundary.shape if boundary is not None else 0 for boundary in boundaries])
print('Embeddings',[ embedding.shape if embedding is not None else 0 for embedding in embeddings])
print('Label',label)
print('Order:',order, 'Index:',idx)
print('subgraph : ', subgraph.num_nodes())

# Model Training and Evaluation

In [4]:
from models.model import SimplicialModel1, BaseGNN
import torch
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(f'/home/adarsh/H-KGC/datasets/{dataset}/logs')
gs = 0

cm = SimplicialModel1(num_classes, dim=simplex_order, device=device).to(device)
baseGnn = BaseGNN(num_classes, dim=simplex_order, device=device).to(device)


In [5]:
class MyDataset(Dataset):

    def __init__(self) -> None:
        super().__init__()

    def __getitem__(self, index):
        b = 1
        while b!=0:
            simplex, order, label = random_sample(dataset, num_classes=num_classes, max_dim=simplex_order)  # randomly sample 
            to_remove = frozenset(simplex)
            try:
                subgraph = extract_subgraph(simplex, graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=max_nodes[order])
                isolated_nodes = ((subgraph.in_degrees() == 0) & (subgraph.out_degrees() == 0)).nonzero().squeeze(1)
                subgraph.remove_nodes(isolated_nodes)
                simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes)
                embeddings, laplacians, boundaries, idx = get_embeddings(simplex_labels, to_remove, num_classes, dim=simplex_order)
                b = 0
            except:
                pass
        return embeddings, laplacians, boundaries, order, idx, label, subgraph
    
    def __len__(self):
        return 10000

def custom_collate(X):
    return X[0]

dataloader = DataLoader(MyDataset(), batch_size=1, num_workers=12, collate_fn=custom_collate)

In [6]:
optim1 = torch.optim.Adam(cm.parameters(), lr=1e-5, betas=(0.8, 0.9))
optim2 = torch.optim.Adam(baseGnn.parameters(), lr=1e-5, betas=(0.8, 0.9))

In [7]:
gs = 0
timeout = 0
loss1 = 0 
loss2 = 0
ep = 0
with tqdm(dataloader) as tepoch:
    for embeddings, laplacians, boundaries, order, idx, label, subgraph in tepoch:
        label, subgraph = label.to(device), subgraph.to(device)
        embeddings = [ x.to(device) if x is not None else None for x in embeddings]
        laplacians = [ x.to(device) if x is not None else None for x in laplacians]
        boundaries = [ x.to(device) if x is not None else None for x in boundaries]

        try:
            pred = cm(embeddings, laplacians, boundaries, order, idx).squeeze()
            loss1 += torch.nn.functional.binary_cross_entropy_with_logits(pred, label, reduction='sum')
            
            
            subgraph = subgraph.to(device)
            pred = baseGnn(subgraph, embeddings[0], order).squeeze()
            loss2 += torch.nn.functional.binary_cross_entropy_with_logits(pred, label, reduction='sum')

            # if ep%4==0:
            optim1.zero_grad()
            loss1.backward() 
            optim1.step()
            optim2.zero_grad()
            loss2.backward()
            optim2.step()
            writer.add_scalars('Train Loss',{'Simplicial CNN': loss1.item(), 'Vanilla GNN': loss2.item()}, gs)

            gs+=1
            torch.cuda.empty_cache()

            loss1 = 0
            loss2 = 0
        except:
            pass

100%|██████████| 10000/10000 [26:30<00:00,  6.29it/s] 


> The following analysis established increasing the density of sampled subgraph increases the autograd function's time. \
> The preprocessing steps are efficient enough

In [None]:
import cProfile
import pstats

# profile = cProfile.Profile()
# profile.runcall(train)
# ps = pstats.Stats(profile)
# ps.print_stats()

In [8]:
with torch.no_grad():
    ep = 0
    H0 = []
    H1 = []
    H2 = []
    H3 = []
    labels0 = []
    labels1 = []
    with tqdm(dataloader) as tepoch:
        for embeddings, laplacians, boundaries, order, idx, label, subgraph in tepoch:
            label, subgraph = label.to(device), subgraph.to(device)
            embeddings = [ x.to(device) if x is not None else None for x in embeddings]
            laplacians = [ x.to(device) if x is not None else None for x in laplacians]
            boundaries = [ x.to(device) if x is not None else None for x in boundaries]
            try:
                if order > 0: # makes no sense to perform node classification since node embedding will be 0.
                    pred0 = (torch.sum(embeddings[0][:order+1], dim=0)!=0).long().squeeze()
                    pred1 = (torch.prod(embeddings[0][:order+1], dim=0)!=0).long().squeeze()
                    H0.append((pred0==label).long())
                    H1.append((pred1==label).long())
                    labels0.append(label)
                pred2 = baseGnn(subgraph, embeddings[0], order).squeeze()
                pred3 = cm(embeddings, laplacians, boundaries, order, idx).squeeze()
                # H1.append((torch.round(torch.sigmoid(pred1))==label).long())
                # H2.append((torch.round(torch.sigmoid(pred2))==label).long())
                H2.append(torch.sigmoid(pred2))
                H3.append(torch.sigmoid(pred3))
                labels1.append(label)
                ep += 1
                torch.cuda.empty_cache()
                if ep>5000:
                    break
            except:
                pass

 50%|█████     | 5000/10000 [13:07<13:07,  6.35it/s]  


In [9]:
H0, H1, H2, H3, labels0, labels1 = torch.stack(H0), torch.stack(H1), torch.stack(H2), torch.stack(H3), torch.stack(labels0), torch.stack(labels1)

## Test Accuracy

In [10]:
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score

### AUC is better with Simplicial GNN

* Cooking dataset

In [12]:
# AUC with union operation
A = roc_auc_score(labels0.cpu().numpy(), H0.cpu().numpy(), average=None)
A

array([0.51911765, 0.51687063, 0.64657611, 0.56582912, 0.53607287,
       0.79921146, 0.89523652, 0.63794053, 0.8654661 , 0.53434365,
       0.70567372, 0.57837271, 0.7574628 , 0.6710991 , 0.80998623,
       0.62403343, 0.68920112, 0.72598578, 0.59171389, 0.89688298])

In [13]:
# AUC with intersection operation
B = roc_auc_score(labels0.cpu().numpy(), H1.cpu().numpy(), average=None)
B

array([0.04292717, 0.03817016, 0.08655175, 0.06787192, 0.04947368,
       0.18579349, 0.33087085, 0.07921127, 0.26067347, 0.04367146,
       0.12316499, 0.06941253, 0.16130686, 0.10876274, 0.19959445,
       0.07345928, 0.10307504, 0.14082343, 0.05486939, 0.3230642 ])

In [14]:
# AUC with vanilla GNN
C = roc_auc_score(labels1.cpu().numpy(), H2.cpu().numpy(), average=None)
C

array([0.37811031, 0.36809629, 0.42570043, 0.3550341 , 0.37381868,
       0.45743081, 0.55753017, 0.3365722 , 0.5992221 , 0.45996608,
       0.44736236, 0.29196856, 0.61854158, 0.39790765, 0.49849615,
       0.42312672, 0.36396334, 0.43439689, 0.4093961 , 0.68208262])

In [15]:
# AUC with Simplicial Relational CNN with attention
D = roc_auc_score(labels1.cpu().numpy(), H3.cpu().numpy(), average=None)
D

array([0.84475226, 0.85036659, 0.85169009, 0.82380886, 0.86718511,
       0.81179283, 0.79583433, 0.8389978 , 0.80167683, 0.85957148,
       0.84274947, 0.83867689, 0.82096279, 0.84412427, 0.83087325,
       0.84676633, 0.85862406, 0.84563099, 0.84680219, 0.78288959])

* MAG 10 dataset

In [11]:
# AUC with union operation
A = roc_auc_score(labels0.cpu().numpy(), H0.cpu().numpy(), average=None)
A

array([0.52428818, 0.51135579, 0.49754948, 0.55378302, 0.48126553,
       0.44210296, 0.59513961, 0.56304752, 0.53522168, 0.52181259])

In [12]:
# AUC with intersection operation
B = roc_auc_score(labels0.cpu().numpy(), H1.cpu().numpy(), average=None)
B

array([0.20325682, 0.22952984, 0.15315415, 0.19937588, 0.24867593,
       0.18684313, 0.28170267, 0.16078714, 0.18261795, 0.18116036])

In [13]:
# AUC with vanilla GNN
C = roc_auc_score(labels1.cpu().numpy(), H2.cpu().numpy(), average=None)
C

array([0.3858733 , 0.80293204, 0.55857195, 0.66615752, 0.38909283,
       0.66376762, 0.90069397, 0.86292019, 0.76499934, 0.8100246 ])

In [14]:
# AUC with Simplicial Relational CNN with attention
D = roc_auc_score(labels1.cpu().numpy(), H3.cpu().numpy(), average=None)
D

array([0.58232999, 0.56074748, 0.632423  , 0.6086886 , 0.60977494,
       0.68004316, 0.5343723 , 0.57903281, 0.56640076, 0.56211353])

* Algebra dataset

In [48]:
mask = (labels0.sum(dim=0)!=0)
labels0 = labels0[:,mask]
labels1 = labels1[:,mask]
H0 = H0[:,mask]
H1 = H1[:,mask]
H2 = H2[:,mask]
H3 = H3[:,mask]

In [53]:
mask

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False, False,  True, False,
        False,  True,  True,  True, False,  True,  True,  True,  True,  True,
         True,  True], device='cuda:0')

In [49]:
# AUC with union operation
A = roc_auc_score(labels0.cpu().numpy(), H0.cpu().numpy(), average=None)
A

array([0.60316548, 0.87834481, 0.83421509, 0.68688603, 0.48682088,
       0.27459993, 0.56340304, 0.48334466, 0.48322785, 0.7470315 ,
       0.41616334, 0.37599125, 0.45169791, 0.23253115, 0.47801609,
       0.50146823, 0.3800952 , 0.39444916, 0.36623563, 0.48744842,
       0.41693383, 0.31519287, 0.50080043, 0.41800005, 0.33560242,
       0.46415809, 0.339969  ])

In [50]:
# AUC with intersection operation
B = roc_auc_score(labels0.cpu().numpy(), H1.cpu().numpy(), average=None)
B

array([1.43371069e-01, 3.58126787e-01, 3.16738202e-01, 1.78085657e-01,
       5.96591312e-02, 1.33761370e-04, 1.24590954e-01, 1.05375915e-01,
       1.11379784e-01, 1.58971410e-01, 5.98372883e-02, 6.00560569e-02,
       1.09904333e-01, 5.00000000e-02, 7.65059132e-02, 3.33333333e-01,
       1.00000000e-01, 6.29124278e-02, 3.47515758e-02, 9.18181941e-02,
       0.00000000e+00, 3.25270263e-02, 5.00000000e-01, 2.23404255e-01,
       0.00000000e+00, 1.22641509e-01, 2.73972603e-02])

In [51]:
# AUC with vanilla GNN
C = roc_auc_score(labels1.cpu().numpy(), H2.cpu().numpy(), average=None)
C

array([0.3367712 , 0.67368304, 0.30531657, 0.60858596, 0.31999774,
       0.68652722, 0.34995078, 0.45771285, 0.72481168, 0.2977402 ,
       0.51995944, 0.58102021, 0.6737459 , 0.26598403, 0.32966697,
       0.25858755, 0.46942063, 0.61100428, 0.47171285, 0.57529979,
       0.44915865, 0.3705732 , 0.35779975, 0.69518344, 0.32587163,
       0.81403198, 0.61005244])

In [52]:
# AUC with Simplicial Relational CNN with attention
D = roc_auc_score(labels1.cpu().numpy(), H3.cpu().numpy(), average=None)
D

array([0.67236035, 0.65286182, 0.67391915, 0.66077263, 0.64479844,
       0.81709554, 0.65022366, 0.70133256, 0.63881042, 0.68741815,
       0.74258188, 0.77596711, 0.60744694, 0.84673106, 0.76292101,
       0.84137966, 0.77096748, 0.72950116, 0.76455381, 0.78663134,
       0.82129677, 0.79102923, 0.84795534, 0.70298919, 0.875418  ,
       0.57521282, 0.69547062])