In [2]:
import torch
torch.autograd.set_detect_anomaly(True)
belief_updt_n_epochs = 4
belief_updt_lr = 0.1
action_slc_n_epochs = 5
action_slc_lr = 0.1
n_target = 2
mvt_amplitude = 3
max_coord = torch.tensor([300., 600.])
n_rollout = 3
n_step_per_rollout = 3
user_sigma = 5
decay_factor = 0.9

x = torch.tensor([[0.25, 0.25], [0.75, 0.75]]) * max_coord
b = torch.ones(n_target) / n_target

a = torch.nn.Parameter(torch.zeros(n_target))

In [7]:
def update_environment(positions, action):
    
    for i in range(n_target):
        angle = action[i]
        if 90 < angle <= 270:
            x_prime = -1
        else:
            x_prime = 1
        y_prime = torch.tan(torch.deg2rad(angle)) * x_prime

        norm = mvt_amplitude / torch.sqrt(y_prime**2 + x_prime**2)
        movement = torch.tensor([x_prime, y_prime]) * norm
        positions[i] += movement

    for coord in range(2):
        for target in range(n_target):
            if positions[target, coord] > max_coord[coord]:
                positions[target, coord] = max_coord[coord]
                
    return positions
                

def logp_action(positions, action, prev_positions):
    
    logp_y = torch.zeros(n_target)
    for target in range(n_target):
        for coord in range(2):
            d = positions[target, coord] - prev_positions[target, coord]
            logp_coord = torch.distributions.Normal(d, user_sigma).log_prob(action[coord])
            logp_y[target] += logp_coord
            
    return logp_y


def sim_act(positions, prev_positions, goal):
    
    delta = positions[goal] - prev_positions[goal]
    noise = torch.randn(2) * user_sigma
    y = delta + noise
    return y



a = torch.nn.Parameter(a)
a_opt = torch.optim.Adam([a, ], lr=0.01)

for _ in range(213):

    old_a = a.clone()
    a_opt.zero_grad()

    # -----------------------------------
    # Build action plans
    # -----------------------------------

    first_action = torch.sigmoid(a)

    action_plan = torch.zeros((n_rollout, n_step_per_rollout, n_target))
    action_plan[:, 0] = first_action
    if n_step_per_rollout > 1:
        action_plan[:, 1:, :] = torch.rand((n_rollout, n_step_per_rollout - 1, n_target)) * 360

    action_plan *= 360  # Convert in degrees

    total_efe = 0
    for rol in range(n_rollout):

        efe_rollout = 0

        # Sample the user goal --------------------------

        q = torch.softmax(b - b.max(), dim=0)
        goal = torch.multinomial(q, 1)[0]

        # -----------------------------------------------

        x_rol = x.clone()
        b_rol = b.clone()

        for step in range(n_step_per_rollout):

            action = action_plan[rol, step]

            # ---- Update positions based on action ---------------------------------------------

            x_rol_prev = x_rol.clone()

            x_rol = update_environment(positions=x_rol, action=action)

            # ------------------------------------------------------------------------------
            # Evaluate epistemic value -----------------------------------------------------
            # ------------------------------------------------------------------------------

            # Simulate action based on goal ----------------------------------------

            y = sim_act(positions=x_rol, goal=goal, prev_positions=x_rol_prev)

            # Compute log probability of user action given a specific goal in mind -------
            logp_y = logp_action(positions=x_rol, action=y, prev_positions=x_rol_prev)

            logq = torch.log_softmax(b_rol - b_rol.detach().max(), dim=0)
            logp_yq = logq + logp_y

            # Revise belief -------------------

            b_rol = torch.nn.Parameter(b_rol)
            b_opt = torch.optim.Adam([b_rol, ], lr=0.01)

            q_rol, kl_div = None, None

            for _ in range(132):

                old_b = b_rol.clone()
                b_opt.zero_grad()
                q_rol = torch.softmax(b_rol - b_rol.detach().max(), dim=0)
                kl_div = torch.sum(q_rol * (q_rol.log() - logp_yq))
                kl_div.backward(retain_graph=True)
                b_opt.step()

                if torch.isclose(old_b, b_rol).all():
                    break

            epistemic_value = kl_div

            # --------------------------------------
            # Compute extrinsic value
            # --------------------------------------

            extrinsic_value = (q_rol * q_rol.log()).sum()  # minus entropy

            # --------------------------------------
            # Compute loss
            # ---------------------------------------
            efe_step = - epistemic_value - extrinsic_value
            efe_rollout += decay_factor ** step * efe_step

        total_efe += efe_rollout

    total_efe /= n_rollout
    total_efe.backward()
    a_opt.step()

    if torch.isclose(old_a, a).all():
        break