In [1]:
import numpy as np
from scipy.signal import savgol_filter
from scipy import linalg
import torch 

In [2]:
from config import config_categ
from utils.reproducibility import set_seed
from agents.DQN import DQNAgent
from agents.LaplaceDQN import LaplaceDQNAgent
from envs.four_state_mdp import Simple4MDP

import gym

In [3]:
game = 'CartPole-v1'
log_dir = 'logs/tmp'

env = gym.make(game, render_mode=None)
config = config_categ[game](env, log_dir)

In [4]:
config.num_gamma

5

In [5]:
print(config.BATCH_SIZE)
print(config.num_gamma)
print(config.num_sensitivities)
print(config.action_dim)
print(config.num_gamma_to_tau)

128
5
500
2
100


In [6]:
batch_size, num_actions, num_gamma_to_tau, num_sensitivities = 1, config.action_dim, config.num_gamma_to_tau, config.num_sensitivities
Q_gamma = torch.randn(size=(batch_size, num_actions, num_gamma_to_tau, num_sensitivities))
# Q_gamma = np.random.normal(size=(batch_size, num_actions, num_gamma_to_tau, num_sensitivities))
print(Q_gamma.shape)

torch.Size([1, 2, 100, 500])


In [9]:
from utils.Inverse_Laplace import SVD_approximation_inverse_Laplace, SVD_approximation_inverse_Laplace_iterative

In [8]:
tau_space = SVD_approximation_inverse_Laplace(config, Q_gamma,)

  gammas_to_tau_tensor = torch.tensor(gammas_to_tau)


In [12]:
tau_space_iter = SVD_approximation_inverse_Laplace_iterative(config, Q_gamma.numpy())

In [29]:
torch.allclose(tau_space_iter, tau_space, atol=1e-5)

True

In [30]:
torch.abs(tau_space_iter-tau_space).sum()

tensor(0.1269)

In [None]:
print(tau_space.shape)

torch.Size([1, 2, 200, 498])


In [165]:
# def SVD_approximation_inverse_Laplace(config, Q_gamma):
"""
SVD-based approximation of the inverse Laplace transform 
Input:
config: configuration object from config.py
Q_gamma: Q-values for different gamma values and sensitivities, 
            shape (batch_size, num_gamma, num_sensitivities)
"""

alpha_reg = config.alpha_reg
K = config.K
delta_t = config.delta_t

#batch_size = config.BATCH_SIZE # NOTE maybe not at the beginning of the training - then assume change in time horizon takes place after that
batch_size = Q_gamma.shape[0]
num_sensitivities = config.num_sensitivities
num_actions = config.action_dim
num_gamma_to_tau = config.num_gamma_to_tau
gamma_to_tau_min = config.gamma_to_tau_min
gamma_to_tau_max = config.gamma_to_tau_max
start = 1 / np.log(gamma_to_tau_min) 
end = 1 / np.log(gamma_to_tau_max)   
gammas_to_tau = torch.exp(torch.true_divide(1, torch.linspace(start, end, num_gamma_to_tau)))

assert Q_gamma.shape == (batch_size, num_actions, num_gamma_to_tau, num_sensitivities), "Q_gamma shape does not match (num_gamma, num_sensitivities)"
# finish extending to actions

In [166]:
#define matrix F:
F=torch.zeros((len(gammas_to_tau),K))
for i_g in range(0,len(gammas_to_tau)):
    for i_t in range(0,K):
        F[i_g,i_t]=gammas_to_tau[i_g]**(i_t*delta_t)

In [167]:
# MATRIX OPERATIONS

#define matrix F:
F_parallel = torch.zeros((len(gammas_to_tau), K))
gammas_to_tau_tensor = torch.tensor(gammas_to_tau)
delta_t_tensor = torch.tensor(delta_t)

# Perform matrix operations instead of loops
F_parallel = gammas_to_tau_tensor.unsqueeze(1).pow(torch.arange(K).float() * delta_t_tensor)

