In [1]:
%config InlineBackend.figure_formats = ['svg']

In [2]:
import matplotlib
matplotlib.rcParams['font.sans-serif'] = ['FreeSans']

In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
state_dim = 64
num_actions = 25
horizon = 20

In [5]:
from a5cpu_BCQ.model import BCQ
from a5cpu_BCQ.data import remap_rewards
from a5cpu_BCQ.data import EpisodicBuffer as EpisodicBufferOO
from a7cpu_BCQf.model import BCQf
from a7cpu_BCQf.data import EpisodicBuffer as EpisodicBufferFF
from a7cpu_BCQf.model import all_subactions_vec

In [6]:
from evaluate import (
    EpisodicBufferO, offline_evaluation_O,
    EpisodicBufferF, offline_evaluation_F,
)

In [7]:
from types import SimpleNamespace
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from joblib import Parallel, delayed

In [8]:
df_1_best = pd.read_csv('best_BCQ_meta.csv')
df_2_best = pd.read_csv('best_BCQf_meta.csv')

In [9]:
df_1_best

Unnamed: 0,index,iteration,val_qvalues,val_wis,val_ess,epoch,created_at,version,threshold,seed
0,97,9800.0,11.247806,92.497391,202.913132,5,2021-11-15 19:47:41.309691,17,0.5,1


In [10]:
df_2_best

Unnamed: 0,index,iteration,val_qvalues,val_wis,val_ess,epoch,created_at,version,threshold,seed
0,90,9100.0,10.518903,93.28846,216.866272,5,2021-11-15 19:51:43.311177,27,0.5,4


In [11]:
df_1_best['version'].item()

17

In [12]:
model_1 = BCQ.load_from_checkpoint(
    checkpoint_path=f'../a5cpu_BCQ/logs/mimic_dBCQ/version_{df_1_best["version"].item()}/checkpoints/step={int(df_1_best["iteration"].item()-1)}.ckpt', 
    map_location=None)
model_1.eval()

model_2 = BCQf.load_from_checkpoint(
    checkpoint_path=f'../a7cpu_BCQf/logs/mimic_dBCQf/version_{df_2_best["version"].item()}/checkpoints/step={int(df_2_best["iteration"].item()-1)}.ckpt', 
    map_location=None)
model_2.eval()
model_2.all_subactions_vec = all_subactions_vec

In [13]:
test_episodes_O = EpisodicBufferO(state_dim, num_actions, horizon)
test_episodes_O.load('../data/episodes+encoded_state+knn_pibs/test_data.pt')
test_episodes_O.reward = remap_rewards(test_episodes_O.reward, SimpleNamespace(**{'R_immed': 0.0, 'R_death': 0.0, 'R_disch': 100.0}))

tmp_test_episodes_loader_O = DataLoader(test_episodes_O, batch_size=len(test_episodes_O), shuffle=False)
test_batch_O = next(iter(tmp_test_episodes_loader_O))

Episodic Buffer loaded with 2894 episides.


In [14]:
test_wis_1, test_ess_1 = model_1.offline_evaluation(test_batch_O, weighted=True, eps=0.01)
test_wis_2, test_ess_2 = offline_evaluation_O(model_2, test_batch_O, weighted=True, eps=0.01)

In [15]:
print(f'Observed Test \t WIS: {test_episodes_O.reward.sum(axis=1).mean():.2f} \t ESS: {test_episodes_O.reward.shape[0]:.2f}')
print(f'Baseline BCQ \t WIS: {test_wis_1:.2f} \t ESS: {test_ess_1:.2f}')
print(f'Factored BCQ \t WIS: {test_wis_2:.2f} \t ESS: {test_ess_2:.2f}')

Observed Test 	 WIS: 90.29 	 ESS: 2894.00
Baseline BCQ 	 WIS: 90.44 	 ESS: 178.32
Factored BCQ 	 WIS: 91.62 	 ESS: 178.32


In [16]:
def bootstrap_test(i, buffer):
    n_epiosdes = len(buffer)
    idx = np.random.default_rng(seed=i).choice(n_epiosdes, n_epiosdes, replace=True)
    batch = buffer[idx]
    return batch[2].sum(axis=1).mean()

