In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from vi_ncrp import NCRP, NCRPFit, TreeNode, softmax

In [3]:
def make_test_tree(sizes):
    if isinstance(sizes, int):
        new_sizes = [[] for i in range(sizes)]
        return make_test_tree(new_sizes)
    if len(sizes) == 0:
        return TreeNode(children = [])
    children = [make_test_tree(s) for s in sizes]
    return TreeNode(children = children)

In [4]:
tree = make_test_tree([[4,1,2],[1,2,3]])

In [5]:
list(tree.inner_and_full_paths())

[(),
 (0,),
 (0, 0),
 (0, 0, 0),
 (0, 0, 1),
 (0, 0, 2),
 (0, 0, 3),
 (0, 1),
 (0, 1, 0),
 (0, 2),
 (0, 2, 0),
 (0, 2, 1),
 (1,),
 (1, 0),
 (1, 0, 0),
 (1, 1),
 (1, 1, 0),
 (1, 1, 1),
 (1, 2),
 (1, 2, 0),
 (1, 2, 1),
 (1, 2, 2)]

In [6]:
list(tree.lookup_path((1,2)).inner_and_full_paths(prefix_so_far=(1,2)))

[(1, 2), (1, 2, 0), (1, 2, 1), (1, 2, 2)]

In [7]:
tree.depth()

3

In [8]:
ncrp = NCRP(progress_bar = 'notebook')
ncrp

NCRP(alphaTheta=[ 1.  1.  1.], alphaV=0.1, alphaW=0.1, iterations=100, progress_bar=notebook, depth=2)

In [9]:
#data = np.zeros((10, 20))

In [10]:
from load_data import load_data
data = load_data("/Users/aleverentz/Downloads/ap/ap.dat")





In [17]:
data = data[:100,:100]

In [18]:
data.shape

(100, 100)

In [19]:
%%time
f = ncrp.fit(data = data)

CPU times: user 5min 14s, sys: 1.24 s, total: 5min 15s
Wall time: 5min 15s


In [20]:
f.alphaTheta_var

array([[ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.9993005 ,  1.99926545,  1.99926541],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.00000146,  1.00000157,  1.00000157],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        ,  1.   

In [21]:
f.phi_var

array([[  3.26097349e-049,   2.64295848e-049,   2.64186881e-049],
       [  1.51113196e-108,   9.34508265e-109,   9.33723775e-109],
       [  4.17323382e-098,   2.55120113e-098,   2.54877143e-098],
       [  4.47857331e-077,   3.18197439e-077,   3.17980953e-077],
       [  1.17219459e-035,   9.68043251e-036,   9.67636678e-036],
       [  4.91249137e-175,   2.54594386e-175,   2.54276815e-175],
       [  3.14469480e-072,   2.61532362e-072,   2.61434586e-072],
       [  1.19807153e-028,   1.15129565e-028,   1.15117398e-028],
       [  9.99306504e-001,   9.99271654e-001,   9.99271612e-001],
       [  4.37474348e-047,   3.66240763e-047,   3.66091627e-047],
       [  1.30489850e-026,   1.16193280e-026,   1.16163468e-026],
       [  1.94487147e-018,   1.82496514e-018,   1.82460224e-018],
       [  4.32515045e-073,   2.88503584e-073,   2.88270232e-073],
       [  1.44207585e-006,   1.55473562e-006,   1.55487295e-006],
       [  1.49322575e-088,   6.20688202e-089,   6.19330275e-089],
       [  

In [22]:
f.tree_var_params

{(): {'alphaV_var': 1.0,
  'alphaW_var': array([  49.5       ,   61.5       ,   69.50001218,   41.5       ,
           46.        ,   30.        ,   26.        ,   33.50006912,
           32.5       ,   29.5       ,   30.50017017,   20.        ,
           23.        ,   29.5       ,  124.35369245,   28.5       ,
           31.50006912,   16.        ,   23.        ,   17.5       ,
           20.5       ,   15.        ,   15.5       ,   21.        ,
           16.50008509,   15.5       ,   17.6454717 ,   14.5       ,
            7.5       ,   14.        ,    6.5       ,   17.        ,
           18.00034404,   14.5       ,   14.50006912,   16.5       ,
           20.        ,   16.        ,    8.        ,   25.00008509,
           18.5       ,    9.5       ,   11.        ,   12.5       ,
           18.5       ,    9.5       ,   16.        ,   11.50000004,
           13.5       ,    9.        ,   18.        ,    6.5       ,
            8.5       ,   11.64545948,    9.5       ,   16.00000

In [23]:
f.tree_stats

{(): {'Eqln1_V': -1.0000000000000002,
  'EqlnV': -1.0000000000000002,
  'EqlnW': array([-3.61235931, -3.3933119 , -3.27008153, -3.79060011, -3.68646451,
         -4.11975883, -4.26545445, -4.00764906, -4.03842037, -4.1368516 ,
         -4.10294764, -4.53367297, -4.39059937, -4.1368516 , -2.69319435,
         -4.17193931, -4.07016419, -4.76318363, -4.39059937, -4.67083953,
         -4.50836056, -4.8298503 , -4.79596172, -4.48367297, -4.7314403 ,
         -4.79596172, -4.66236178, -4.86492723, -5.55743948, -4.90127887,
         -5.71128563, -4.70068363, -4.64184053, -4.86492723, -4.86492232,
         -4.73144559, -4.53367297, -4.76318363, -5.48855548, -4.30545099,
         -4.61369667, -5.30645908, -5.15244437, -5.01900131, -4.61369667,
         -5.30645908, -4.76318363, -5.10595783, -4.93900131, -5.36355548,
         -4.6418601 , -5.71128563, -5.42410614, -5.0928922 , -5.30645908,
         -4.76318363, -4.97820195, -4.90127887, -5.06153528, -4.73144559,
         -4.90127887, -4.97820195