In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from vi_ncrp import NCRP, NCRPFit, TreeNode, softmax

In [3]:
np.random.seed(1)

In [4]:
def make_test_tree(sizes):
    if isinstance(sizes, int):
        new_sizes = [[] for i in range(sizes)]
        return make_test_tree(new_sizes)
    if len(sizes) == 0:
        return TreeNode(children = [])
    children = [make_test_tree(s) for s in sizes]
    return TreeNode(children = children)

In [5]:
tree = make_test_tree([[4,1,2],[1,2,3]])

In [6]:
list(tree.inner_and_full_paths())

[(),
 (0,),
 (0, 0),
 (0, 0, 0),
 (0, 0, 1),
 (0, 0, 2),
 (0, 0, 3),
 (0, 1),
 (0, 1, 0),
 (0, 2),
 (0, 2, 0),
 (0, 2, 1),
 (1,),
 (1, 0),
 (1, 0, 0),
 (1, 1),
 (1, 1, 0),
 (1, 1, 1),
 (1, 2),
 (1, 2, 0),
 (1, 2, 1),
 (1, 2, 2)]

In [7]:
list(tree.lookup_path((1,2)).inner_and_full_paths(prefix_so_far=(1,2)))

[(1, 2), (1, 2, 0), (1, 2, 1), (1, 2, 2)]

In [8]:
tree.depth()

3

In [9]:
ncrp = NCRP(progress_bar = 'terminal')
ncrp

NCRP(alphaTheta=array([ 1.,  1.,  1.]), alphaV=0.1, alphaW=0.1, iterations=100, progress_bar='terminal', branch_structure=[5, 5], depth=2)

In [10]:
#data = np.zeros((10, 20))

In [11]:
from load_data import load_data
data = load_data("/Users/aleverentz/Downloads/ap/ap.dat")

Reading lines: 100%|██████████| 2246/2246 [00:00<00:00, 2578.56it/s]
Filling matrix: 100%|██████████| 2246/2246 [00:02<00:00, 972.51it/s] 


In [12]:
data.shape

(2246, 10473)

In [13]:
data = data[:100,:1000]

In [14]:
data.shape

(100, 1000)

In [15]:
data = data[np.where(data.sum(axis=1) > 0)[0], :]

In [16]:
data.shape

(100, 1000)

In [17]:
%%time
f = ncrp.fit(data = data)

Optimizing via coordinate ascent: 100%|██████████| 100/100 [00:02<00:00, 36.68it/s]

CPU times: user 2.72 s, sys: 32.9 ms, total: 2.75 s
Wall time: 2.75 s





In [18]:
most_likely_paths = np.argmax(f.path_prob, axis=1)
{k: f.index_to_path[k] for k in np.unique(most_likely_paths)}

{7: (1,), 10: (1, 2)}

In [19]:
f.get_most_likely_paths()

array([(1,), (1,), (1,), (1,), (1,), (1,), (1, 2), (1,), (1,), (1,), (1,),
       (1, 2), (1,), (1,), (1, 2), (1, 2), (1,), (1,), (1, 2), (1,), (1,),
       (1,), (1,), (1,), (1, 2), (1,), (1,), (1, 2), (1,), (1,), (1, 2),
       (1,), (1,), (1,), (1, 2), (1,), (1,), (1,), (1,), (1, 2), (1, 2),
       (1, 2), (1,), (1,), (1,), (1,), (1,), (1,), (1,), (1,), (1,),
       (1, 2), (1,), (1,), (1,), (1,), (1,), (1, 2), (1,), (1,), (1, 2),
       (1,), (1, 2), (1,), (1, 2), (1,), (1, 2), (1, 2), (1,), (1,), (1,),
       (1,), (1,), (1,), (1,), (1, 2), (1,), (1,), (1,), (1, 2), (1,),
       (1,), (1,), (1,), (1,), (1,), (1, 2), (1, 2), (1, 2), (1,), (1,),
       (1,), (1,), (1,), (1, 2), (1,), (1,), (1,), (1, 2), (1,)], dtype=object)