In [1]:
import torch
import torch.nn as nn
import numpy as np
import scipy.stats

In [2]:
class CRP:
    
    def __init__(self, alpha):
        self.alpha = alpha
        self.assignement_list = []  # there are assignement_list[i] customers at table i
        self.n_customers = 0
        self.n_tables = 0
    
    def sample(self):
        if not self.assignement_list:  # first customer to table 0
            self.assignement_list.append(1)
            self.n_customers = 1
            self.n_tables = 1
            return 0
        else:
            aux_assignement_list = self.assignement_list.copy()
            aux_assignement_list.append(self.alpha)
            prob_vec = np.array(aux_assignement_list, dtype='float32') / (self.n_customers + self.alpha)
            table = np.random.choice(self.n_tables + 1, p=prob_vec)
            self.n_customers += 1
            if table == self.n_tables:  # new table is formed with prob prop to alpha
                self.n_tables += 1
                self.assignement_list.append(1)
            else:  # otherwise customer sits at existing table with prob prop to its number of customers
                self.assignement_list[table] += 1
            return table
                

In [3]:
def sample_option(crp, policies, policy_prior, termination_distrs,
                  termination_distr_prior, dim_states, dim_actions):
    out = crp.sample()
    if len(policies) < crp.n_tables:
        aux_policy_list = []
        aux_termination_distr_list = []
        for s in range(dim_states):
            policy = scipy.stats.rv_discrete(values=(range(dim_actions), policy_prior.rvs()[0]))
            aux_policy_list.append(policy)
            termination_distr = scipy.stats.bernoulli(p=termination_distr_prior.rvs())
            aux_termination_distr_list.append(termination_distr)
        policies.append(aux_policy_list)
        termination_distrs.append(aux_termination_distr_list)
    return out

def generate_trajectories(N, T, dim_states, dim_actions, alpha):
    # TODO: give seed as input for replicability
    states = np.zeros([N, T], dtype='int64')
    actions = np.zeros([N, T], dtype='int64')
    options = np.zeros([N, T], dtype='int64')
    terminations = np.zeros([N, T], dtype='bool8')
    crp = CRP(alpha)
    s0_prior = scipy.stats.rv_discrete(values=(range(dim_states), np.ones(dim_states) / dim_states))
    policy_prior = scipy.stats.dirichlet(np.ones(dim_actions))  # same prior is used for all options and states
    termination_distr_prior = scipy.stats.uniform()
    policies = []
    termination_distrs = []
    transitions = []
    for s in range(dim_states):
        transitions_s = []
        for a in range(dim_actions):
            transition_probs = scipy.stats.dirichlet(np.ones(dim_states)).rvs()
            transitions_s.append(scipy.stats.rv_discrete(values=(range(dim_states), transition_probs[0])))
        transitions.append(transitions_s)
    for i in range(N):
        states[i, 0] = s0_prior.rvs()
        terminations[i, 0] = 1
        options[i, 0] = sample_option(crp, policies, policy_prior, termination_distrs,
                                      termination_distr_prior, dim_states, dim_actions)
        for t in range(T-1):
            actions[i, t] = policies[options[i, t]][states[i, t]].rvs()
            states[i, t+1] = transitions[states[i, t]][actions[i, t]].rvs()
            terminations[i, t+1] = termination_distrs[options[i, t]][states[i, t+1]].rvs()
            if terminations[i, t+1] == 0:
                options[i, t+1] = options[i, t]
            else:
                options[i, t+1] = sample_option(crp, policies, policy_prior, termination_distrs,
                                                termination_distr_prior, dim_states, dim_actions)
        actions[i, -1] = policies[options[i, -1]][states[i, -1]].rvs()
    
    return states, actions, options, terminations, crp.n_tables

In [4]:
class SmallVarianceOptimizer:
    
    def __init__(self, states, dim_states, actions, dim_actions, K, lambda_vec, steps,
                 lr=0.1, device='cuda', clip=5.):
        self.device = device
        self.states = torch.tensor(states).to(self.device)  # [N, T]
        self.actions = torch.tensor(actions).to(self.device)  # [N, T]
        self.K = K
        self.lambda_vec = lambda_vec
        self.lr = lr
        self.N, self.T = states.shape
        self.dim_actions = dim_actions
        self.steps = steps
        self.relaxed_options_logits = torch.randn((N, T, K), requires_grad=True, device=self.device)
        self.relaxed_terminations_logits = torch.randn((N, T), requires_grad=True, device=self.device)
        self.termination_fn_logits = torch.randn((K, dim_states), requires_grad=True, device=self.device)
        self.policies_logits = torch.randn((dim_states, K, dim_actions), requires_grad=True, device=self.device)
        self.parameter_list = [self.relaxed_options_logits, self.relaxed_terminations_logits,
                               self.termination_fn_logits, self.policies_logits]
        self.optimizer = torch.optim.Adam(self.parameter_list, lr=lr, weight_decay=0.1*lr)
        self.clip = clip
        
    def compute_objective(self, discretize=False):
        relaxed_terminations = torch.sigmoid(self.relaxed_terminations_logits)
        relaxed_options = nn.functional.softmax(self.relaxed_options_logits, dim=-1)
        if discretize:  # in this case relaxed_terminations and relaxed_options are not relaxed
            relaxed_terminations = torch.sign(relaxed_terminations - 0.5) * 0.5 + 0.5
            relaxed_options = nn.functional.one_hot(torch.argmax(relaxed_options, dim=-1), self.K)
        termination_fn = torch.sigmoid(self.termination_fn_logits)
        policies = nn.functional.softmax(self.policies_logits, dim=2)
        term1 = -self.lambda_vec[0] * self.K
        flat_states = torch.reshape(self.states[:, 1:], (-1,))
        flat_term_fns_trajs = torch.index_select(input=termination_fn, dim=1, index=flat_states)
        term_fns_trajs = torch.transpose(torch.transpose(torch.reshape(
            flat_term_fns_trajs, (self.K, self.N, self.T-1)), 0, 2), 0, 1)  # [N,T-1,K]
        relaxed_term_probs = torch.sum(relaxed_options[:, :-1] * term_fns_trajs, dim=2)  # [N, T-1]
