In [3]:
import torch

In [6]:
action_slc_n_epochs = 100
action_slc_lr = 0.1
belief_updt_n_epochs = 100
belief_updt_lr = 0.1
n_target = 2
mvt_amplitude = 3
max_coord = torch.tensor([300., 600.])
n_rollout = 1
n_step_per_rollout = 1
n_spl_epist_val = 5
user_sigma = 5
decay_factor = 0.9

In [13]:
x = torch.tensor([[0.25, 0.25], [0.75, 0.75]]) * max_coord
x

tensor([[ 75., 150.],
        [225., 450.]])

In [16]:
b = torch.ones(n_target)
b

tensor([1., 1.])

In [22]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x12b220a30>

In [28]:
a = torch.nn.Parameter(torch.zeros(n_target))

opt = torch.optim.Adam([a, ], lr=belief_updt_lr)

for step in range(action_slc_n_epochs):

    old_a = a.clone()
    opt.zero_grad()

    action = torch.sigmoid(a)*360
    total_efe = 0
    for _ in range(n_rollout):
        efe_rollout = 0

        x_rollout = x.clone()
        b_rollout = b.clone()

        action_plan = torch.zeros((n_step_per_rollout, n_target))
        action_plan[0] = action
        if n_step_per_rollout > 1:
            action_plan[1:] = torch.rand((n_step_per_rollout-1, n_target)) * 360

        for step in range(n_step_per_rollout):

            a = action_plan[step]
            
            # ---- Update positions based on action ---------------------------------------------
            
            x_rollout_previous = x_rollout.clone()

            for i in range(n_target):
                angle = a[i]
                x_prime = 1.0
                if 90 < angle <= 270:
                    x_prime *= -1

                y_prime = torch.tan(torch.deg2rad(angle)) * x_prime

                norm = mvt_amplitude / torch.sqrt(y_prime**2 + x_prime**2)
                movement = torch.tensor([x_prime, y_prime]) * norm
                x_rollout[i] += movement

            for coord in range(2):
                for target in range(n_target):
                    if x_rollout[target, coord] > max_coord[coord]:
                        x_rollout[target, coord] = max_coord[coord]

            # ------------------------------------------------------------------------------            
            # Evaluate epistemic value -----------------------------------------------------
            # ------------------------------------------------------------------------------
            
            epistemic_value = 0
            b_prime_spl = torch.zeros((n_spl_epist_val, n_target))
            q = torch.softmax(b_rollout - b_rollout.max(), dim=0)
            logq = q.log()
            
            # Sample from assistant beliefs ---------------------------------------
            goals = torch.multinomial(q, n_spl_epist_val, replacement=True)
            
            for i in range(n_spl_epist_val):
                
                goal = goals[i]
                
                # Simulate action based on goal ----------------------------------------
                
                delta = x_rollout[goal] - x_rollout_previous[goal]
                noise = torch.randn(2) * user_sigma
                y = delta + noise
                
                # Compute log probability of user action given a specific goal in mind -------
                logp_y = torch.zeros(n_target)
                for target in range(n_target):
                    for coord in range(2):
                        d = x_rollout[target, coord] - x_rollout_previous[target, coord]
                        logp_coord = torch.distributions.Normal(d, user_sigma).log_prob(y[coord])
                        logp_y[target] += logp_coord

                logp_yq = logq + logp_y

                # --- Revise beliefs --------------------------------------------------
                
                b_prime = torch.nn.Parameter(b_rollout.clone())
                opt = torch.optim.Adam([b_prime, ], lr=belief_updt_lr)

                for step in range(belief_updt_n_epochs):

                    old_b_prime = b_prime.clone()

                    opt.zero_grad()
                    with torch.no_grad():
                        b_prime_scaled = b_prime - b_prime.max()
                    q_prime = torch.softmax(b_prime_scaled, dim=0)
                    kl_div = torch.sum(q_prime * (q_prime.log() - logp_yq))
                    kl_div.backward(retain_graph=True)
                    opt.step()

                    if torch.isclose(old_b_prime, b_prime).all():
                        break
                
                # -------------------
                
                b_prime_spl[i] = b_prime
                epistemic_value += kl_div

            epistemic_value /= n_spl_epist_val
            b_rollout = b_prime_spl[torch.randint(n_spl_epist_val, (1, ))].squeeze()
            
            efe_step = - epistemic_value  # Need to add intrinsic value

            efe_rollout += decay_factor**step * efe_step

        total_efe += efe_rollout

    total_efe /= n_rollout
    total_efe.backward(retain_graph=True)
    opt.step()

    if torch.isclose(old_a, a).all():
        break