In [1]:
%load_ext autoreload
%autoreload 2

import joblib
import random
import numpy as np
import pandas as pd
from datetime import datetime, timedelta

import os
import multiprocessing as mp
import gym

import tensorflow as tf
from tensorflow.keras.models import clone_model
from tensorflow.keras.losses import Huber, CategoricalCrossentropy, KLDivergence

import sys
sys.path.append('../script')
from utils import ( 
    preprocess_frame_v4, choose_action, get_lin_anneal_eps, sample_ran_action, 
    play_episode, EpisodeLogger, frame_max_pooling, FRAME_CROP_SETTINGS, 
    run_saliency_map, animate_episode, animate_episode_sal
)
from atari_model import (
    atari_model, atari_model_dueling, atari_model_dueling, atari_model_distr,
    fit_batch_DQNn_PER, fit_batch_DDQNn_PER, fit_batch_DDQNn_PER_DS, train_on_batch
)
from replay_memory import PrioritizedReplayMemory

tf.__version__

'2.7.0'

In [2]:
# Set configuration
total_train_len = 50_000      # Total no. of episodes to train
max_episode_len = None       # Max no. of frames agent is allowed to see per episode [CURRENTLY UNUSED]
state_len = 4                # No. of stacked frames that comprise a state
train_interval = 4           # Every x actions a gradient descend step is performed
tgt_update_interval = 10_000  # Interval in terms of no. of frames after we update target model weights
eps_init = 1                 # Initial eps in eps-greedy exploration 
eps_final = 0.1              # Final eps in eps-greedy exploration
eps_final_frame = 1_000_000    # No. of frame over which eps is linearly annealed to final eps
replay_init_sz = 50_000        # Replay mem. initialization size: random policy is run for this many frames, training starts after
replay_mem_sz = 1_000_000      # Max no. of frames cached in replay memory

batch_sz = 32                # No. of training cases (sample from replay mem.) for each SGD update
disc_rate = 0.99             # Q-learning discount factor (gamma)
n_step = 3                   # Determines multi-step learning (n=1 is simply single step learning)
# lr = 0.0000625               # Learning rate of CNN
lr = 0.00025

per_alpha = 0.5              # Exponent of priority probabilities
# per_beta_rng = [0.4, 1]      # TODO: amend code to handle this
per_beta = 0              # Exponent of importance sampling weights
init_tds = False           # Whether to compute td-errors for initial replay memories
crop_frame = True

# Model variants (Rainbow)
large_net = False
double_learn = True
dueling_net = False
noisy_net = False
distr_net = True

if noisy_net:
    eps_init = eps_final = 0
    
if distr_net:
    N = 51  # No. of atoms for our discretized distr.
    V_min, V_max = -10, 10  # Min and max of distribution support
    Z = np.linspace(V_min, V_max, N)  # Value distribution (i.e. the atoms)
    dZ = (V_max - V_min) / (N - 1)
    Z_repN = np.repeat([Z], N, axis=0)  # Utility matrix to avoid recomputing later on
    tgt_zeroing = False  # Toggles if we set y_tgt to 0 or y_pred for actions that were not taken
    # loss = CategoricalCrossentropy()
    loss = KLDivergence()
else:
    Z = None

In [3]:
# Initialize Atari environment
env = gym.make('SpaceInvadersDeterministic-v4')

# Set seeds
seed = 1234
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
env.seed(seed)
env.action_space.seed(seed)

[1234]

In [4]:
# Initialize logging and storage
game = 'space_invaders'
# model_name = 'DQN_cp_4f_50ri_3n_pr_cf_tf7_cc_dd_pp3_ds'
model_name = 'test1_5'
model_dir = f'../{game}/model/{model_name}'

os.mkdir(model_dir)
os.mkdir(model_dir + '/record')
os.mkdir(model_dir + '/model')

ep_log = EpisodeLogger(model_dir + '/episode_log')

In [5]:
action_space = range(6)  # Use default action space https://www.gymlibrary.dev/environments/atari/space_invaders/
M = len(action_space)
kernel_init = 'he_normal'

# Set frame crop configuration
env.reset()
frame = env.step(1)[0]
crop_lims = FRAME_CROP_SETTINGS[game] if crop_frame else None
frame_shape = preprocess_frame_v4(frame, crop_lims).shape
state_shape = (*frame_shape, state_len)

# Initialize online and behavorial network
if dueling_net:
    model = atari_model_dueling(M, lr, state_shape, kernel_init, noisy_net, large_net)
elif distr_net:
    model = atari_model_distr(N, M, loss, lr, state_shape, kernel_init, noisy_net)
else:
    model = atari_model(M, lr, state_shape, kernel_init, noisy_net, large_net)

model_tgt = clone_model(model)  # Target network
model_tgt.set_weights(model.get_weights())

# Initialize replay memory
replay_mem = PrioritizedReplayMemory(replay_mem_sz, state_len, n_step, per_alpha, per_beta)
frame_num = 0
max_episode_reward = 0
episode_start = 0
init_done = False