#         relaxed_bce = nn.functional.binary_cross_entropy(input=relaxed_term_probs,
#                                                          target=relaxed_terminations[:, 1:], reduction='none')
        relaxed_bce = -relaxed_terminations[:, 1:] * torch.log(relaxed_term_probs) + (relaxed_terminations[:, 1:] - 1.) * torch.log(1. - relaxed_term_probs)
        term2 = -self.lambda_vec[1] * (relaxed_bce +
                                   torch.log(torch.maximum(relaxed_term_probs, 1. - relaxed_term_probs)))
        flat_states_all = torch.reshape(self.states, (-1,))
        flat_policies_traj_states = torch.index_select(input=policies, dim=0, index=flat_states_all)
        policies_traj_states = torch.reshape(flat_policies_traj_states,
                                             (self.N, self.T, self.K, self.dim_actions))
        relaxed_policy_vals = torch.sum(torch.unsqueeze(relaxed_options, 3) * policies_traj_states, dim=2)
        # relaxed_policy vals has shape [N, T, dim_actions]
        relaxed_policy_at_actions = torch.squeeze(torch.gather(
            input=relaxed_policy_vals, dim=2, index=torch.unsqueeze(self.actions, 2)))  # [N, T]
        term4 = relaxed_policy_at_actions - torch.max(relaxed_policy_vals, dim=2)[0]
        objective = term1 + torch.sum(term2) + torch.sum(term4)
        if not discretize:
            term3 = -(1. - relaxed_terminations[:, 1:]) * torch.linalg.norm(
                relaxed_options[:, 1:] - relaxed_options[:, :-1], ord=2, dim=2)
            term5 = relaxed_terminations * torch.log(relaxed_terminations) + (1. - relaxed_terminations) * torch.log(1. - relaxed_terminations)
            term5 += torch.sum(relaxed_options * torch.log(relaxed_options), dim=2)
            objective += self.lambda_vec[2] * torch.sum(term3) + self.lambda_vec[3] * torch.sum(term5)
        return -objective / (self.N * self.T)
    
    def _gradient_step(self):
        self.optimizer.zero_grad()
        negative_objective = self.compute_objective()
        negative_objective.backward()
        nn.utils.clip_grad_norm_(self.parameter_list, self.clip)
        self.optimizer.step()
        return negative_objective.detach().cpu().numpy()
    
    def train(self, verbose=True):
        for i in range(self.steps):
            negative_objective_np = self._gradient_step()
            if verbose:
                print(f'Finished epoch {i}\twith loss: {negative_objective_np:f}\t')
        discrete_negative_objective = self.compute_objective(discretize=True)
        print(f'Finished training with discrete loss: {discrete_negative_objective:f}\t')
        

In [5]:
N = 100  # number of trajectories
T = 4  # number of steps per trajectory (corresponds to T-1 in the manuscript, last state is not sampled)
dim_states = 3  # discrete state space, this is the number of states
dim_actions = 2  # discrete action space, this is the number of actions
alpha = 1.
states, actions, options, terminations, true_K = generate_trajectories(N, T, dim_states, dim_actions, alpha)
print(true_K)

10


In [6]:
K = 3
lambda_vec = (N * T / 40., 20, 5 / np.sqrt(K), 10.)
steps = 150

svo = SmallVarianceOptimizer(states, dim_states, actions, dim_actions, K, lambda_vec, steps)
svo.train()

Finished epoch 0	with loss: 19.176922	
Finished epoch 1	with loss: 18.681585	
Finished epoch 2	with loss: 18.245411	
Finished epoch 3	with loss: 17.908915	
Finished epoch 4	with loss: 17.712057	
Finished epoch 5	with loss: 17.587275	
Finished epoch 6	with loss: 17.480705	
Finished epoch 7	with loss: 17.467039	
Finished epoch 8	with loss: 17.468140	
Finished epoch 9	with loss: 17.508423	
Finished epoch 10	with loss: 17.530903	
Finished epoch 11	with loss: 17.592415	
Finished epoch 12	with loss: 17.572828	
Finished epoch 13	with loss: 17.567417	
Finished epoch 14	with loss: 17.642328	
Finished epoch 15	with loss: 17.695250	
Finished epoch 16	with loss: 17.720999	
Finished epoch 17	with loss: 17.736547	
Finished epoch 18	with loss: 17.726431	
Finished epoch 19	with loss: 17.767076	
Finished epoch 20	with loss: 17.812025	
Finished epoch 21	with loss: 17.801205	
Finished epoch 22	with loss: 17.801735	
Finished epoch 23	with loss: 17.806765	
Finished epoch 24	with loss: 17.868113	
Finished e