# F_numpy = F.numpy()  # Convert the tensor back to a numpy array if needed

  gammas_to_tau_tensor = torch.tensor(gammas_to_tau)


In [193]:
torch.allclose(F_parallel, F, atol=1e-21)

True

In [169]:
U_arr, lam_arr, V_arr = linalg.svd(F)

print(U_arr.shape)
print(lam_arr.shape)
print(V_arr.shape)
print(F.shape)

(100, 100)
(100,)
(200, 200)
torch.Size([100, 200])


In [170]:
# U, lam, V = linalg.svd(F) #SVD decomposition of F
U, lam, V = torch.linalg.svd(F) #SVD decomposition of F
print(U.shape)
print(lam.shape)
print(V.shape)

# set up gamma-space:
Z=Q_gamma[:,:,:,0:-2]-Q_gamma[:,:,:,1:-1]

# smooth gamma-space (it might not be necessary, it helps if the input is *very* noisy):
# for h in range(0,num_h-2):
#     Z[:,h]=savgol_filter(Z[:,h], 5, 1)

torch.Size([100, 100])
torch.Size([100])
torch.Size([200, 200])


In [171]:

# Linearly recover tau-space from eigenspace of F:
tau_space=torch.zeros((batch_size, num_actions, K, num_sensitivities-2))
# do for several batches and actions-> parallelize TODO
for batch in range(batch_size):
    for act in range(num_actions):
        for h in range(0,num_sensitivities-2):
            term=torch.zeros((1,K))
            for i in range(0,len(lam)):
                fi=lam[i]**2/(alpha_reg**2+lam[i]**2)
                new=fi*(((U[:,i]@Z[batch,act,:,h])*V[i,:] )/lam[i])
                term=term+new
            tau_space[batch,act,:,h]=term

In [172]:
print(V[i,:].shape)
print((U[:,i]@Z[batch,act,:,h]).shape)

torch.Size([200])
torch.Size([])


In [173]:
print(term.shape)
print(new.shape)
print(tau_space.shape)

torch.Size([1, 200])
torch.Size([200])
torch.Size([1, 2, 200, 498])


In [174]:

# Linearly recover tau-space from eigenspace of F:
tau_space_paralell=torch.zeros((batch_size, num_actions, K, num_sensitivities-2))
# do for several batches and actions-> parallelize TODO
for batch in range(batch_size):
    for act in range(num_actions):
        for h in range(0,num_sensitivities-2):
            term=torch.zeros((1,K))
            for i in range(0,len(lam)):
                fi=lam[i]**2/(alpha_reg**2+lam[i]**2)
                new=fi*(((U[:,i] @ Z[batch,act,:,h]) * V[i,:] )/lam[i])
                term=term+new
            
            tau_space_paralell[batch,act,:,h]=term

In [142]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Compute fi values
fi = lam**2 / (alpha_reg**2 + lam**2)
fi = fi.reshape(1,1,-1,1)
print(fi.shape)


torch.Size([1, 1, 100, 1])


In [175]:
fi = lam**2 / (alpha_reg**2 + lam**2)
# Expand tensors for broadcasting
fi = fi.reshape(1, 1, -1, 1)  # Shape: (1, 1, len(lam), 1)
U_expanded = U.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, K, len(lam))
V_expanded = V.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, len(lam), K)

tau_space_paralell = torch.zeros((batch_size, num_actions, K, num_sensitivities - 2), device=device)
for h in range(num_sensitivities - 2):
    Z_expanded = Z[:, :, :, h].unsqueeze(2)

    tmp = (Z_expanded @ U_expanded).permute(0,1,3,2)
    V_lam = V_expanded[:,:,:len(lam),:] 
    term = (fi * (tmp * V_lam) / lam.reshape(1, 1, -1, 1)).sum(dim=2)
    tau_space_paralell[:, :, :, h] = term

In [189]:
print(tau_space_paralell.shape)
print(tau_space.shape)

torch.Size([1, 2, 200, 498])
torch.Size([1, 2, 200, 498])


In [222]:
torch.arange(3,device='cpu')

