In [1]:
from tqdm.notebook import tqdm
import math
import gym
import torch
import torch.optim as optim 
from torch.utils.tensorboard import SummaryWriter
from collections import deque

from active_rl.networks.dqn_atari import ENS_DQN
from active_rl.utils.memory import LabelledReplayMemory
from active_rl.utils.optimization import AMN_optimization_ensemble
from active_rl.environments.atari_wrappers import make_atari, wrap_deepmind
from active_rl.utils.atari_utils import fp, evaluate, ActionSelector
from active_rl.utils.acquisition_functions import ens_BALD

In [2]:
env_name = 'Seaquest'
env_raw = make_atari('{}NoFrameskip-v4'.format(env_name))
env = wrap_deepmind(env_raw, frame_stack=False, episode_life=True, clip_rewards=True)
c,h,w = c,h,w = fp(env.reset()).shape
n_actions = env.action_space.n

In [3]:
BATCH_SIZE = 64
LR = 0.0000625
GAMMA = 0.99
EPS_START = 1.
EPS_END = 0.05
EPS_DECAY = 400000 
NUM_STEPS = 20000000
LABELLED_MEMORY_CAPACITY = 10000
UNLABELLED_MEMORY_CAPACITY = 10000
BATCH_LABEL_PERCENTAGE=0.1
TRAINING_ITERATIONS= int(10 * LABELLED_MEMORY_CAPACITY /  BATCH_SIZE)

NAME = f"AMN_ens_Bald_{env_name}_ENS_DECAY_{EPS_DECAY}"

In [4]:
device ='cuda:1'
AMN_net = ENS_DQN(n_actions).to(device)
expert_net = torch.load("models/expert_Seaquest_step17000000", map_location=device)
AMN_net.apply(AMN_net.init_weights)
expert_net.eval()
optimizer = optim.Adam(AMN_net.parameters(), lr=LR, eps=1.0e-4)

In [5]:
memory = LabelledReplayMemory(UNLABELLED_MEMORY_CAPACITY, LABELLED_MEMORY_CAPACITY, [5,h,w], n_actions, ens_BALD, AMN_net, device=device)
action_selector = ActionSelector(EPS_START, EPS_END, AMN_net, EPS_DECAY, n_actions, device)

In [6]:
steps_done = 0
num_labels = 0
writer = SummaryWriter(f'runs/{NAME}')

In [None]:
q = deque(maxlen=5)
done=True
progressive = tqdm(range(NUM_STEPS), total=NUM_STEPS, ncols=400, leave=False, unit='b')
for step in progressive:
  if done:
    env.reset()
    sum_reward = 0
    img, _, _, _ = env.step(1) # BREAKOUT specific !!!
    for i in range(10): # no-op
      n_frame, _, _, _ = env.step(0)
      n_frame = fp(n_frame)
      q.append(n_frame)
        
  # Select and perform an action
  state = torch.cat(list(q))[1:].unsqueeze(0)
  action, eps = action_selector.select_action(state)
  n_frame, reward, done, info = env.step(action)
  n_frame = fp(n_frame)

  # 5 frame as memory
  q.append(n_frame)
  memory.push(torch.cat(list(q)).unsqueeze(0), action, reward, done) # here the n_frame means next frame from the previous time step

  # Perform one step of the optimization (on the target network)
  if step % UNLABELLED_MEMORY_CAPACITY == 0 and step > 0:
    num_labels += memory.label_sample(percentage=BATCH_LABEL_PERCENTAGE, batch_size=BATCH_SIZE)
    loss = 0
      
    for _ in range(TRAINING_ITERATIONS):
      loss += AMN_optimization_ensemble(AMN_net, expert_net, optimizer, memory, batch_size=BATCH_SIZE, device=device)
        
    loss /= TRAINING_ITERATIONS
    writer.add_scalar('Performance/loss', loss, step)
      
    evaluated_reward_AMN = evaluate(step, AMN_net, device, env_raw, n_actions, eps=0.05, num_episode=20)
    writer.add_scalar('Performance/reward_vs_step', evaluated_reward_AMN, step)
    writer.add_scalar('Performance/reward_vs_label', evaluated_reward_AMN, num_labels)
    
    evaluated_reward_expert = evaluate(step, expert_net, device, env_raw, n_actions, eps=0.05, num_episode=20)
    writer.add_scalar('Performance/reward_expert_vs_step', evaluated_reward_expert, step)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, layout=Layout(flex='2'), max=20000000.0), HTML(value='…