In [1]:
from tqdm.notebook import tqdm
import math
import gym
import torch
import torch.optim as optim 
from torch.utils.tensorboard import SummaryWriter
from collections import deque
import numpy as np

from networks.dqn_atari import DQN
from active_rl.utils.memory import ReplayMemory
from active_rl.utils.optimization import standard_optimization
from active_rl.environments.atari_wrappers import make_atari, wrap_deepmind
from active_rl.utils.atari_utils import fp, ActionSelector, evaluate
from active_rl.utils.acquisition_functions import ens_BALD
from active_rl.statistics.q_n2s import compute_q_n2s_ratio

In [2]:
env_name = 'Breakout'
env_raw = make_atari('{}NoFrameskip-v4'.format(env_name))
env = wrap_deepmind(env_raw, frame_stack=False, episode_life=True, clip_rewards=True)
c,h,w = c,h,w = fp(env.reset()).shape
n_actions = env.action_space.n

In [3]:
BATCH_SIZE = 64
LR = 0.0000625
GAMMA = 0.99
EPS_START = 0.01
EPS_END = 0.01
EPS_DECAY = 1 
NUM_STEPS = 20000000
POLICY_UPDATE = 4
TARGET_UPDATE= 4000
EVAL_N2S_RATIO = 5000
MEMORY_CAPACITY = 500000
INITIAL_STEPS = 1000
NUM_NETS = 5

NAME = 'hypothesis_q_val_convergence'

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # if gpu is to be used
policy_nets = []
target_nets = []
optimizers = []
for i in range(NUM_NETS):
  policy_net = DQN(n_actions).to(device)
  target_net = DQN(n_actions).to(device)
  policy_net.apply(policy_net.init_weights)
  target_net.load_state_dict(policy_net.state_dict())
  target_net.eval()
  optimizer = optim.Adam(policy_net.parameters(), lr=LR, eps=1.5e-4)
  policy_nets.append(policy_net)
  target_nets.append(target_net)
  optimizers.append(optimizer)
pretrained_net = torch.load("models/dqn_expert_breakout_model").to(device)

In [5]:
memory = ReplayMemory(MEMORY_CAPACITY, [5,h,w], n_actions, device)
action_selector = ActionSelector(EPS_START, EPS_END, pretrained_net, EPS_DECAY, n_actions, device)

In [6]:
steps_done = 0
writer = SummaryWriter(f'runs/{NAME}')

In [7]:
q = deque(maxlen=5)
done=True
eps = 0
episode_len = 0

In [None]:
progressive = tqdm(range(NUM_STEPS), total=NUM_STEPS, ncols=400, leave=False, unit='b')
for step in progressive:
  if done:
    env.reset()
    sum_reward = 0
    episode_len = 0
    img, _, _, _ = env.step(1) # BREAKOUT specific !!!
    for i in range(10): # no-op
      n_frame, _, _, _ = env.step(0)
      n_frame = fp(n_frame)
      q.append(n_frame)
        
   # Select and perform an action
  state = torch.cat(list(q))[1:].unsqueeze(0)
  action, eps = action_selector.select_action(state)
  n_frame, reward, done, info = env.step(action)
  n_frame = fp(n_frame)

  # 5 frame as memory
  q.append(n_frame)
  memory.push(torch.cat(list(q)).unsqueeze(0), action, reward, done) # here the n_frame means next frame from the previous time step
  episode_len += 1
    
  if step % POLICY_UPDATE == 0 and step > INITIAL_STEPS:
    total_loss = 0
    for i in range(NUM_NETS):
      loss = standard_optimization(policy_nets[i], target_nets[i], optimizers[i], memory, 
                                   batch_size=BATCH_SIZE, device=device)
      total_loss += loss
    avg_loss = total_loss / NUM_NETS
    if writer is not None:
      writer.add_scalar('Performance/loss', avg_loss, step)
      
  if step % TARGET_UPDATE == 0 and step > INITIAL_STEPS:
    for i in range(NUM_NETS):
      target_nets[i].load_state_dict(policy_nets[i].state_dict())
  
  if step % EVAL_N2S_RATIO == 0 and step > INITIAL_STEPS:
    states, actions, _, _, _ = memory.sample(BATCH_SIZE)
    action_n2s, avg_not_action_n2s, na_to_a_ratio = compute_q_n2s_ratio(policy_nets, states, actions)
    writer.add_scalar('Hypothesis/action_n2s', action_n2s, step)
    writer.add_scalar('Hypothesis/not_action_n2s', avg_not_action_n2s, step)
    writer.add_scalar('Hypothesis/not_action_to_action_ratio', na_to_a_ratio, step)

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=20000000.0), HTML(value='')), layout=Layo…