tensor([0, 1, 2])

In [224]:
torch.exp(torch.true_divide(1, torch.linspace(start, end, num_gamma_to_tau))).to(device)

tensor([0.0100, 0.4406, 0.6377, 0.7334, 0.7894, 0.8260, 0.8518, 0.8709, 0.8857,
        0.8975, 0.9070, 0.9149, 0.9216, 0.9273, 0.9323, 0.9366, 0.9404, 0.9437,
        0.9467, 0.9494, 0.9519, 0.9541, 0.9561, 0.9580, 0.9597, 0.9612, 0.9627,
        0.9640, 0.9653, 0.9664, 0.9675, 0.9686, 0.9695, 0.9704, 0.9713, 0.9721,
        0.9728, 0.9736, 0.9742, 0.9749, 0.9755, 0.9761, 0.9767, 0.9772, 0.9777,
        0.9782, 0.9787, 0.9791, 0.9795, 0.9799, 0.9803, 0.9807, 0.9811, 0.9814,
        0.9818, 0.9821, 0.9824, 0.9827, 0.9830, 0.9833, 0.9836, 0.9838, 0.9841,
        0.9844, 0.9846, 0.9848, 0.9851, 0.9853, 0.9855, 0.9857, 0.9859, 0.9861,
        0.9863, 0.9865, 0.9867, 0.9868, 0.9870, 0.9872, 0.9873, 0.9875, 0.9876,
        0.9878, 0.9879, 0.9881, 0.9882, 0.9884, 0.9885, 0.9886, 0.9888, 0.9889,
        0.9890, 0.9891, 0.9892, 0.9894, 0.9895, 0.9896, 0.9897, 0.9898, 0.9899,
        0.9900])

In [183]:
torch.sum(~torch.allclose(tau_space_paralell, tau_space))

TypeError: sum(): argument 'input' (position 1) must be Tensor, not int

In [188]:
torch.allclose(tau_space_paralell, tau_space, atol=1e-6)

True

In [182]:
tau_space_paralell[0,0,2,:]

