In [None]:
import argparse
import dill
import autograd.numpy as np
from autograd import grad 
import random
import types
import warnings
from copy import deepcopy
from tqdm.notebook import tqdm

In [None]:
np.random.seed(41)

In [None]:
# Load defined models and trajectories
'''
These are synthetics examples for an optimal stopping problem for
clinician diagnosis of disease
'''
with open('data/model0.obj', 'rb') as f:
    model = dill.load(f)
with open('data/model0_trajs.obj', 'rb') as f:
    trajs = dill.load(f)
    trajs = random.sample(trajs, 100)

In [None]:
'''
Set the dimensions of the spaces
'''

s_size = model.S #2
a_size = model.A #3
z_size = model.Z #3

In [None]:
# Random initialisation of params
model.b0 = np.random.dirichlet([1]*model.S)
model.O = np.random.dirichlet([1]*model.Z, size=(model.A,model.S))
model.mu = np.random.normal(.5,.5, size=(model.A,model.S))

model.T = model.T.astype(float)

In [None]:
"""
Softmax functions to parameterise distributions so that we don't have to 
do constrained optimisations
"""


def softmax_O(x):    
    e_x = np.exp(x - np.max(x,axis=2).reshape((a_size,s_size,1)))

    return e_x / e_x.sum(axis=2).reshape((a_size,s_size,1))

def softmax_b0(x):    
    e_x = np.exp(x - np.max(x))

    return e_x / e_x.sum(axis=0)

def softmax_T(x):    
    e_x = np.exp(x - np.max(x,axis=2).reshape((s_size,a_size,1)))

    return e_x / e_x.sum(axis=2).reshape((s_size,a_size,1))

In [None]:
like = np.array([-np.inf] * 6)

In [None]:
'''
Define mean vector policy parameterisation - this is original DIPOLE we
will use to generate a warm start for our parameters
'''

def policy(mu,eta,b):
    del_a = np.exp(-eta*np.sum((b-mu)**2,axis=-1))
    del_a /= del_a.sum()
    return del_a

In [None]:
'''
This is just to initialise some of the latent variables
'''            
    
for traj in trajs:
    traj.alp = [np.zeros(model.b0.shape) for _ in range(traj.tau+1)]
    traj.alp[0] = model.b0.copy()
    for t in range(traj.tau):
        traj.alp[t+1] = np.ravel(model.O[traj.a[t],:,traj.z[t],np.newaxis] * model.T[:,traj.a[t],:].T @ traj.alp[t][:,np.newaxis])
    traj.bet = [np.ones(model.b0.shape) for _ in range(traj.tau+1)]

    for t in reversed(range(traj.tau)):
        traj.bet[t] = np.ravel(model.T[:,traj.a[t],:] @ (model.O[traj.a[t],:,traj.z[t],np.newaxis] * traj.bet[t+1][:,np.newaxis]))
    traj.gmm = [alp * bet for alp, bet in zip(traj.alp, traj.bet)]
    traj.gmm = [gmm / gmm.sum() for gmm in traj.gmm]
    traj.xi = [None] * traj.tau
        
    for t in range(traj.tau):
        traj.xi[t] = model.O[traj.a[t],:,traj.z[t],np.newaxis].T * model.T[:,traj.a[t],:] * (traj.alp[t][:,np.newaxis] @ traj.bet[t+1][:,np.newaxis].T)
        traj.xi[t] /= traj.xi[t].sum()
    traj.b = [alp / alp.sum() for alp in traj.alp]

like[1:] = like[:-1]
like[0] = 0
like_a = 0
for traj in trajs:
    like[0] += np.sum(traj.gmm[0] * np.log(np.clip(model.b0, 1e-100,None)))
    for t in range(traj.tau):
        like[0] += np.sum(traj.xi[t] * np.log(np.clip(model.T[:,traj.a[t],:], 1e-100,None)))
    for t in range(traj.tau):
        like[0] += np.sum(traj.gmm[t+1] * np.log(np.clip(model.O[traj.a[t],:,traj.z[t]], 1e-100,None)))
    for t in range(traj.tau):
        like[0] += np.log(np.clip(model.pi(traj.b[t])[traj.a[t]], 1e-100,None))
        like_a += np.log(np.clip(model.pi(traj.b[t])[traj.a[t]], 1e-100,None))



In [None]:
'''
Key likelihood function, what we're looking to optimise, define here, optimsise later
'''

