In [1]:
import os
import subprocess
import sys
import autograd.numpy as np
import itertools
import json
from tqdm import tqdm
from IPython.display import display, HTML
import time
import copy
from collections import namedtuple
from functools import partial
import matplotlib.pyplot as plt
top_level_dir = '/'.join( os.getcwd().split( '/' )[ :-2 ] )
if top_level_dir not in sys.path:
    sys.path.append( top_level_dir )
%load_ext autoreload
%autoreload 2
display(HTML("<style>.container { width:100% !important; }</style>"))

# This notebook will look at the convergence rate for different kinds of graphs
\*Even though gibbs isn't the best choice for getting a good marginal value, will include the plot of the marginal values that are sampled anyway

In [2]:
from GenModels.GM.Distributions import Categorical, Dirichlet, TensorTransition, TensorTransitionDirichletPrior
from GenModels.GM.Models.DiscreteGraphModels import *
from GenModels.GM.States.GraphicalMessagePassing import *

In [3]:
graphs = []
for _ in range( 50 ):
    line = DataGraph()
    for i in range( 1, 30 ):
        line.addEdge( parents=[ i-1 ], children=[ i ] )
    graphs.append( ( line, np.array( [] ) ) )

In [4]:
graphs = [ ( graph1(), np.array( [] ) ),
           ( graph2(), np.array( [] ) ),
           ( graph3(), np.array( [] ) ),
           ( graph4(), np.array( [] ) ),
           ( graph5(), np.array( [] ) ),
           ( graph6(), np.array( [] ) ),
           ( graph7(), np.array( [] ) ),
           cycleGraph1(),
           cycleGraph2(),
           cycleGraph3(),
           cycleGraph7(),
           cycleGraph8(),
           cycleGraph10(),
           cycleGraph11(),
           cycleGraph12() ]
graphs = graphs*3

In [5]:
total_nodes = sum( [ len( g.nodes ) for g, _ in graphs ] )
print( total_nodes )

348


In [6]:
initial_shape, transition_shapes, emission_shape = GHMM.parameterShapes( graphs, d_latent=3, d_obs=4 )
initial_priors = np.ones( initial_shape )
transition_priors = [ np.ones( shape ) for shape in transition_shapes ]
emission_prior = np.ones( emission_shape )

In [7]:
true_model = GHMM( priors=( initial_priors, transition_priors, emission_prior ), method='EM' )
for i, ( graph, fbs ) in enumerate( graphs ):
    true_model.setGraphs( [ ( graph, fbs ) ] )
    _, data = true_model.sampleStates()
    graph.setNodeData( data.keys(), data.values() )

In [8]:
# print( true_model.params.initial_dist.pi )
# print( [ dist.pi for dist in true_model.params.transition_dists ] )
# print( true_model.params.emission_dist.pi )

# Test 1 - Deep graphs without cycles

In [9]:
# em_model    = GHMM( graphs, priors=( initial_priors, transition_priors, emission_prior ), method='EM' )
# gibbs_model = GHMM( graphs, priors=( initial_priors, transition_priors, emission_prior ), method='Gibbs' )
# cavi_model  = GHMM( graphs, priors=( initial_priors, transition_priors, emission_prior ), method='CAVI' )
# svi_model   = GHMM( graphs, priors=( initial_priors, transition_priors, emission_prior ), method='SVI', step_size=0.1, minibatch_size=1 )

In [10]:
# values = []
# it = np.arange( 10 )
# for _ in it:
#     em_marginal    = em_model.fitStep()
#     gibbs_marginal = gibbs_model.fitStep( return_marginal=True )
#     elbo_cavi      = cavi_model.fitStep()
#     elbo_svi       = svi_model.fitStep()
#     values.append( [ em_marginal, gibbs_marginal, elbo_cavi, elbo_svi ] )

#     print( values[ -1 ] )

In [11]:
# y1, y2, y3, y4 = zip( *values )

In [12]:
# plt.plot( it[ :10 ], y1[ :10 ], color='red', label='em' )
# plt.plot( it[ :10 ], y2[ :10 ], color='blue', label='gibbs' )
# plt.plot( it[ :10 ], y3[ :10 ], color='green', label='elbo' )
# plt.plot( it[ :10 ], y4[ :10 ], color='purple', label='svi' )
# plt.show()

In [13]:
svae = GSVAE( graphs[ :3 ], priors=( initial_priors, transition_priors ), d_obs=4 )

In [14]:
# losses = svae.fit()

In [15]:
# assert 0

In [16]:
# plt.plot( np.arange( len( losses ) ), np.array( [ l._value for l in losses ] ) )

In [17]:
y = np.random.choice( 5, 5 )

In [18]:
one_hot = np.zeros( ( y.shape[ 0 ], 5 ) )
one_hot[ np.arange( 5 ), y ] = 1

