In [1]:
import torch

In [2]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1181194c0>

In [49]:
belief_updt_n_epochs = 4
belief_updt_lr = 0.1
action_slc_n_epochs = 5
action_slc_lr = 0.1
n_target = 2
mvt_amplitude = 3
max_coord = torch.tensor([300., 600.])
n_rollout = 3
n_step_per_rollout = 3
user_sigma = 5
decay_factor = 0.9

x = torch.tensor([[0.25, 0.25], [0.75, 0.75]]) * max_coord
b = torch.ones(n_target) / n_target

a = torch.nn.Parameter(torch.zeros(n_target))

a_opt = torch.optim.Adam([a, ], lr=action_slc_lr)

for _ in range(action_slc_n_epochs):

    old_a = a.clone()
    a_opt.zero_grad()
    
    # -----------------------------------
    # Build action plans
    # -----------------------------------
    
    first_action = torch.sigmoid(a)
    
    action_plan = torch.zeros((n_rollout, n_step_per_rollout, n_target))
    action_plan[:, 0] = first_action
    if n_step_per_rollout > 1:
        action_plan[:, 1:, :] = torch.rand((n_rollout, n_step_per_rollout-1, n_target)) * 360
        
    action_plan *= 360  # Convert in degrees
    
    total_efe = 0
    for rol in range(n_rollout):
        
        print("Rollout", _)
        
        efe_rollout = 0
        
        # Sample the user goal --------------------------
        
        q = torch.softmax(b, dim=0)
        goal = torch.multinomial(q, 1)[0]
        
        # -----------------------------------------------

        x_rol = x.clone()
        b_rol = b.clone()

        for step in range(n_step_per_rollout):
            
            print("STEP", step)

            action = action_plan[rol, step]

            # ---- Update positions based on action ---------------------------------------------

            x_rol_prev = x_rol.clone()

            for i in range(n_target):
                angle = action[i]
                if 90 < angle <= 270:
                    x_prime = -1
                else:
                    x_prime = 1
                y_prime = torch.tan(torch.deg2rad(angle)) * x_prime

                norm = mvt_amplitude / torch.sqrt(y_prime**2 + x_prime**2)
                movement = torch.tensor([x_prime, y_prime]) * norm
                x_rol[i] += movement

            for coord in range(2):
                for target in range(n_target):
                    if x_rol[target, coord] > max_coord[coord]:
                        x_rol[target, coord] = max_coord[coord]

            # ------------------------------------------------------------------------------            
            # Evaluate epistemic value -----------------------------------------------------
            # ------------------------------------------------------------------------------

            # Simulate action based on goal ----------------------------------------

            delta = x_rol[goal] - x_rol_prev[goal]
            noise = torch.randn(2) * user_sigma
            y = delta + noise
            
            print("y", y.detach().numpy())

            # Compute log probability of user action given a specific goal in mind -------
            logp_y = torch.zeros(n_target)
            for target in range(n_target):
                for coord in range(2):
                    d = x_rol[target, coord] - x_rol_prev[target, coord]
                    logp_coord = torch.distributions.Normal(d, user_sigma).log_prob(y[coord])
                    logp_y[target] += logp_coord
            
            logq = torch.log_softmax(b_rol - b_rol.detach().max(), dim=0)
            logp_yq = logq + logp_y

            # Revise belief -------------------
            
            print("b_rol", b_rol.detach().numpy())

            b_rol = torch.nn.Parameter(b_rol)
            b_opt = torch.optim.Adam([b_rol, ], lr=belief_updt_lr)

            for _ in range(belief_updt_n_epochs):

                old_b = b_rol.clone()
                b_opt.zero_grad()
                q_rol = torch.softmax(b_rol - b_rol.detach().max(), dim=0)
                print("q_rol", q_rol.detach().numpy())
                kl_div = torch.sum(q_rol * (q_rol.log() - logp_yq))
                kl_div.backward(retain_graph=True)
                b_opt.step()

                if torch.isclose(old_b, b_rol).all():
                    print("converged")
                    break
                    
            
            # b_rol = b_rol.detach()
            
            print("b_rol after", b_rol.detach().numpy())
            
            epistemic_value = kl_div
            
            # --------------------------------------
            # Compute extrinsic value 
            # --------------------------------------
            
            extrinsic_value = (q_rol * q_rol.log()).sum()  # minus entropy
            
            # --------------------------------------
            # Compute loss
            #---------------------------------------
            print("extrinsic_value", extrinsic_value)
            print("epistemic_value", epistemic_value)
            efe_step = - epistemic_value - extrinsic_value

            efe_rollout += decay_factor**step * efe_step

        total_efe += efe_rollout
    
    total_efe /= n_rollout
    print("big update, EFE=", total_efe)
    total_efe.backward()
    a_opt.step()
    
    print("new a", a.detach().numpy())

