In [None]:
# Standard libraries
import os
import sys
import glob
import copy
import time

# Data processing
import numpy as np
import pandas as pd
from PIL import Image

# Deep learning
import torch

# Reinforcement learning
import gymnasium as gym
import ale_py
from stable_baselines3 import PPO, DQN
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecEnvWrapper, VecMonitor, VecVideoRecorder
from stable_baselines3.common.atari_wrappers import AtariWrapper
from stable_baselines3.common.callbacks import CallbackList
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed

# Visualization
import matplotlib.pyplot as plt

# Logging
from wandb.integration.sb3 import WandbCallback
import wandb

# Project-specific imports
sys.path.insert(0, '../src')
from replay_buffer import HDF5ReplayBufferRAM
from models import Autoencoder, CTR_Attention_dil
from rl_wrappers import RawRewardTracker
from rl_callbacks import RawRewardLoggingCallback, WandbModelCheckpointCallback
from rl_networks import AutoencoderFeatureExtractor
from rl_policies import MyDQNPolicy
from rl_buffers import MyReplayBuffer, PrioritizedReplayBuffer
from rl_models import MyDQNModel

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_dtype(torch.float32)


In [None]:
# Configuration
use_wandb = True
game_name = "Freeway"

# Lambda values for 16% attention target
lam_set_dict = {
    'Enduro': np.float64(0.15636971314004802),
    'Freeway': np.float64(0.1899595360666595),
    'MsPacman': np.float64(0.17380204890481837),
    'Seaquest': np.float64(0.22773814757646266),
    'SpaceInvaders': np.float64(0.16151139168983497),
    'Riverraid': np.float64(0.15986161207261196)
}

# Game mappings
games = ['Enduro', 'Freeway', 'MsPacman', 'Riverraid', 'Seaquest', 'SpaceInvaders']
AA_to_AH = {
    'Enduro': 'enduro',
    'Freeway': 'freeway',
    'MsPacman': 'ms_pacman',
    'Riverraid': 'riverraid',
    'Seaquest': 'seaquest',
    'SpaceInvaders': 'space_invaders'
}