In [19]:
one_hot

array([[0., 0., 0., 0., 1.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.]])

In [20]:
y[ None ] + one_hot

array([[4., 2., 4., 1., 3.],
       [4., 2., 5., 1., 2.],
       [4., 2., 4., 1., 3.],
       [4., 3., 4., 1., 2.],
       [4., 2., 5., 1., 2.]])

In [21]:
groups = [ 0, 1, 2 ]
d_latents = dict( zip( groups, [ 2, 3, 4 ] ) )
d_obs = 4

In [22]:
def graphToGroupGraph( graphs, dataPerNode, groupPerNode, with_fbs=False, random_latent_states=False, d_latents=None ):
    assert isinstance( graphs, list )
    group_graphs = []
    for graph in graphs:

        if( with_fbs ):
            if( not isinstance( graph, Graph ) ):
                graph, fbs = graph
            else:
                graph, fbs = graph, np.array( [] )

        data = [ ( node, dataPerNode( node ) ) for node in graph.nodes ]
        group = [ ( node, groupPerNode( node ) ) for node in graph.nodes ]
        group_graph = GroupGraph.fromGraph( graph, data, group )

        if( random_latent_states ):
            assert d_latents is not None
            for node in group_graph.nodes:
                group = group_graph.groups[ node ]
                possible_latent_states = np.array( list( set( np.random.choice( np.arange( d_latents[ group ] ), d_latents[ group ] - 1 ).tolist() ) ) )
                group_graph.setPossibleLatentStates( node, possible_latent_states )

        if( with_fbs ):
            group_graphs.append( ( group_graph, fbs ) )
        else:
            group_graphs.append( group_graph )
    return group_graphs

def dataPerNode( node ):
    return Categorical.generate( D=d_obs, size=1 )
def groupPerNode( node ):
    return Categorical.generate( D=len( groups ) )

group_graphs = graphToGroupGraph( graphs, dataPerNode, groupPerNode, with_fbs=True )

In [23]:
shapes = GroupGHMM.parameterShapes( group_graphs, d_latents={ 0:2, 1:3, 2:4 }, d_obs=4, groups=[ 0, 1, 2 ] )
initial_shapes, transition_shapes, emission_shapes = shapes

In [24]:
initial_priors = dict( ( group, np.ones( shape ) ) for group, shape in initial_shapes.items() )
transition_priors = dict( ( group, [ np.ones( shape ) for shape in shapes ] ) for group, shapes in transition_shapes.items() )
emission_prior = dict( ( group, np.ones( shape ) ) for group, shape in emission_shapes.items() )

In [25]:
# true_model = GroupGHMM( priors=( initial_priors, transition_priors, emission_prior ), method='EM' )

In [26]:
# for i, ( graph, fbs ) in enumerate( group_graphs ):
#     true_model.setGraphs( [ ( graph, fbs ) ] )
#     _, data = true_model.sampleStates()
#     graph.setNodeData( data.keys(), data.values() )

In [27]:
# em_model    = GroupGHMM( group_graphs, priors=( initial_priors, transition_priors, emission_prior ), method='EM' )
# gibbs_model = GroupGHMM( group_graphs, priors=( initial_priors, transition_priors, emission_prior ), method='Gibbs' )
# cavi_model  = GroupGHMM( group_graphs, priors=( initial_priors, transition_priors, emission_prior ), method='CAVI' )
# svi_model   = GroupGHMM( group_graphs, priors=( initial_priors, transition_priors, emission_prior ), method='SVI', step_size=0.1, minibatch_size=1 )

In [28]:
# cavi_model.fitStep()

In [29]:
# cavi_model.fitStep()

In [30]:
group_svae = GroupGSVAE( graphs=group_graphs[ :3 ], priors=( initial_priors, transition_priors ), d_obs=d_obs )

In [34]:
group_svae.fit()

  return lambda g: g[idxs]


i 0 loss Autograd ArrayBox with value 666.5478645538861


  return f_raw(*args, **kwargs)
  defvjp(anp.tanh,   lambda ans, x : lambda g: g / anp.cosh(x) **2)


i 25 loss Autograd ArrayBox with value 729.7726157804339
i 50 loss Autograd ArrayBox with value 708.2332843905117
i 75 loss Autograd ArrayBox with value 625.0144292477235
Done!


[<autograd.numpy.numpy_boxes.ArrayBox at 0x7fab883ba7c8>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6f014a88>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6eda37c8>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6eb31408>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6e83e148>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6e635dc8>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6e34fbc8>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6e0dd908>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6de69648>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6dbf8348>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6d906088>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6d687d88>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6d416ac8>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6d1a5808>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6cf35548>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6cc43288>,
 <autograd.numpy.numpy_boxes.ArrayBox at 0x7fab6c9c5f88>,
 <autograd.num