In [1]:
import sys
sys.path.append('../..')

import warnings
warnings.filterwarnings('ignore')

from preprocess.process_dataset import get_dgl_graph
from preprocess.subgraph_extraction import extract_subgraph
from preprocess.graph_to_simplicial_complex import get_simplicial_complex, get_embeddings, _get_simplices, random_sample, _get_neg_simplices
from hodgelaplacians import HodgeLaplacians
from global_parameters import device
from layers.simplicial_convolution import SimplicialAttentionLayer, SimplicialConvolution
import numpy as np
import timeit
from torch.utils.data import Dataset, DataLoader
import random
import torch
import os
from tqdm import tqdm

def set_seed(seed = 42):
    '''
        For Reproducibility: Sets the seed of the entire notebook.
    '''

    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
    # Sets a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(2)

from grail.model.dgl.graph_classifier import GraphClassifier as dgl_model
from grail.subgraph_extraction.graph_sampler import node_label

ModuleNotFoundError: No module named 'utils'

In [2]:
# Our model performs best on this dataset. Results are available below
dataset = 'cat_edge_cooking'
num_classes = 20
simplex_order=4
max_nodes = [10, 100, 150, 200] # parameters to control the sparsity of the sampled graph

graph, nx_graph = get_dgl_graph(dataset)

> loading cat_edge_cooking dataset
> loading nx_graph from cache
> loading dgl_graph from cache


In [3]:
@torch.no_grad()
def random_sample(dataset, num_classes, max_dim=4):
    '''
    Parameters:
    -----------
    dataset : str, name of the dataset e.g. cat_edge_cooking, cat_edge_DAWN etc.
    num_classes : int, number of classes to which a simplex can be classified into
    max_dim : int, the highest order of simplex allowed.

    Returns:
    --------
    simplex: set, a set of vertices
    dim: int, dimension of the sampled simplex
    label: tensor, a binary tensor of size (num_classes,)
    '''
    dim = 1
    simplicies, labels = _get_simplices(dataset, num_classes)
    simplex = np.random.choice(simplicies)
    while len(simplex) < dim+1:
        simplex = np.random.choice(simplicies)
    simplex = set(np.random.choice(list(simplex), size=dim+1, replace=False).tolist())
    label = torch.zeros((num_classes,))
    for sim, lab in zip(simplicies, labels):
        if simplex.issubset(sim):
            label[lab] = 1
    return simplex, dim, label

In [4]:
simplex, order, label = random_sample(dataset, num_classes=num_classes, max_dim=simplex_order)
to_remove = frozenset(simplex)
subgraph = extract_subgraph(simplex, graph, h=4, enclosing_sub_graph=True, max_nodes_per_hop=max_nodes[order])
isolated_nodes = ((subgraph.in_degrees() == 0) & (subgraph.out_degrees() == 0)).nonzero().squeeze(1)
subgraph.remove_nodes(isolated_nodes)
simplex_labels = get_simplicial_complex(subgraph, graph, nx_graph, dataset, num_classes)
pos_embeddings, neg_embeddings, laplacians, boundaries, idx = get_embeddings(simplex_labels, to_remove, num_classes, dim=simplex_order)
print('Laplacians:',[ laplacian.shape if laplacian is not None else 0 for laplacian in laplacians])
print('Boundaries:',[ boundary.shape if boundary is not None else 0 for boundary in boundaries])
print('Embeddings',[ embedding.shape if embedding is not None else 0 for embedding in pos_embeddings])
print('Embeddings',[ embedding.shape if embedding is not None else 0 for embedding in neg_embeddings])
print('Label',label)
print('Order:',order, 'Index:',idx)
print('subgraph : ', subgraph.num_nodes())

Laplacians: [torch.Size([26, 26]), torch.Size([98, 98]), torch.Size([100, 100]), torch.Size([34, 34])]
Boundaries: [torch.Size([1, 26]), torch.Size([26, 98]), torch.Size([98, 100]), torch.Size([100, 34])]
Embeddings [torch.Size([26, 20]), torch.Size([98, 20]), torch.Size([100, 20]), torch.Size([34, 20])]
Embeddings [torch.Size([26, 20]), torch.Size([98, 20]), torch.Size([100, 20]), torch.Size([34, 20])]
Label tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        0., 0.])
Order: 1 Index: 79
subgraph :  26


In [43]:
simplices, labels = _get_simplices(dataset, num_classes)
edges = torch.stack(graph.edges()).permute(1,0).numpy()
nodes = set(graph.nodes())
simplex_labels = {}
visited = set()
for u,v in tqdm(edges):
    label = np.zeros(num_classes)
    for simplex_index in nx_graph[u][v]['hyperedge_index']:
        label[labels[simplex_index]] = 1
    nx_graph[u][v]['label'] = np.flatnonzero(label)

100%|██████████| 959868/959868 [00:08<00:00, 119393.90it/s]


In [44]:
import networkx as nx
triplets = []
for u,v in tqdm(edges):
    for r in nx_graph[u][v]['label']:
        triplets.append((u,r,v))
from sklearn.model_selection import train_test_split
train, test = train_test_split(triplets)
with open('/home/adarsh/H-KGC/grail/data/cat_edge_cooking/train.txt','w') as f:
    for u,r,v in tqdm(train):
        f.write(f'{u}\t{r}\t{v}\n')
with open('/home/adarsh/H-KGC/grail/data/cat_edge_cooking/valid.txt','w') as f:
    for u,r,v in tqdm(test):
        f.write(f'{u}\t{r}\t{v}\n')

100%|██████████| 959868/959868 [00:03<00:00, 275319.51it/s]
100%|██████████| 1029736/1029736 [00:01<00:00, 596798.40it/s]
100%|██████████| 343246/343246 [00:00<00:00, 584606.42it/s]


In [None]:
triplets = {}


In [None]:
labels, _ = node_label(subgraph.adj(scipy_fmt='coo').todense(), 3)