In [17]:
def bootstrap_WIS_1(i, buffer):
    n_epiosdes = len(buffer)
    idx = np.random.default_rng(seed=i).choice(n_epiosdes, n_epiosdes, replace=True)
    batch = buffer[idx]
    return model_1.offline_evaluation(batch, weighted=True, eps=0.01)

In [18]:
def bootstrap_WIS_2(i, buffer):
    n_epiosdes = len(buffer)
    idx = np.random.default_rng(seed=i).choice(n_epiosdes, n_epiosdes, replace=True)
    batch = buffer[idx]
    return offline_evaluation_O(model_2, batch, weighted=True, eps=0.01)

## 100 bootstraps

In [19]:
boot_test_reward = Parallel(n_jobs=6)(delayed(bootstrap_test)(i, test_episodes_O) for i in tqdm(range(100)))

100%|██████████| 100/100 [00:06<00:00, 14.77it/s]


In [20]:
boot_test_wis_1, boot_test_ess_1 = zip(*Parallel(n_jobs=6)(delayed(bootstrap_WIS_1)(i, test_episodes_O) for i in tqdm(range(100))))

100%|██████████| 100/100 [06:18<00:00,  3.78s/it]


In [21]:
boot_test_wis_2, boot_test_ess_2 = zip(*Parallel(n_jobs=6)(delayed(bootstrap_WIS_2)(i, test_episodes_O) for i in tqdm(range(100))))

100%|██████████| 100/100 [05:51<00:00,  3.52s/it]


In [22]:
print(f'Observed Test \t ' + 
      f'WIS: {test_episodes_O.reward.sum(axis=1).mean():.2f} ({np.quantile(boot_test_reward, 0.025):.2f}, {np.quantile(boot_test_reward, 0.975):.2f}) \t ' +
      f'ESS: {test_episodes_O.reward.shape[0]:.2f}')
print(f'Baseline BCQ \t ' + 
      f'WIS: {test_wis_1:.2f} ({np.quantile(boot_test_wis_1, 0.025):.2f}, {np.quantile(boot_test_wis_1, 0.975):.2f}) \t ' +
      f'ESS: {test_ess_1:.2f} ({np.quantile(boot_test_ess_1, 0.025):.2f}, {np.quantile(boot_test_ess_1, 0.975):.2f})')
print(f'Factored BCQ \t ' + 
      f'WIS: {test_wis_2:.2f} ({np.quantile(boot_test_wis_2, 0.025):.2f}, {np.quantile(boot_test_wis_2, 0.975):.2f}) \t ' +
      f'ESS: {test_ess_2:.2f} ({np.quantile(boot_test_ess_2, 0.025):.2f}, {np.quantile(boot_test_ess_2, 0.975):.2f})')

Observed Test 	 WIS: 90.29 (89.09, 91.05) 	 ESS: 2894.00
Baseline BCQ 	 WIS: 90.44 (85.95, 94.12) 	 ESS: 178.32 (157.02, 199.87)
Factored BCQ 	 WIS: 91.62 (86.79, 95.04) 	 ESS: 178.32 (155.87, 198.88)


In [24]:
print(f'Observed Test \t ' + 
      f'WIS: {test_episodes_O.reward.sum(axis=1).mean():.2f} ± {np.std(boot_test_reward):.2f} \t ' +
      f'ESS: {test_episodes_O.reward.shape[0]:.2f}')
print(f'Baseline BCQ \t ' + 
      f'WIS: {test_wis_1:.2f} ± {np.std(boot_test_wis_1):.2f} \t ' +
      f'ESS: {test_ess_1:.2f} ± {np.std(boot_test_ess_1):.2f}')
print(f'Factored BCQ \t ' + 
      f'WIS: {test_wis_2:.2f} ± {np.std(boot_test_wis_2):.2f} \t ' +
      f'ESS: {test_ess_2:.2f} ± {np.std(boot_test_ess_2):.2f}')

Observed Test 	 WIS: 90.29 ± 0.51 	 ESS: 2894.00
Baseline BCQ 	 WIS: 90.44 ± 2.44 	 ESS: 178.32 ± 11.42
Factored BCQ 	 WIS: 91.62 ± 2.12 	 ESS: 178.32 ± 11.96


In [25]:
(boot_test_wis_1 < boot_test_wis_2).mean()

1.0