In [2]:
from tqdm.notebook import tqdm
import math
import gym
import torch
import torch.optim as optim 
from torch.utils.tensorboard import SummaryWriter
from collections import deque

from active_rl.networks.dqn_atari import ENS_DQN
from active_rl.utils.memory import RankedReplayMemory, LabeledReplayMemory
from active_rl.utils.optimization import AMN_optimization_ensemble
from active_rl.environments.atari_wrappers import make_atari, wrap_deepmind
from active_rl.utils.atari_utils import fp, ActionSelector, evaluate
from active_rl.utils.acquisition_functions import ens_BALD

In [6]:
env_name = 'Breakout'
env_raw = make_atari('{}NoFrameskip-v4'.format(env_name))
env = wrap_deepmind(env_raw, frame_stack=False, episode_life=False, clip_rewards=True)
c,h,w = c,h,w = fp(env.reset()).shape
n_actions = env.action_space.n

In [7]:
BATCH_SIZE = 64
LR = 0.0000625
GAMMA = 0.99
EPS = 0.05
NUM_STEPS = 10000000
NOT_LABELLED_CAPACITY = 10000
LABELLED_CAPACITY = 200000
INITIAL_STEPS=NOT_LABELLED_CAPACITY
TRAINING_PER_LABEL = 80.
PERCENTAGE = 0.005
TRAINING_ITER = int(TRAINING_PER_LABEL * NOT_LABELLED_CAPACITY * PERCENTAGE)

NAME = 'Evaluate expert'

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # if gpu is to be used
# AMN_net = MC_DQN(n_actions).to(device)
AMN_net = ENS_DQN(n_actions).to(device)
expert_net = torch.load("models/dqn_expert_breakout_model").to(device)
AMN_net.apply(AMN_net.init_weights)
expert_net.eval()
# optimizer = optim.Adam(AMN_net.parameters(), lr=LR, eps=1.5e-4)
optimizer = optim.Adam(AMN_net.parameters(), lr=LR, eps=1.5e-4)

In [11]:
steps_done = 0
writer = SummaryWriter(f'runs/{NAME}')

In [12]:
q = deque(maxlen=5)
done=True
eps = 0
episode_len = 0
num_labels = 0

In [None]:
progressive = tqdm(range(NUM_STEPS), total=NUM_STEPS, ncols=400, leave=False, unit='b')
for step in progressive:
  evaluated_reward = evaluate(step, expert_net, device, env_raw, n_actions, eps=0.05, num_episode=15)
  writer.add_scalar('reward', evaluated_reward, step)

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=10000000.0), HTML(value='')), layout=Layo…