In [None]:
import os
import subprocess
import sys
import autograd.numpy as np
import itertools
import json
from tqdm import tqdm
from IPython.display import display, HTML
import time
import copy
from collections import namedtuple
from functools import partial
import matplotlib.pyplot as plt
top_level_dir = '/'.join( os.getcwd().split( '/' )[ :-2 ] )
if top_level_dir not in sys.path:
    sys.path.append( top_level_dir )
%load_ext autoreload
%autoreload 2
display(HTML("<style>.container { width:100% !important; }</style>"))

# This notebook will look at the convergence rate for different kinds of graphs
\*Even though gibbs isn't the best choice for getting a good marginal value, will include the plot of the marginal values that are sampled anyway

In [None]:
from GenModels.GM.Distributions import Categorical, Dirichlet, TensorTransition, TensorTransitionDirichletPrior
from GenModels.GM.Models.DiscreteGraphModels import *
from GenModels.GM.States.GraphicalMessagePassing import *

In [None]:
graphs = []
for _ in range( 50 ):
    line = DataGraph()
    for i in range( 1, 30 ):
        line.addEdge( parents=[ i-1 ], children=[ i ] )
    graphs.append( ( line, np.array( [] ) ) )

In [None]:
graphs = [ ( graph1(), np.array( [] ) ),
           ( graph2(), np.array( [] ) ),
           ( graph3(), np.array( [] ) ),
           ( graph4(), np.array( [] ) ),
           ( graph5(), np.array( [] ) ),
           ( graph6(), np.array( [] ) ),
           ( graph7(), np.array( [] ) ),
           cycleGraph1(),
           cycleGraph2(),
           cycleGraph3(),
           cycleGraph7(),
           cycleGraph8(),
           cycleGraph10(),
           cycleGraph11(),
           cycleGraph12() ]
graphs = graphs*3

In [None]:
total_nodes = sum( [ len( g.nodes ) for g, _ in graphs ] )
print( total_nodes )

In [None]:
initial_shape, transition_shapes, emission_shape = GHMM.parameterShapes( graphs, d_latent=3, d_obs=4 )
initial_priors = np.ones( initial_shape )
transition_priors = [ np.ones( shape ) for shape in transition_shapes ]
emission_prior = np.ones( emission_shape )

In [None]:
true_model = GHMM( priors=( initial_priors, transition_priors, emission_prior ), method='EM' )
for i, ( graph, fbs ) in enumerate( graphs ):
    true_model.setGraphs( [ ( graph, fbs ) ] )
    _, data = true_model.sampleStates()
    graph.setNodeData( data.keys(), data.values() )

In [None]:
# print( true_model.params.initial_dist.pi )
# print( [ dist.pi for dist in true_model.params.transition_dists ] )
# print( true_model.params.emission_dist.pi )

# Test 1 - Deep graphs without cycles

In [None]:
# em_model    = GHMM( graphs, priors=( initial_priors, transition_priors, emission_prior ), method='EM' )
# gibbs_model = GHMM( graphs, priors=( initial_priors, transition_priors, emission_prior ), method='Gibbs' )
# cavi_model  = GHMM( graphs, priors=( initial_priors, transition_priors, emission_prior ), method='CAVI' )
# svi_model   = GHMM( graphs, priors=( initial_priors, transition_priors, emission_prior ), method='SVI', step_size=0.1, minibatch_size=1 )

In [None]:
# values = []
# it = np.arange( 10 )
# for _ in it:
#     em_marginal    = em_model.fitStep()
#     gibbs_marginal = gibbs_model.fitStep( return_marginal=True )
#     elbo_cavi      = cavi_model.fitStep()
#     elbo_svi       = svi_model.fitStep()
#     values.append( [ em_marginal, gibbs_marginal, elbo_cavi, elbo_svi ] )

#     print( values[ -1 ] )

In [None]:
# y1, y2, y3, y4 = zip( *values )

In [None]:
# plt.plot( it[ :10 ], y1[ :10 ], color='red', label='em' )
# plt.plot( it[ :10 ], y2[ :10 ], color='blue', label='gibbs' )
# plt.plot( it[ :10 ], y3[ :10 ], color='green', label='elbo' )
# plt.plot( it[ :10 ], y4[ :10 ], color='purple', label='svi' )
# plt.show()

In [None]:
svae = GSVAE( graphs[ :3 ], priors=( initial_priors, transition_priors ), d_obs=4 )

In [None]:
# losses = svae.fit()

In [None]:
# assert 0

In [None]:
# plt.plot( np.arange( len( losses ) ), np.array( [ l._value for l in losses ] ) )

In [None]:
y = np.random.choice( 5, 5 )