def likelihood(params,trajs):
    
    b0  = softmax_b0(params[0])
    O   = softmax_O(params[1])
    T   = softmax_T(params[2]) 
    mu  = params[3]
    eta = params[4]
    
    
    '''
    First we gnerate beliefs
    '''
    for traj in trajs:
        traj.alp1 = [np.zeros(b0.shape) for _ in range(traj.tau+1)]
        traj.alp1[0] = b0
        for t in range(traj.tau):
            traj.alp1[t+1] = np.ravel(O[traj.a[t],:,traj.z[t],np.newaxis] * T[:,traj.a[t],:].T @ traj.alp1[t][:,np.newaxis])

        traj.b = [alp / alp.sum() for alp in traj.alp1]
    
    likes = 0

    '''
    Now calculate the likelihood
    '''
    like_a = 0
    for traj in trajs:
        likes += np.sum(traj.gmm[0] * np.log(b0))
        for t in range(traj.tau):
            likes += np.sum(traj.xi[t] * np.log(T[:,traj.a[t],:]))
        for t in range(traj.tau):
            likes += np.sum(traj.gmm[t+1] * np.log(O[traj.a[t],:,traj.z[t]]))
        for t in range(traj.tau):
            likes += np.log(policy(mu,eta,traj.b[t])[traj.a[t]])
            like_a += np.log(policy(mu,eta,traj.b[t])[traj.a[t]])
    return likes

In [None]:
'''
Collect parameters to be optimised
'''
par = [model.b0,model.O,model.T,model.mu,float(model.eta)]

In [None]:
likelihood(par,trajs)

In [None]:
'''
Main training loop for the warm start
'''


l_rate = 1e-3

liks = []
param_history = []

grad_p = grad(likelihood)


for itr in tqdm(range(1000)):
    
    par = [model.b0,model.O,model.T,model.mu,float(model.eta)]
    param_history.append(par)
    
    b0  = softmax_b0(model.b0)
    O   = softmax_O(model.O)
    T   = softmax_T(model.T) 
    '''
    Forward-Backward algorithm to fix latent variables
    '''
    for traj in trajs:
        traj.alp = [np.zeros(model.b0.shape) for _ in range(traj.tau+1)]
        traj.alp[0] = b0.copy()
        for t in range(traj.tau):
            traj.alp[t+1] = np.ravel(O[traj.a[t],:,traj.z[t],np.newaxis] * T[:,traj.a[t],:].T @ traj.alp[t][:,np.newaxis])
        
        traj.bet = [np.ones(model.b0.shape) for _ in range(traj.tau+1)]
        for t in reversed(range(traj.tau)):
            traj.bet[t] = np.ravel(T[:,traj.a[t],:] @ (O[traj.a[t],:,traj.z[t],np.newaxis] * traj.bet[t+1][:,np.newaxis]))
        traj.gmm = [alp * bet for alp, bet in zip(traj.alp, traj.bet)]
        traj.gmm = [gmm / gmm.sum() for gmm in traj.gmm]
        
        traj.xi = [None] * traj.tau
        for t in range(traj.tau):
            traj.xi[t] = O[traj.a[t],:,traj.z[t],np.newaxis].T * T[:,traj.a[t],:] * (traj.alp[t][:,np.newaxis] @ traj.bet[t+1][:,np.newaxis].T)
            traj.xi[t] /= traj.xi[t].sum()
            
        traj.b = [alp / alp.sum() for alp in traj.alp]
    '''
    Now call gradient and unpack
    '''
    grads = grad_p(par,trajs)
    
    model.b0  += l_rate * grads[0]
    
    model.O   += l_rate * grads[1]

    model.T   += l_rate * grads[2] 
    model.mu  += l_rate * grads[3]
    model.eta += l_rate * grads[4]
    lik = likelihood(par,trajs)
    print(lik)
    liks.append(lik)

In [None]:
import matplotlib.pyplot as plt
plt.plot(liks)
plt.show()

In [None]:
par

In [None]:
with open('warm_model.obj', 'wb') as f:
    dill.dump(model, f)

In [None]:
'''
Now generate belief-action pairs so we can train the soft tree on them
first as a warm-start
'''

belief_list = []
action_list = []
for traj in trajs:
    a = deepcopy(traj.a)
    action_list += a
    b = deepcopy(traj.b)
    belief_list += b[:-1]
    
