In [1]:
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

from tqdm import tqdm
from itertools import combinations
from scipy.sparse import csr_matrix

In [2]:
K = 10_000
basedir = "./lshtc"

In [3]:
def load_dataset(filename="train-remapped.csv", nmax=1_000_000_000_000):
    with open(filename, "r") as f:
        lines = f.readlines()

    class_set = set()
    labels = []
    features = []
    for l, line in tqdm(enumerate(lines), total=len(lines)-1):
        if l > nmax: break
        if l == 0: continue
        line = line.strip().split(" ")
        label = []
        feature = {}
        for element in line:
            if ":" not in element:
                element = int(element.replace(",", ""))
                class_set.add(element)
                label.append(element)
            else:
                feature_id = int(element.split(":")[0])
                feature_value = int(element.split(":")[1])
                feature[feature_id] = feature_value
        labels.append(label)
        features.append(feature)
    return class_set, features, labels

def filter_dataset(classes, X, Y, f):
    Xnew, Ynew = [], []
    for _x, _y in zip(X, Y):
        if f(_x, _y):
            Xnew.append(_x)
            Ynew.append(_y)
    classes_new = set([val for sublist in Ynew for val in sublist])
    return classes_new, Xnew, Ynew


In [4]:
classes, X, Y = load_dataset(f"{basedir}/train-remapped.csv")

 84%|████████▍ | 1984489/2365436 [01:06<00:12, 29929.25it/s]


KeyboardInterrupt: 

# Get the graph... 

In [None]:
def get_graph(hierarchy_file="hierarchy.txt"):
    with open(hierarchy_file, "r") as f:
        lines = f.readlines()
    G = nx.Graph()
    for l, line in tqdm(enumerate(lines), total=len(lines)-1):
        a, b = line.split(' ')
        a = int(a.strip())
        b = int(b.strip())
        if a in classes or b in classes:
            if a not in G.nodes():
                G.add_node(a)
            if b not in G.nodes():
                G.add_node(b)
            G.add_edge(a, b)
    return G

# Get the largest connected component in the graph... 
G = get_graph(f"{basedir}/hierarchy.txt")
G_components = [G.subgraph(cc_G) for cc_G in nx.connected_components(G)]
G_ours = G_components[np.argmax([len(G_c.nodes()) for G_c in G_components])] 
len(G_ours.nodes())

# Graph summarization

In [None]:
G