In [1]:
import torch
import gym

In [3]:
env = gym.make("CartPole-v1")

In [10]:
state_dim = env.reset().shape[0]
h1_in = 60
h1_out = 60
n_actions = 2

In [11]:
model = torch.nn.Sequential(
            torch.nn.Linear(state_dim, h1_in),
            torch.nn.ReLU(),
            torch.nn.Linear(h1_in, h1_out),
            torch.nn.ReLU(),
            torch.nn.Linear(h1_out, n_actions)
)

In [106]:
mem_size = 100
batch_size = 20
loss_func = torch.nn.MSELoss()

In [129]:
state_batch = torch.empty((batch_size, state_dim))
next_state_batch = torch.empty((batch_size, state_dim))

action_batch = torch.empty((batch_size, 1))
reward_batch = torch.empty((batch_size, 1))
done_batch = torch.empty((batch_size, 1))

for e in range(1000):
    state = torch.from_numpy(env.reset())
    q_vals = model(state)
    action = torch.argmax(q_vals).item()
    
    next_state, reward, done, _ = env.step(action)
    
    if e < batch_size:
        state_batch[e] = state
        next_state_batch[e] = torch.from_numpy(next_state)

        action_batch[e] = action
        reward_batch[e] = reward
        done_batch[e] = done
        
        continue
    
        
    elif e >= mem_size:
        idx = e - (mem_size * (e // mem_size))
        
        state_batch[idx] = state
        next_state_batch[idx] = torch.from_numpy(next_state)

        action_batch[idx] = action
        reward_batch[idx] = reward
        done_batch[idx] = done

    else:
        state_batch = torch.cat((state_batch, torch.unsqueeze(state, dim=0)), dim=0)
        
        next_state = torch.from_numpy(next_state)
        next_state_batch = torch.cat((next_state_batch, torch.unsqueeze(next_state, dim=0)), dim=0)

        action_batch = torch.cat((action_batch, torch.unsqueeze(torch.tensor([action]), dim=0)), dim = 0)
        reward_batch = torch.cat((reward_batch, torch.unsqueeze(torch.Tensor([reward]), dim=0)), dim = 0)
        done_batch = torch.cat((done_batch, torch.unsqueeze(torch.Tensor([done]), dim=0)), dim = 0)
        
    random_indices = torch.randperm(action_batch.size(0))[:batch_size]
    
    q_vals = model(state_batch[random_indices])
    
    with torch.no_grad():
        next_q_vals = model(next_state_batch[random_indices])
    
    y_hat = reward_batch[random_indices].squeeze() + 0.99*((1 - done_batch[random_indices]).squeeze() * torch.max(next_q_vals, dim=1)[0])
    y = q_vals.gather(dim=1, index=action_batch[random_indices].long()).squeeze()
    
    loss = loss_func(y, y_hat)
    

In [130]:
y_hat.shape

torch.Size([20])

In [131]:
y.shape

torch.Size([20])

In [121]:
torch.max(next_q_vals, dim=1)[0]

tensor([-0.1324, -0.1366, -0.1319, -0.1319, -0.1325, -0.1381, -0.1349, -0.1383,
        -0.1351, -0.1315, -0.1316, -0.1354, -0.1343, -0.1341, -0.1326, -0.1342,
        -0.1343, -0.1351, -0.1350, -0.1347])

In [123]:
done_batch[random_indices].squeeze()

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [128]:
reward_batch[random_indices].squeeze()

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.])

In [97]:
random_indices = torch.randperm(action_batch.size(0))[:batch_size]

In [99]:
state_batch[random_indices]

tensor([[-0.0416,  0.0145, -0.0046,  0.0005],
        [-0.0493, -0.0073,  0.0488, -0.0389],
        [-0.0007, -0.0101,  0.0003, -0.0284],
        [-0.0149,  0.0474, -0.0225, -0.0311],
        [ 0.0190, -0.0205, -0.0273,  0.0281],
        [-0.0216, -0.0040, -0.0256,  0.0100],
        [-0.0104,  0.0104, -0.0132, -0.0446],
        [-0.0409,  0.0477,  0.0138, -0.0273],
        [ 0.0102, -0.0423,  0.0332,  0.0462],
        [-0.0231,  0.0452, -0.0144,  0.0348],
        [-0.0311,  0.0152,  0.0057, -0.0378],
        [-0.0427,  0.0253,  0.0421,  0.0162],
        [-0.0194,  0.0420, -0.0220, -0.0297],
        [-0.0480,  0.0176,  0.0118,  0.0230],
        [-0.0425,  0.0382,  0.0073,  0.0297],
        [ 0.0384, -0.0376,  0.0464, -0.0426],
        [-0.0360, -0.0114, -0.0359,  0.0060],
        [ 0.0396,  0.0203, -0.0165, -0.0270],
        [ 0.0281, -0.0245, -0.0180, -0.0235],
        [ 0.0045,  0.0113,  0.0145, -0.0107]])

In [None]:
# 1. make empty tensor with shape = (batch_size, x)
# 2. fill tensor by indexing until full. Is full when is filled with batch_size number of rows.
# 3. stack tensor untill shape is (mem_size, x)
# 4. add new rows to tensor by indexing. Every time index == mem_size -> index = 0 