In [6]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_frames (InputLayer)      [(None, 95, 65, 4)]  0           []                               
                                                                                                  
 lambda (Lambda)                (None, 95, 65, 4)    0           ['input_frames[0][0]']           
                                                                                                  
 conv1 (Conv2D)                 (None, 22, 15, 16)   4112        ['lambda[0][0]']                 
                                                                                                  
 conv2 (Conv2D)                 (None, 10, 6, 32)    8224        ['conv1[0][0]']                  
                                                                                              

In [7]:
# Continue from existing stored run
# dir = 'DQN_cp_4f_50ri_3n_pr_cf_tf7_cc_dl_dd_pp2'

# replay_mem = joblib.load(f'../{game}/model/{dir}/replay_mem.pkl')

# model = tf.keras.models.load_model(f'../{game}/model/{dir}/model/model_latest.keras')
# model_tgt = clone_model(model)
# model_tgt.set_weights(model.get_weights())

# frame_num = 11393496
# max_episode_reward = 0
# episode_start = 9215
# init_done = True

In [8]:
mp_pool = mp.Pool(1)
t_lastmax = datetime.now()  # Time since last max train score (controls)
i_animation = 0  # Used to toggle between saliency types

for episode_num in range(episode_start, total_train_len):
    # Start a new game (episode)
    init_frame = env.reset()
    new_life = True
    game_over = False

    # Keep track of episode figures
    episode_reward = 0
    episode_train_cnt = 0
    episode_frames = []
    episode_states = []
    episode_actions = []
    episode_a_israns = []
    episode_Qs = []
    episode_pZs = []
    episode_losses = []
    episode_tderrs = []

    # SI specific: skip first 20 frames when new game (40 in total with new-life line)
    for _ in range(20):
        env.step(0)

    # Play episode until game over
    while not game_over:
        if new_life:
            # SI specific: skip first 20 frames when new life
            for _ in range(20):
                env.step(0)

            for _ in range(random.randint(1, 15)):
                # Random initialization, to reduce overfitting
                frame, _, game_over, info = env.step(sample_ran_action(action_space))
            
            lives = info['ale.lives']
            frame_pp = preprocess_frame_v4(frame, crop_lims)  # Maxpooling not needed for first frame
            state = np.stack(state_len * [frame_pp], axis=2)
            
        # Select action
        eps = get_lin_anneal_eps(frame_num - replay_init_sz, eps_init, eps_final, eps_final_frame)
        eps = eps if init_done else 1
        action, a_isran, Q, pZ = choose_action(model, state, action_space, eps, distr_net, Z, ret_stats=True)
        
        # Store state and actions variables
        episode_frames.append(frame)
        episode_states.append(state)
        episode_actions.append(action)
        episode_a_israns.append(a_isran)
        episode_Qs.append(Q)
        episode_pZs.append(pZ)
        
        # Take action and observe transition
        prev_frame = frame  # Keep previous frame for max pooling step
        frame, reward, game_over, info = env.step(action)
        
        # Process env response
        frame_pp = frame_max_pooling([prev_frame, frame])
        frame_pp = preprocess_frame_v4(frame_pp, crop_lims)
        state = np.append(state[:, :, 1:], frame_pp[:, :, None], axis=2)
        # reward = clip_reward(reward)
        new_life = info['ale.lives'] < lives 
        lives = info['ale.lives']
        
        # Add new transition to replay memory
        transition = (action, reward, game_over or new_life, new_life, frame_pp)  # TODO: inspect game + life
        replay_mem.store_memory(transition)
                
        # Increase transition counters
        frame_num += 1
        episode_reward += reward
        init_done = frame_num >= replay_init_sz  # Is replay initializing done
                    
        # After init period start replay transitions and train model
        if init_done:
            
            # # Initialize td-errors of init replay mems
            # if init_tds and frame_num == replay_init_sz:
            #     replay_init_idxs = range(replay_init_sz)
            #     replay_init_ps = replay_mem.get_priorities(replay_init_idxs)
            #     replay_init_idxs = np.flatnonzero(replay_init_ps)
            #     replay_init_mems = replay_mem.get_memories(replay_init_idxs, n_step)[:-1]
            #     replay_init_td_errs = td_error(model, model_tgt, action_space, disc_rate, *replay_init_mems)
            #     replay_mem.update_priorities(replay_init_idxs, replay_init_td_errs)

            # Train model every train_interval
            if frame_num % train_interval == 0:
                mini_batch = replay_mem.get_sample(batch_sz)
                batch_idxs, mini_batch = mini_batch[-1], mini_batch[:-1]
                w_imps = replay_mem.get_imps_weights(batch_idxs)
                
                if distr_net:
                    batch_td_errs, loss = fit_batch_DDQNn_PER_DS(
                        model, model_tgt, action_space, disc_rate, *mini_batch, w_imps, 
                        Z, Z_repN, dZ, (V_min, V_max), tgt_zeroing, noisy_net, double_learn
                    )
                    # print(loss)
                else:
                    batch_td_errs, loss = fit_batch_DDQNn_PER(
                        model, model_tgt, action_space, disc_rate, *mini_batch, 
                        w_imps, noisy_net, double_learn
                    )
                replay_mem.update_priorities(batch_idxs, batch_td_errs)
                # is there balance between updating ps and having new ps set to max? will updated ps stand a chance?
                episode_train_cnt += 1
                episode_losses.append(loss)
                episode_tderrs.append(np.mean(batch_td_errs))

            # Update target model
            if frame_num % tgt_update_interval == 0:
                model_tgt.set_weights(model.get_weights())

    # Log episode statistics
    ep_log.append(
        episode_num, episode_train_cnt, frame_num, episode_reward, 
        episode_actions, episode_a_israns, episode_Qs, episode_losses, 
        episode_tderrs
    )

    # Output episode animation video every 1000 episodes
    if episode_num % 1000 == 0:
        # Save model
        opath = f'../{game}/model/{model_name}/model/model_{episode_num}.keras'
        model.save(opath)

        # Save train and test recording
        opath = f'../{game}/model/{model_name}/record/record_{episode_num}_train_{episode_reward}.mp4'
        sal_type = 'gcam' if i_animation % 2 else 'sal'
        mp_pool.apply_async(animate_episode_sal, args=(
            model, episode_states, episode_frames, episode_actions, 
            episode_Qs, action_space, opath, dueling_net, distr_net, (Z, episode_pZs), sal_type)
        )
        i_animation += 1
        # eval_frame, eval_reward = play_episode(model, env, action_space, state_len)
        # frames_to_mp4(f'../{game}/model/{model_name}/record/record_{episode_num}_eval_{eval_reward}.mp4', eval_frame)
   
    # Also output episode video if new max score was attained, and time-delta has been met
    elif (episode_reward > max_episode_reward) & (datetime.now() > t_lastmax):
        # Store new max reward results
        opath = f'../{game}/model/{model_name}/record/record_{episode_num}_train_x_{episode_reward}.mp4'
        sal_type = 'gcam' if i_animation % 2 else 'sal'
        mp_pool.apply_async(animate_episode_sal, args=(
            model, episode_states, episode_frames, episode_actions, 
            episode_Qs, action_space, opath, dueling_net, distr_net, (Z, episode_pZs), sal_type)
        )
        i_animation += 1
        t_lastmax = datetime.now() + timedelta(minutes=30)
        max_episode_reward = episode_reward
        
    # Backup model and replay memory every 250 episodes
    # Note: takes a minute or two, so should not be run frequently 
    if episode_num % 250 == 0:
        # Store replay memory
        opath = f'../{game}/model/{model_name}/replay_mem.pkl'
        joblib.dump(replay_mem, opath, compress=3)

        # Store current model
        opath = f'../{game}/model/{model_name}/model/model_latest.keras'
        model.save(opath)

  np.nanmean(np.max(Qs, axis=1)),
