In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

class EpisodicBufferO(Dataset):
    def __init__(self, state_dim, num_actions, horizon, buffer_size=0):
        self.max_size = int(buffer_size)
        self.horizon = horizon
        self.state = torch.zeros((self.max_size, horizon, state_dim))
        self.action = torch.zeros((self.max_size, horizon, 1), dtype=torch.long)
        self.reward = torch.zeros((self.max_size, horizon, 1))
        self.not_done = torch.zeros((self.max_size, horizon, 1))
        self.pibs = torch.zeros((self.max_size, horizon, num_actions))
        self.estm_pibs = torch.zeros((self.max_size, horizon, num_actions))
    
    def __len__(self):
        return len(self.state)
    
    def __getitem__(self, idx):
        return (
            self.state[idx],
            self.action[idx],
            self.reward[idx],
            self.not_done[idx],
            self.pibs[idx],
            self.estm_pibs[idx],
        )
    
    def load(self, filename):
        data = torch.load(filename)
        self.state = data['statevecs'][:, :-1, :]
        self.action = data['actions'][:, 1:].unsqueeze(-1)  # Need to offset by 1 so that we predict actions that have not yet occurred
        self.reward = data['rewards'][:, 1:].unsqueeze(-1)  # Need to offset by 1
        self.not_done = data['notdones'][:, 1:].unsqueeze(-1)
        self.pibs = data['pibs'][:, :-1, :]
        self.estm_pibs = data['estm_pibs'][:, :-1, :]
        print(f"Episodic Buffer loaded with {len(self)} episides.")


In [3]:
state_dim = 64
num_actions = 25
horizon = 20

In [7]:
from types import SimpleNamespace

def remap_rewards(R, args):
    R = np.select([R == 0, R == -1, R == 1], [args.R_immed, args.R_death, args.R_disch,], R)
    return torch.tensor(R)

In [8]:
test_episodes_O = EpisodicBufferO(state_dim, num_actions, horizon)
test_episodes_O.load('../data/episodes+encoded_state+knn_pibs/test_data.pt')
test_episodes_O.reward = remap_rewards(test_episodes_O.reward, SimpleNamespace(**{'R_immed': 0.0, 'R_death': 0.0, 'R_disch': 100.0}))

tmp_test_episodes_loader_O = DataLoader(test_episodes_O, batch_size=len(test_episodes_O), shuffle=False)
test_batch_O = next(iter(tmp_test_episodes_loader_O))

Episodic Buffer loaded with 2894 episides.


In [None]:
# get knn highest probability action index, check agreement with 

In [11]:
states, actions, rewards, not_dones, pibs, estm_pibs = test_batch_O
rewards = rewards[:, :, 0].cpu().numpy()
n, horizon, _ = states.shape
discounted_rewards = rewards * (1.0 ** np.arange(horizon))

In [15]:
estm_pibs.argmax(dim=2).shape

torch.Size([2894, 20])

In [16]:
actions.shape

torch.Size([2894, 20, 1])

In [19]:
# rough estimate
(estm_pibs.argmax(dim=2) == actions.squeeze()).to(float).mean()

tensor(0.7347, dtype=torch.float64)

In [21]:
# top 1 knn action
cnt_match, cnt_all = 0.0, 0.0
for idx in range(n):
    lng = (not_dones[idx, :, 0].sum() + 1).item()  # all but the final transition has notdone==1
    a_obs = actions[idx, :lng, 0]
    a_prd = estm_pibs[idx, :lng].argmax(dim=-1)
    cnt_all += lng
    cnt_match += (a_obs == a_prd).to(float).sum()

print(cnt_match/cnt_all)

tensor(0.5716, dtype=torch.float64)


In [33]:
# top 2 knn actions
cnt_match, cnt_all = 0.0, 0.0
for idx in range(n):
    lng = (not_dones[idx, :, 0].sum() + 1).item()  # all but the final transition has notdone==1
    a_obs = actions[idx, :lng, 0]
    a_prd = torch.argsort(estm_pibs[idx, :lng], descending=True, dim=-1)
    cnt_all += lng
    cnt_match += (a_obs.unsqueeze(1) == a_prd[:, :2]).to(float).sum(dim=1).sum()

print(cnt_match/cnt_all)

tensor(0.7445, dtype=torch.float64)


In [32]:
# top 5 knn actions
cnt_match, cnt_all = 0.0, 0.0
for idx in range(n):
    lng = (not_dones[idx, :, 0].sum() + 1).item()  # all but the final transition has notdone==1
    a_obs = actions[idx, :lng, 0]
    a_prd = torch.argsort(estm_pibs[idx, :lng], descending=True, dim=-1)
    cnt_all += lng
    cnt_match += (a_obs.unsqueeze(1) == a_prd[:, :5]).to(float).sum(dim=1).sum()

print(cnt_match/cnt_all)

tensor(0.8752, dtype=torch.float64)


In [40]:
# top 10 knn actions
cnt_match, cnt_all = 0.0, 0.0
for idx in range(n):
    lng = (not_dones[idx, :, 0].sum() + 1).item()  # all but the final transition has notdone==1
    a_obs = actions[idx, :lng, 0]
    a_prd = torch.argsort(estm_pibs[idx, :lng], descending=True, dim=-1)
    cnt_all += lng
    cnt_match += (a_obs.unsqueeze(1) == a_prd[:, :10]).to(float).sum(dim=1).sum()

print(cnt_match/cnt_all)

tensor(0.9349, dtype=torch.float64)


In [22]:
cnt_match, cnt_all = 0.0, 0.0
for idx in range(n):
    lng = (not_dones[idx, :, 0].sum() + 1).item()  # all but the final transition has notdone==1
    a_obs = actions[idx, :lng, 0]
    a_prd = pibs[idx, :lng].argmax(dim=-1)
    cnt_all += lng
    cnt_match += (a_obs == a_prd).to(float).sum()

print(cnt_match/cnt_all)

tensor(0.5783, dtype=torch.float64)


In [None]:
# # reference of WIS loop
# for idx in range(n):
#     lng = (not_dones[idx, :, 0].sum() + 1).item()  # all but the final transition has notdone==1

#     # Predict Q-values and Imitation probabilities
#     q, _, i = self.Q(states[idx])
#     imt = F.log_softmax(i.reshape(-1, 2, 5), dim=-1).exp()
#     imt = (imt / imt.max(axis=-1, keepdim=True).values > self.threshold).float()

#     # Factored action remapping
#     q = q @ self.all_subactions_vec.T
#     imt = torch.einsum('bi,bj->bji', (imt[:,0,:], imt[:,1,:])).reshape(-1, 25)

#     # Use large negative number to mask actions from argmax
#     a_id = (imt * q + (1. - imt) * torch.finfo().min).argmax(axis=1).cpu().numpy()
#     pie_soft = np.zeros((horizon, 25))
#     pie_soft += eps * estm_pibs[idx].cpu().numpy() # Soften using training behavior policy
#     pie_soft[range(horizon), a_id] += (1.0 - eps)

#     # Compute importance sampling ratios
#     a_obs = actions[idx, :, 0]
#     ir[idx, :lng] = pie_soft[range(lng), a_obs[:lng].cpu().numpy()] / pibs[idx, range(lng), a_obs[:lng]].cpu().numpy()
#     ir[idx, lng:] = 1  # Mask out the padded timesteps