In [None]:
one_hot = np.zeros( ( y.shape[ 0 ], 5 ) )
one_hot[ np.arange( 5 ), y ] = 1

In [None]:
one_hot

In [None]:
y[ None ] + one_hot

In [None]:
groups = [ 0, 1, 2 ]
d_latents = dict( zip( groups, [ 2, 3, 4 ] ) )
d_obs = 4

In [None]:
def graphToGroupGraph( graphs, dataPerNode, groupPerNode, with_fbs=False, random_latent_states=False, d_latents=None ):
    assert isinstance( graphs, list )
    group_graphs = []
    for graph in graphs:

        if( with_fbs ):
            if( not isinstance( graph, Graph ) ):
                graph, fbs = graph
            else:
                graph, fbs = graph, np.array( [] )

        data = [ ( node, dataPerNode( node ) ) for node in graph.nodes ]
        group = [ ( node, groupPerNode( node ) ) for node in graph.nodes ]
        group_graph = GroupGraph.fromGraph( graph, data, group )

        if( random_latent_states ):
            assert d_latents is not None
            for node in group_graph.nodes:
                group = group_graph.groups[ node ]
                possible_latent_states = np.array( list( set( np.random.choice( np.arange( d_latents[ group ] ), d_latents[ group ] - 1 ).tolist() ) ) )
                group_graph.setPossibleLatentStates( node, possible_latent_states )

        if( with_fbs ):
            group_graphs.append( ( group_graph, fbs ) )
        else:
            group_graphs.append( group_graph )
    return group_graphs

def dataPerNode( node ):
    return Categorical.generate( D=d_obs, size=1 )
def groupPerNode( node ):
    return Categorical.generate( D=len( groups ) )

group_graphs = graphToGroupGraph( graphs, dataPerNode, groupPerNode, with_fbs=True )

In [None]:
shapes = GroupGHMM.parameterShapes( group_graphs, d_latents={ 0:2, 1:3, 2:4 }, d_obs=4, groups=[ 0, 1, 2 ] )
initial_shapes, transition_shapes, emission_shapes = shapes

In [None]:
initial_priors = dict( ( group, np.ones( shape ) ) for group, shape in initial_shapes.items() )
transition_priors = dict( ( group, [ np.ones( shape ) for shape in shapes ] ) for group, shapes in transition_shapes.items() )
emission_prior = dict( ( group, np.ones( shape ) ) for group, shape in emission_shapes.items() )

In [None]:
# true_model = GroupGHMM( priors=( initial_priors, transition_priors, emission_prior ), method='EM' )

In [None]:
# for i, ( graph, fbs ) in enumerate( group_graphs ):
#     true_model.setGraphs( [ ( graph, fbs ) ] )
#     _, data = true_model.sampleStates()
#     graph.setNodeData( data.keys(), data.values() )

In [None]:
# em_model    = GroupGHMM( group_graphs, priors=( initial_priors, transition_priors, emission_prior ), method='EM' )
# gibbs_model = GroupGHMM( group_graphs, priors=( initial_priors, transition_priors, emission_prior ), method='Gibbs' )
# cavi_model  = GroupGHMM( group_graphs, priors=( initial_priors, transition_priors, emission_prior ), method='CAVI' )
# svi_model   = GroupGHMM( group_graphs, priors=( initial_priors, transition_priors, emission_prior ), method='SVI', step_size=0.1, minibatch_size=1 )

In [None]:
# cavi_model.fitStep()

In [None]:
# cavi_model.fitStep()

In [None]:
group_svae = GroupGSVAE( graphs=group_graphs[ :3 ], priors=( initial_priors, transition_priors ), d_obs=d_obs )

In [None]:
group_svae.fit()

In [16]:
def logsumexp( v, axis=0 ):
    max_v = np.max( v )
    return np.log( np.sum( np.exp( v - max_v ), axis=axis ) ) + max_v


In [28]:
from autograd import grad
import autograd.numpy as np
from autograd.misc.optimizers import adam

In [60]:
y = np.zeros( 4 )
y[ 0 ] = 0.7
y[ 1 ] = 0.1
y[ 2 ] = 0.1
y[ 3 ] = 0.1

def blah( x, i ):
    ans = -np.sum( y * np.sin( 0.5 + x )**2 )
    print( ans )
    return ans

In [64]:
np.max( ( 1.0, 0.0, 2.0 ) )

2.0

In [65]:
np.random.choice( 10 )

4

In [67]:
def normalize( x ):
    return x - logsumexp( x )

In [69]:
x = np.array( [ 1, 3 ] )
print( np.exp( normalize( x ) ) )

[0.11920292 0.88079708]