#     if torch.isclose(old_a, a).all():
#         print("converged")
#         break


Rollout 0
STEP 0
y [-6.2871294 -6.0726624]
b_rol [0.5 0.5]
q_rol [0.5 0.5]
converged
b_rol after [0.5 0.5]
extrinsic_value tensor(-0.6931, grad_fn=<SumBackward0>)
epistemic_value tensor(6.0104, grad_fn=<SumBackward0>)
STEP 1
y [ 0.28643608 10.26171   ]
b_rol [0.5 0.5]
q_rol [0.5 0.5]
q_rol [0.549834 0.450166]
q_rol [0.59833586 0.40166414]
q_rol [0.6442675 0.3557325]
b_rol after [0.89200085 0.10799922]
extrinsic_value tensor(-0.6509, grad_fn=<SumBackward0>)
epistemic_value tensor(7.3182, grad_fn=<SumBackward0>)
STEP 2
y [ 1.2001015 -0.2814257]
b_rol [0.89200085 0.10799922]
q_rol [0.6865419  0.31345809]
q_rol [0.72790146 0.27209854]
q_rol [0.71215683 0.2878431 ]
q_rol [0.6902332 0.3097669]
b_rol after [0.8748524  0.12514672]
extrinsic_value tensor(-0.6189, grad_fn=<SumBackward0>)
epistemic_value tensor(5.1732, grad_fn=<SumBackward0>)
Rollout 0
STEP 0
y [-1.9905651  1.5861751]
b_rol [0.5 0.5]
q_rol [0.5 0.5]
converged
b_rol after [0.5 0.5]
extrinsic_value tensor(-0.6931, grad_fn=<SumBackw

q_rol [0.5883162 0.4116838]
q_rol [0.6020702  0.39792985]
b_rol after [0.6964829  0.30351797]
extrinsic_value tensor(-0.6722, grad_fn=<SumBackward0>)
epistemic_value tensor(6.5347, grad_fn=<SumBackward0>)
STEP 1
y [7.4852743 4.4156184]
b_rol [0.6964829  0.30351797]
q_rol [0.59699625 0.40300375]
q_rol [0.6440394 0.3559606]
q_rol [0.60929424 0.3907058 ]
q_rol [0.57457227 0.4254277 ]
b_rol after [0.6222673  0.37773266]
extrinsic_value tensor(-0.6820, grad_fn=<SumBackward0>)
epistemic_value tensor(6.2933, grad_fn=<SumBackward0>)
STEP 2
y [-5.094171 -5.329057]
b_rol [0.6222673  0.37773266]
q_rol [0.56083083 0.43916914]
q_rol [0.609339 0.390661]
q_rol [0.65376174 0.3462382 ]
q_rol [0.6909426  0.30905744]
b_rol after [0.9674953  0.03250469]
extrinsic_value tensor(-0.6183, grad_fn=<SumBackward0>)
epistemic_value tensor(6.8435, grad_fn=<SumBackward0>)
Rollout 0
STEP 0
y [-0.4623363  1.8480606]
b_rol [0.5 0.5]
q_rol [0.5 0.5]
q_rol [0.45016605 0.54983395]
q_rol [0.4741107 0.5258893]
q_rol [0.501

In [47]:
p = torch.tensor([0.5, 0.5])
- (p*p.log2()).sum()

tensor(1.)

In [44]:
p.log().mean()

tensor(-5.7565)