tensor([ 1.5887e-01,  1.5192e+00, -2.3931e-01, -4.9895e-01, -1.2672e+00,
         1.2690e+00,  5.5180e-01, -7.0137e-01, -4.0276e-01,  1.1063e+00,
        -1.7431e+00,  2.2606e-01,  1.8534e-02,  1.1908e+00,  3.4091e-01,
        -1.1821e+00, -7.8780e-01,  3.5548e-01,  1.0009e+00,  1.0632e-01,
        -3.2831e-01,  6.6244e-01, -1.4574e-01, -1.0425e+00, -5.1515e-01,
         2.0787e+00,  6.6962e-02, -1.4101e+00,  3.5819e-01, -2.0816e-01,
        -1.3312e+00, -4.5513e-01,  4.4870e-01,  4.9707e-01,  2.5160e-01,
         1.5117e+00, -2.7640e+00,  1.7142e+00,  3.4535e-01, -9.5751e-01,
         2.1636e+00, -1.3937e+00,  3.3830e-01, -1.8516e-01,  2.8886e-01,
         2.2719e-01, -6.0110e-01, -3.7749e-02,  7.9825e-01, -2.2316e-01,
         5.4351e-01, -2.3285e+00,  1.2120e+00,  1.3499e+00, -1.7468e+00,
        -1.8926e-01, -4.4488e-01,  1.6370e+00, -1.1455e+00,  5.6134e-02,
         6.1739e-01,  2.5897e-01, -1.1239e-02, -2.6212e-01,  3.1307e-02,
        -3.6046e-01,  1.9612e+00, -9.9431e-01, -1.8

In [181]:
tau_space[0,0,2,:]

tensor([ 1.5887e-01,  1.5192e+00, -2.3931e-01, -4.9895e-01, -1.2672e+00,
         1.2690e+00,  5.5180e-01, -7.0137e-01, -4.0276e-01,  1.1063e+00,
        -1.7431e+00,  2.2606e-01,  1.8534e-02,  1.1908e+00,  3.4091e-01,
        -1.1821e+00, -7.8780e-01,  3.5548e-01,  1.0009e+00,  1.0632e-01,
        -3.2831e-01,  6.6244e-01, -1.4574e-01, -1.0425e+00, -5.1515e-01,
         2.0787e+00,  6.6962e-02, -1.4101e+00,  3.5819e-01, -2.0816e-01,
        -1.3312e+00, -4.5513e-01,  4.4870e-01,  4.9707e-01,  2.5160e-01,
         1.5117e+00, -2.7640e+00,  1.7142e+00,  3.4535e-01, -9.5751e-01,
         2.1636e+00, -1.3937e+00,  3.3830e-01, -1.8516e-01,  2.8886e-01,
         2.2719e-01, -6.0110e-01, -3.7749e-02,  7.9825e-01, -2.2316e-01,
         5.4351e-01, -2.3285e+00,  1.2120e+00,  1.3499e+00, -1.7468e+00,
        -1.8926e-01, -4.4488e-01,  1.6370e+00, -1.1455e+00,  5.6134e-02,
         6.1739e-01,  2.5897e-01, -1.1238e-02, -2.6212e-01,  3.1307e-02,
        -3.6046e-01,  1.9612e+00, -9.9431e-01, -1.8

In [156]:
tau_space_paralell

tensor([[[[-0.0523,  0.0668, -1.1534,  ...,  0.0197,  0.9567, -1.1148],
          [-0.5544, -0.2033,  2.9779,  ..., -1.5160, -2.0545,  4.1092],
          [-0.3861, -0.7471,  1.3966,  ...,  1.7326, -1.7076,  0.4324],
          ...,
          [ 0.1682,  0.1178,  0.3977,  ...,  0.3496, -0.3953,  0.1881],
          [ 0.1698,  0.1175,  0.4022,  ...,  0.3480, -0.3946,  0.1885],
          [ 0.1715,  0.1171,  0.4068,  ...,  0.3464, -0.3939,  0.1888]],

         [[-0.7677,  1.7796, -1.5688,  ..., -0.9761,  1.9610, -2.5358],
          [ 3.4110, -4.5800,  1.1505,  ...,  0.4845, -1.0007,  2.9399],
          [ 1.1078, -1.1725,  0.4343,  ...,  1.4548, -1.1677,  1.0524],
          ...,
          [ 0.7071, -0.3788, -0.1714,  ..., -0.4819,  0.1764,  0.2281],
          [ 0.7112, -0.3798, -0.1743,  ..., -0.4875,  0.1796,  0.2307],
          [ 0.7152, -0.3806, -0.1771,  ..., -0.4931,  0.1829,  0.2332]]]])

In [121]:
(Z_expanded @ U_expanded).shape

torch.Size([1, 2, 1, 100])

In [148]:
tmp = (Z_expanded @ U_expanded).permute(0,1,3,2)
print(tmp.shape)

# take first len(lam) of V
V_lam = V_expanded[:,:,:len(lam),:] 
term = (fi * (tmp * V_lam) / lam.reshape(1, 1, -1, 1)).sum(dim=2)
tau_space_paralell[:, :, :, h] = term

print(term.shape)
print(fi.shape)


torch.Size([1, 2, 100, 1])
torch.Size([1, 2, 200])
torch.Size([1, 1, 100, 1])


In [150]:
tau_space_paralell.shape

torch.Size([1, 2, 200, 498])

In [146]:
term.shape

torch.Size([1, 2, 200])

In [118]:
print(U.shape)
print(V.shape)
print(Z.shape)
print('##############################')

print(fi.shape)
print(U_expanded.shape)
print(Z_expanded.shape)
print(V_expanded.shape)

torch.Size([100, 100])
torch.Size([200, 200])
torch.Size([1, 2, 100, 498])
##############################
torch.Size([100])
torch.Size([1, 1, 100, 100])
torch.Size([1, 2, 1, 100])
torch.Size([1, 1, 200, 200])


In [116]:
U_expanded * Z_expanded @ V_expanded

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [2, 100] but got: [2, 200].

In [117]:
print(K)
print(len(lam))

200
100


In [71]:
print(new.shape)
print(len(lam))

(200,)
100


In [194]:








        
#smooth tau-space (it might not be necessary, use for a smoother visualization):
# for h in range(0,num_h-2):
#     tau_space[:,h]=savgol_filter(tau_space[:,h], 11, 1)

#Normalization (it is not really necessary for this very short temporal horizon T=4):
tau_space[tau_space<0]=0 #make all probabilities positive
# do for several batches and actions -> parallelize TODO
for batch in range(batch_size):
    for a in range(num_actions):
        for i in range(0,K):
            if torch.nansum(tau_space[batch,a,i,:])>0.0:
                tau_space[batch,a,i,:]=tau_space[batch,a,i,:]/torch.nansum(tau_space[batch,a,i,:])
        

In [213]:
torch.allclose(tau_space_paralell, tau_space, atol=1e-6)

False

In [214]:
print(tau_space.shape)

torch.Size([1, 2, 200, 498])


In [219]:
tau_space_paralell_tmp = tau_space_paralell.clone()


tau_space_paralell_tmp[tau_space_paralell_tmp < 0] = 0

sum_tau = torch.nansum(tau_space_paralell_tmp, dim=-1, keepdim=True)
mask = sum_tau > 0
sum_tau[sum_tau == 0] = 1

tau_space_paralell_tmp = tau_space_paralell_tmp / sum_tau

mask_expanded = mask.expand_as(tau_space_paralell_tmp)
tau_space_paralell_tmp[~mask_expanded] = 0

In [220]:
torch.allclose(tau_space_paralell_tmp, tau_space, atol=1e-6)

True

In [216]:
mask.shape

torch.Size([1, 2, 200, 1])

In [217]:
tau_space_paralell_tmp.shape

torch.Size([1, 2, 200, 498])

In [209]:
tau_space_paralell_tmp = tau_space_paralell.clone()

tau_space_paralell_tmp[tau_space_paralell_tmp<0] = 0
sum_tau = torch.nansum(tau_space_paralell_tmp, dim=-1, keepdim=True)
mask = sum_tau > 0
mask_expanded = mask.expand_as(tau_space_paralell_tmp)

tau_space_paralell_tmp[mask_expanded] = tau_space_paralell_tmp[mask_expanded] / sum_tau[mask].expand_as(tau_space_paralell_tmp[mask_expanded])


RuntimeError: The expanded size of the tensor (199200) must match the existing size (400) at non-singleton dimension 0.  Target sizes: [199200].  Tensor sizes: [400]

In [212]:
print(tau_space_paralell_tmp[mask_expanded].shape)
print(sum_tau[mask].expand_as(tau_space_paralell_tmp[mask_expanded]))

torch.Size([199200])


RuntimeError: The expanded size of the tensor (199200) must match the existing size (400) at non-singleton dimension 0.  Target sizes: [199200].  Tensor sizes: [400]

In [207]:
print(sum_tau[mask].shape)
print(tau_space_paralell_tmp[mask].shape)

torch.Size([400])


IndexError: The shape of the mask [1, 2, 200, 1] at index 3 does not match the shape of the indexed tensor [1, 2, 200, 498] at index 3

In [203]:
tau_space_paralell[positv_mask.unsqueeze(-1)].shape

IndexError: The shape of the mask [1, 2, 200, 1] at index 3 does not match the shape of the indexed tensor [1, 2, 200, 498] at index 3

In [200]:
tau_space_paralell.shape

torch.Size([1, 2, 200, 498])

In [201]:
torch.nansum(tau_space_paralell, dim=3).shape

torch.Size([1, 2, 200])

In [202]:
positv_mask.shape

torch.Size([1, 2, 200])

In [204]:
positv_mask.unsqueeze(-1).shape

torch.Size([1, 2, 200, 1])