# Training loop over seeds and inverse_psi settings
for i_seed in [0, 1, 2]:
    for inv_psi in [False, True]:
        
        # Training configuration
        config = {
            "game_name": game_name,
            "use_CTR": True,
            "env_seed": [42, 1337, 30890][i_seed],
            "total_timesteps": 2_500_000,  # Frameskip=4, so baseline is 2.5M for 10M Frames
            "learning_starts": 25_000,  # 10k-25k is baseline
            "train_freq": 4,  # 4 is baseline
            "gamma": 0.99,  # Standard is 0.99, Human-like is <=0.9
            "use_PER": True,  # Prioritized Experience Replay
            "inverse_psi": inv_psi,
            "lambda_strategy": "fixed",
            "lambda_fixed_val": lam_set_dict[game_name],
            "beta_res_attention": 0.0,  # How much of non-attended features to retain
            "lin_sched_lam_min": lam_set_dict[game_name],
            "lin_sched_lam_max": 1.0,
            "lin_sched_max_timesteps": 2_000_000,
            "model_save_freq": 100_000,
        }
        
        # Lambda scheduler setup
        if config["lambda_strategy"] == "fixed":
            lam_scheduler = lambda step: config["lambda_fixed_val"] * np.ones_like(step)
        elif config["lambda_strategy"] == "linear_scheduler":
            lam_scheduler = lambda step: np.clip(
                config["lin_sched_lam_min"] + (config["lin_sched_lam_max"] - config["lin_sched_lam_min"]) * step / config["lin_sched_max_timesteps"],
                config["lin_sched_lam_min"],
                config["lin_sched_lam_max"]
            )
        elif config["lambda_strategy"] == "max_to_min":
            lam_scheduler = lambda step: np.clip(
                config["lin_sched_lam_max"] - (config["lin_sched_lam_max"] - config["lin_sched_lam_min"]) * step / config["lin_sched_max_timesteps"],
                config["lin_sched_lam_min"],
                config["lin_sched_lam_max"]
            )
        
        # Initialize WandB
        if use_wandb:
            run = wandb.init(
                entity="uthenrik-the-university-of-tokyo",
                project="Covert Attention Agents",
                config=config,
                sync_tensorboard=True,
                monitor_gym=True,
            )
        
        # Load pre-trained models
        game_name = config['game_name']
        
        # Load autoencoder
        ae_pattern = f'trained_models/autoencoder/AE_4f_TCDS_BlurPool64_{game_name}_*.pkl'
        ae_files = glob.glob(ae_pattern)
        checkpoint_path = max(ae_files, key=os.path.getmtime)
        checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device)
        autoencoder = Autoencoder(device).to(device)
        autoencoder.load_state_dict(checkpoint['state'])
        autoencoder_copy = copy.deepcopy(autoencoder)
        
        for param in autoencoder.parameters():
            param.requires_grad = False
        
        # Load CTR attention models
        CTR_version = 'V15'
        ctr_pattern = f'trained_models/CTR_att/{CTR_version}/nCTR_AA_AH_{CTR_version}_{game_name}_*.pkl'
        ctr_files = glob.glob(ctr_pattern)
        checkpoint_path = max(ctr_files, key=os.path.getmtime)
        checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device)
        
        CTR_AA = CTR_Attention_dil(repr_shape=(1, 32, 21, 21), autoencoder=copy.deepcopy(autoencoder), device=device).to(device)
        CTR_AA.load_state_dict(checkpoint['CTR_AA_state'], strict=False)
        CTR_AA.eval()
        
        CTR_AH = CTR_Attention_dil(repr_shape=(1, 32, 21, 21), autoencoder=copy.deepcopy(autoencoder), device=device).to(device)
        CTR_AH.load_state_dict(checkpoint['CTR_AH_state'], strict=False)
        CTR_AH.eval()
        
        # Environment setup
        SEED = config['env_seed']
        
        def make_env():
            def _init():
                env = gym.make(f"ALE/{game_name}-v5", frameskip=1, render_mode="rgb_array")
                env = RawRewardTracker(env)
                env = AtariWrapper(env, frame_skip=4, terminal_on_life_loss=False)
                env.reset(seed=SEED)
                env.action_space.seed(SEED)
                return env
            return _init
        
        set_random_seed(SEED)
        
        env = DummyVecEnv([make_env()])
        env.seed(SEED)
        env = VecMonitor(env)
        
        if use_wandb:
            env = VecVideoRecorder(
                env,
                f"videos/{run.id}",
                record_video_trigger=lambda x: x % 50000 == 0,
                video_length=300,
            )
        
        env = VecFrameStack(env, n_stack=4)
        
        # Model creation
        def make_dqn_model(env, use_CTR):
            policy_kwargs = dict(
                features_extractor_class=AutoencoderFeatureExtractor,
                features_extractor_kwargs=dict(
                    autoencoder=autoencoder,
                    lam_scheduler=lam_scheduler,
                    device=device,
                    CTR=CTR_AH,
                    config=config
                ),
            )
            
            if config["use_PER"]:
                replay_buffer_class = PrioritizedReplayBuffer
            else:
                replay_buffer_class = MyReplayBuffer
        
            dqn_kwargs = dict(
                env=env,
                policy=MyDQNPolicy,
                policy_kwargs=policy_kwargs,
                replay_buffer_class=replay_buffer_class,
                replay_buffer_kwargs={
                    "handle_timeout_termination": False,
                    "config": config,
                },
                verbose=1,
                device=device,
                learning_rate=2.5e-4,
                buffer_size=500_000,
                batch_size=32,
                gamma=config["gamma"],
                exploration_initial_eps=1.0,
                exploration_final_eps=0.1,
                exploration_fraction=0.1,
                target_update_interval=10_000,
                train_freq=config["train_freq"],
                gradient_steps=1,
                learning_starts=config["learning_starts"],
                optimize_memory_usage=True,
            )
        
            if use_wandb:
                dqn_kwargs["tensorboard_log"] = f"runs/{run.id}"
        
            return MyDQNModel(**dqn_kwargs)
        
        # Create and configure model
        model_DDQN = make_dqn_model(env, use_CTR=config['use_CTR'])
        model_DDQN.config = config
        model_DDQN.policy.step_counter = 0
        model_DDQN.policy.set_CTR_strategy(config['use_CTR'])
        
        # Fix autoencoder references
        model_DDQN.q_net.features_extractor.autoencoder = copy.deepcopy(autoencoder_copy)
        model_DDQN.q_net.features_extractor.autoencoder.eval()
        model_DDQN.q_net_target.features_extractor.autoencoder = copy.deepcopy(autoencoder_copy)
        model_DDQN.q_net_target.features_extractor.autoencoder.eval()
        
        # Training
        if use_wandb:
            wandb.watch(model_DDQN.policy.q_net, log="all", log_freq=1000)
        
        if use_wandb:
            callbacks = CallbackList([
                WandbCallback(),
                RawRewardLoggingCallback(),
                WandbModelCheckpointCallback(model_save_freq=config["model_save_freq"]),
            ])
            model_DDQN.learn(total_timesteps=config["total_timesteps"], reset_num_timesteps=False, callback=callbacks)
        else:
            model_DDQN.learn(total_timesteps=config["total_timesteps"], reset_num_timesteps=False)
        
        wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33muthenrik[0m ([33muthenrik-the-university-of-tokyo[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


A.L.E: Arcade Learning Environment (version 0.11.1+2750686)
[Powered by Stella]


Using cuda device
Wrapping the env in a VecTransposeImage.
Logging to runs/3v9mi3o4/DQN_0
Saving video to /workspace/TCDS/videos/3v9mi3o4/rl-video-step-0-to-step-300.mp4
MoviePy - Building video /workspace/TCDS/videos/3v9mi3o4/rl-video-step-0-to-step-300.mp4.
MoviePy - Writing video /workspace/TCDS/videos/3v9mi3o4/rl-video-step-0-to-step-300.mp4



                                                                                                                                                                                                     

MoviePy - Done !
MoviePy - video ready /workspace/TCDS/videos/3v9mi3o4/rl-video-step-0-to-step-300.mp4
----------------------------------
| rollout/            |          |
|    current_PER_beta | 0.4      |
|    current_lambda   | 0.19     |
|    ep_len_mean      | 2.04e+03 |
|    ep_rew_mean      | 0        |
|    exploration_rate | 0.971    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 277      |
|    time_elapsed     | 29       |
|    total_timesteps  | 8174     |
----------------------------------
----------------------------------
| rollout/            |          |
|    current_PER_beta | 0.4      |
|    current_lambda   | 0.19     |
|    ep_len_mean      | 2.04e+03 |
|    ep_rew_mean      | 0        |
|    exploration_rate | 0.941    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 285      |
|    time_elapsed     | 57       |
|    total_timesteps  | 16355    |
----------------------

                                                                                                                                                                                                     

MoviePy - Done !
MoviePy - video ready /workspace/TCDS/videos/3v9mi3o4/rl-video-step-50000-to-step-50300.mp4
--------------------------------------
| rollout/            |              |
|    current_PER_beta | 0.408        |
|    current_lambda   | 0.19         |
|    ep_len_mean      | 2.04e+03     |
|    ep_rew_mean      | 0            |
|    exploration_rate | 0.794        |
| time/               |              |
|    episodes         | 28           |
|    fps              | 85           |
|    time_elapsed     | 673          |
|    total_timesteps  | 57256        |
| train/              |              |
|    learning_rate    | 0.00025      |
|    loss             | 5.02e-06     |
|    n_updates        | 8063         |
|    q_values/mean    | -0.047702737 |
|    q_values/std     | 0.0008837336 |
--------------------------------------
--------------------------------------
| rollout/            |              |
|    current_PER_beta | 0.41         |
|    current_lambda   | 0.19     

                                                                                                                                                                                                     

MoviePy - Done !
MoviePy - video ready /workspace/TCDS/videos/3v9mi3o4/rl-video-step-100000-to-step-100300.mp4
