In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline

In [2]:
import scipy.cluster.hierarchy as sch

In [3]:
import context
from hier_clust import linkage_util, HierClust
from hier_clust.tree_util import Tree

In [4]:
from gen_data_util import gen_data, plot_tree_overlay

In [5]:
np.random.seed(1)

In [6]:
depth = 16
n_obs = 2 ** depth
n_dim = 10

In [7]:
%%time
x, y = gen_data(depth = depth, n_dim = n_dim, depth_labels = 3)

CPU times: user 2.29 s, sys: 13.2 ms, total: 2.31 s
Wall time: 2.31 s


In [8]:
indices = np.arange(n_obs)
np.random.shuffle(indices)
x = x[indices]
y = y[indices]

In [9]:
print(x.shape)
print(y.shape)

(65536, 10)
(65536,)


In [10]:
hc = HierClust(n_neighbors=10, neighbor_graph_strategy='balltree')
# hc = HierClust(n_neighbors=10, neighbor_graph_strategy='rptree')

In [11]:
import time_util; reload(time_util); del time_util
from time_util import TimingRegistry

In [12]:
reg = TimingRegistry()

In [None]:
from scipy.sparse import dia_matrix

In [None]:
with reg.timer("distances") as timer:
    dist = hc._get_distances(x)
with reg.timer("connected components") as timer:
    components = hc._get_connected_components(dist)
with reg.timer("similarity") as timer:
    similarity = hc._get_similarity(dist)
with reg.timer("compute laplacian") as timer:
    diag = similarity.sum(axis = 0)
    diag = dia_matrix((diag, [0]), (n_obs, n_obs)).tocsr()
    laplacian = diag - similarity
with reg.timer("get_fiedler_vector") as timer:
    fiedler_vector = hc._get_fiedler_vector(laplacian)

with reg.timer("fit") as timer:
    fit_result = hc.fit(x)

with reg.timer("complete linkage") as timer:
    link = sch.linkage(x, method='complete')
with reg.timer("convert linkage to tree") as timer:
    t = linkage_util.linkage_to_tree(link)

Elapsed time: 4.083 seconds (distances)
Elapsed time: 1.188 seconds (connected components)
Elapsed time: 0.152 seconds (similarity)
Elapsed time: 0.010 seconds (compute laplacian)
Elapsed time: 0.530 seconds (get_fiedler_vector)
Elapsed time: 178.043 seconds (fit)


In [None]:
print("n_obs\t" + "\t".join(reg.registry.keys()))

In [None]:
print(str(n_obs) + "\t" + "\t".join(map(lambda x: str(x * 1000), reg.registry.values())))

##### Cleanup

In [None]:
plt.close()