beliefs = np.array(belief_list)
beliefs_1 = beliefs[:,0].reshape(370,1)
action_list = np.array(action_list)  
actions = np.zeros((370,3))
actions[:,0] = action_list == 0
actions[:,1] = action_list == 1
actions[:,2] = action_list == 2

In [None]:
from soft_tree_model import soft_tree

In [None]:
policy_tree = soft_tree(tree_depth = 3,xdim=s_size,ydim=a_size)

In [None]:
for i in range(len(policy_tree.params)):
    policy_tree.params[i] = policy_tree.params[i] * 0.1

In [None]:
'''
Now get warm start parameters
'''
policy_tree.train(beliefs,actions,100,l_rate = 1e-2)

In [None]:
def likelihood_tree(params,trajs):
    
    b0  = softmax_b0(params[0])
    O   = softmax_O(params[1])
    T   = softmax_T(params[2]) 
    
    tree_p = params[3]
    
    for traj in trajs:
        traj.alp1 = [np.zeros(b0.shape) for _ in range(traj.tau+1)]
        traj.alp1[0] = b0
        for t in range(traj.tau):
            traj.alp1[t+1] = np.ravel(O[traj.a[t],:,traj.z[t],np.newaxis] * T[:,traj.a[t],:].T @ traj.alp1[t][:,np.newaxis])

        traj.b = [alp / alp.sum() for alp in traj.alp1]
    
    likes = 0
      
    like_a = 0
    for traj in trajs:
        likes += np.sum(traj.gmm[0] * np.log(b0))
        for t in range(traj.tau):
            likes += np.sum(traj.xi[t] * np.log(T[:,traj.a[t],:]))
        for t in range(traj.tau):
            likes += np.sum(traj.gmm[t+1] * np.log(O[traj.a[t],:,traj.z[t]]))
        for t in range(traj.tau):
            likes += np.log(policy_tree.forward([traj.b[t]],tree_p).reshape(3)[traj.a[t]])
            like_a += np.log(policy_tree.forward([traj.b[t]],tree_p).reshape(3)[traj.a[t]])
    return -likes

In [None]:
par = [model.b0,model.O,model.T,policy_tree.params]

In [None]:
'''
Main training loop for InterPoLe
'''


l_rate = 1e-3

liks = []
param_history = []

grad_p = grad(likelihood_tree)


for itr in tqdm(range(100)):
    
    par = [model.b0,model.O,model.T,policy_tree.params]
    param_history.append(par)
    
    b0  = softmax_b0(model.b0)
    O   = softmax_O(model.O)
    T   = softmax_T(model.T) 
    
    for traj in trajs:
        traj.alp = [np.zeros(model.b0.shape) for _ in range(traj.tau+1)]
        traj.alp[0] = b0.copy()
        for t in range(traj.tau):
            traj.alp[t+1] = np.ravel(O[traj.a[t],:,traj.z[t],np.newaxis] * T[:,traj.a[t],:].T @ traj.alp[t][:,np.newaxis])
        
        traj.bet = [np.ones(model.b0.shape) for _ in range(traj.tau+1)]
        for t in reversed(range(traj.tau)):
            traj.bet[t] = np.ravel(T[:,traj.a[t],:] @ (O[traj.a[t],:,traj.z[t],np.newaxis] * traj.bet[t+1][:,np.newaxis]))
        traj.gmm = [alp * bet for alp, bet in zip(traj.alp, traj.bet)]
        traj.gmm = [gmm / gmm.sum() for gmm in traj.gmm]
        
        traj.xi = [None] * traj.tau
        for t in range(traj.tau):
            traj.xi[t] = O[traj.a[t],:,traj.z[t],np.newaxis].T * T[:,traj.a[t],:] * (traj.alp[t][:,np.newaxis] @ traj.bet[t+1][:,np.newaxis].T)
            traj.xi[t] /= traj.xi[t].sum()
            
        traj.b = [alp / alp.sum() for alp in traj.alp]
        
    grads = grad_p(par,trajs)
    
    model.b0  -= l_rate * grads[0]
    
    model.O   -= l_rate * grads[1]

    model.T   -= l_rate * grads[2] 
        
    policy_tree.params = policy_tree.update_step(policy_tree.params,grads[3],l_rate)
    
    lik = likelihood_tree(par,trajs)
    print(lik)
    liks.append(lik)

In [None]:
with open('interpole_model.obj', 'wb') as f:
    dill.dump(model, f)
    
with open('interpole_tree.obj', 'wb') as f:
    dill.dump(policy_tree, f)