2024-01-29 23:47:13.949953: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: ram://01d55244-ca61-44c2-a064-e1515e8f8846/assets


  np.nanmean(np.max(Qs, axis=1)),


INFO:tensorflow:Assets written to: ram://a6c13597-d1f0-4a90-8ef6-2f9b7daafcb0/assets


  np.nanmean(np.max(Qs, axis=1)),
  ylim = np.nanmax([p.max() for p in pZ]) * 1.05
  if s != self._text:
  ylim = np.nanmax([p.max() for p in pZ]) * 1.05
  if s != self._text:
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T


INFO:tensorflow:Assets written to: ram://5a564088-88ef-4149-9644-40ce77f46afe/assets


  if s != self._text:
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T


INFO:tensorflow:Assets written to: ram://3f2f2433-df86-489b-9b31-b9dc310a7b78/assets


  if s != self._text:
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T


INFO:tensorflow:Assets written to: ram://17cbe5ea-5e41-430f-bb92-8bc53c47fdc9/assets


  if s != self._text:


INFO:tensorflow:Assets written to: ram://0f94e7a6-17c2-4da7-a65c-43fafd805eed/assets


  if s != self._text:
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T


INFO:tensorflow:Assets written to: ram://1c626784-0a8b-4474-bc79-f0a94d678979/assets


  if s != self._text:
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T
  return np.array([self.get_memory(idx, n) for idx in idxs]).T


INFO:tensorflow:Assets written to: ram://b3e17e37-5dc2-4110-af4e-833cd058eddd/assets


  if s != self._text:


INFO:tensorflow:Assets written to: ram://d5f49414-5efe-4e3f-b810-e74956b5d79f/assets


  if s != self._text:
  return np.array([self.get_memory(idx, n) for idx in idxs]).T


INFO:tensorflow:Assets written to: ram://77a0f00d-7e0c-4ead-9aee-854e6bf429f9/assets


  if s != self._text:
  return np.array([self.get_memory(idx, n) for idx in idxs]).T


INFO:tensorflow:Assets written to: ram://11efec74-dfc8-4c87-844d-4ba083fdcfa5/assets


  if s != self._text:
  return np.array([self.get_memory(idx, n) for idx in idxs]).T


KeyboardInterrupt: 

In [None]:
# 11.5 - 3 compr GB for 4508 episodes