In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from vi_ncrp import NCRP, NCRPFit, TreeNode, softmax

In [3]:
def make_test_tree(sizes):
    if isinstance(sizes, int):
        new_sizes = [[] for i in range(sizes)]
        return make_test_tree(new_sizes)
    if len(sizes) == 0:
        return TreeNode(children = [])
    children = [make_test_tree(s) for s in sizes]
    return TreeNode(children = children)

In [4]:
tree = make_test_tree([[4,1,2],[1,2,3]])

In [5]:
list(tree.inner_and_full_paths())

[(),
 (0,),
 (0, 0),
 (0, 0, 0),
 (0, 0, 1),
 (0, 0, 2),
 (0, 0, 3),
 (0, 1),
 (0, 1, 0),
 (0, 2),
 (0, 2, 0),
 (0, 2, 1),
 (1,),
 (1, 0),
 (1, 0, 0),
 (1, 1),
 (1, 1, 0),
 (1, 1, 1),
 (1, 2),
 (1, 2, 0),
 (1, 2, 1),
 (1, 2, 2)]

In [6]:
list(tree.lookup_path((1,2)).inner_and_full_paths(prefix_so_far=(1,2)))

[(1, 2), (1, 2, 0), (1, 2, 1), (1, 2, 2)]

In [7]:
tree.depth()

3

In [8]:
ncrp = NCRP(progress_bar = 'notebook')
ncrp

NCRP(alphaTheta=[ 1.  1.  1.], alphaV=0.1, alphaW=0.1, iterations=100, progress_bar=notebook, depth=2)

In [9]:
%%time
f = ncrp.fit(data = np.zeros((10, 20)))


CPU times: user 789 ms, sys: 23.2 ms, total: 812 ms
Wall time: 828 ms


In [10]:
f.alphaTheta_var

array([[ 1.1,  1.1,  1.1],
       [ 1.1,  1.1,  1.1],
       [ 1.1,  1.1,  1.1],
       [ 1.1,  1.1,  1.1],
       [ 1.1,  1.1,  1.1],
       [ 1.1,  1.1,  1.1],
       [ 1.1,  1.1,  1.1],
       [ 1.1,  1.1,  1.1],
       [ 1.1,  1.1,  1.1],
       [ 1.1,  1.1,  1.1]])

In [11]:
f.phi_var

array([[ 0.1,  0.1,  0.1],
       [ 0.1,  0.1,  0.1],
       [ 0.1,  0.1,  0.1],
       [ 0.1,  0.1,  0.1],
       [ 0.1,  0.1,  0.1],
       [ 0.1,  0.1,  0.1],
       [ 0.1,  0.1,  0.1],
       [ 0.1,  0.1,  0.1],
       [ 0.1,  0.1,  0.1],
       [ 0.1,  0.1,  0.1]])

In [12]:
f.tree_var_params

{(): {'alphaV_var': 1.0,
  'alphaW_var': array([ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
          1.,  1.,  1.,  1.,  1.,  1.,  1.]),
  'betaV_var': 1.0},
 (0,): {'alphaV_var': 7.1606911472098496,
  'alphaW_var': array([ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
          1.,  1.,  1.,  1.,  1.,  1.,  1.]),
  'betaV_var': 3.6803455736049275},
 (0, 0): {'alphaV_var': 2.0003516173988505,
  'alphaW_var': array([ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
          1.,  1.,  1.,  1.,  1.,  1.,  1.]),
  'betaV_var': 1.1003516173988506}}

In [13]:
f.tree_stats

{(): {'Eqln1_V': -1.0000000000000002,
  'EqlnV': -1.0000000000000002,
  'EqlnW': array([-3.54773966, -3.54773966, -3.54773966, -3.54773966, -3.54773966,
         -3.54773966, -3.54773966, -3.54773966, -3.54773966, -3.54773966,
         -3.54773966, -3.54773966, -3.54773966, -3.54773966, -3.54773966,
         -3.54773966, -3.54773966, -3.54773966, -3.54773966, -3.54773966]),
  'j0': 1,
  'logS': array([ 8.51770165,  8.51770165,  8.51770165,  8.51770165,  8.51770165,
          8.51770165,  8.51770165,  8.51770165,  8.51770165,  8.51770165]),
  'path_prob': array([ 0.74193028,  0.74193028,  0.74193028,  0.74193028,  0.74193028,
          0.74193028,  0.74193028,  0.74193028,  0.74193028,  0.74193028]),
  'z0': array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])},
 (0,): {'Eqln1_V': -1.1754677064399226,
  'EqlnV': -0.43935035380361254,
  'EqlnW': array([-3.54773966, -3.54773966, -3.54773966, -3.54773966, -3.54773966,
         -3.54773966, -3.54773966, -3.54773966, -3.54773966, -3.547