In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [33]:
%autosave 10

Autosaving every 10 seconds


In [2]:
from vi_ncrp import NCRP, NCRPFit, TreeNode, softmax

In [3]:
from collections import Counter

In [4]:
np.random.seed(2)

In [5]:
def make_test_tree(sizes):
    if isinstance(sizes, int):
        new_sizes = [[] for i in range(sizes)]
        return make_test_tree(new_sizes)
    if len(sizes) == 0:
        return TreeNode(children = [])
    children = [make_test_tree(s) for s in sizes]
    return TreeNode(children = children)

In [6]:
tree = make_test_tree([[4,1,2],[1,2,3]])

In [7]:
list(tree.inner_and_full_paths())

[(),
 (0,),
 (0, 0),
 (0, 0, 0),
 (0, 0, 1),
 (0, 0, 2),
 (0, 0, 3),
 (0, 1),
 (0, 1, 0),
 (0, 2),
 (0, 2, 0),
 (0, 2, 1),
 (1,),
 (1, 0),
 (1, 0, 0),
 (1, 1),
 (1, 1, 0),
 (1, 1, 1),
 (1, 2),
 (1, 2, 0),
 (1, 2, 1),
 (1, 2, 2)]

In [8]:
list(tree.lookup_path((1,2)).inner_and_full_paths(prefix_so_far=(1,2)))

[(1, 2), (1, 2, 0), (1, 2, 1), (1, 2, 2)]

In [9]:
tree.depth()

3

In [10]:
ncrp = NCRP(progress_bar = 'terminal')
ncrp

NCRP(alphaTheta=array([ 1.,  1.,  1.]), alphaV=0.1, alphaW=0.1, iterations=100, progress_bar='terminal', branch_structure=[5, 5], depth=2)

In [12]:
from load_data import load_data, load_vocab
data = load_data("/Users/aleverentz/Downloads/ap/ap.dat")
vocab = load_vocab("/Users/aleverentz/Downloads/ap/vocab.txt")

Reading lines: 100%|██████████| 2246/2246 [00:00<00:00, 2421.56it/s]
Filling matrix: 100%|██████████| 2246/2246 [00:02<00:00, 899.28it/s]
Loading vocabulary: 100%|██████████| 10473/10473 [00:00<00:00, 1517181.15it/s]


In [13]:
data.shape

(2246, 10473)

In [14]:
data = data[:,:]

In [15]:
data.shape

(2246, 10473)

In [16]:
data = data[np.where(data.sum(axis=1) > 0)[0], :]

In [17]:
data.shape

(2246, 10473)

In [18]:
%%time
f = ncrp.fit(data = data)

Optimizing via coordinate ascent: 100%|██████████| 100/100 [17:43<00:00, 11.52s/it]

CPU times: user 6min 54s, sys: 10min 40s, total: 17min 34s
Wall time: 17min 43s





In [19]:
Counter(f.get_most_likely_paths())

Counter({(0, 4): 12, (1,): 1192, (1, 4): 1042})

In [20]:
top_words = f.get_top_words_per_node(k = 10, vocab = vocab)
top_words

OrderedDict([((),
              (['i',
                'people',
                'two',
                'police',
                'new',
                'years',
                'last',
                'state',
                'year',
                'million'],
               array([ 0,  3,  5, 12,  1, 10,  8, 13,  4,  6]),
               array([ 1247.24924033,   989.86455339,   894.76433169,   847.20637263,
                        829.21875603,   736.2024551 ,   601.93400923,   598.42339873,
                        578.78484991,   541.30220339]))),
             ((0,),
              (['index',
                'stock',
                'million',
                'points',
                'exchange',
                'yen',
                'shares',
                'share',
                'trading',
                'close'],
               array([336,  87,   6, 316, 246, 387, 439, 213, 215, 223]),
               array([ 21.09244052,  20.0933307 ,  20.09296462,  18.09359742,
             

In [31]:
# Testing to make sure top_words_per_node behaves as expected
for path, (words, indices, alphas) in top_words.items():
    # Check that `words` and `indices` agree with `vocab`
    words_from_indices = [vocab[i] for i in indices]
    if words != words_from_indices:
        print(words)
        print(words_from_indices)
        raise AssertionError()
    # Check that `alphas` and `indices` agree with `f.alphaW_var`
    path_index = f.path_to_index[path]
    alphas_from_indices = f.alphaW_var[path_index, indices]
    if not np.array_equal(alphas, alphas_from_indices):
        print(alphas)
        print(alphas_from_indices)
        raise AssertionError()
    # Check that whenever alphaW_var > min(alphas), word_index is in indices
    min_alpha = min(alphas)
    for word_index, alpha in enumerate(f.alphaW_var[path_index, :]):
        if alpha > min_alpha:
            if word_index not in indices:
                print("At word index {}, {} > {} but {} is not in {}").format(
                    word_index, alpha, min_alpha, word_index, indices)
                raise AssertionError()
    # Check that alphas are sorted
    if list(alphas) != sorted(list(alphas), key=lambda x: -x):
        print("Alphas is not sorted in descending order: {}".format(alphas))
        raise AssertionError()
print("Done")

Done
