# Import

In [None]:
# Standard libraries
import sys
import os
import copy
import glob
import pickle
from datetime import datetime
from collections import namedtuple, deque

# Data processing
import numpy as np
import pandas as pd
import cv2
from PIL import Image
import scipy

# Deep learning
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Visualization
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

# Project-specific imports
sys.path.insert(0, './atarihead')
sys.path.insert(0, '../src')
from replay_buffer import HDF5ReplayBufferRAM
from models import Autoencoder, Gaze_predictor_pool, Motor_predictor_fwd, CTR_Attention_dil, CTR_Attention_SA
from utils import get_lr
from training import Attention_training_class

# External analysis functions
import FUNCTIONS_Analysis as func_a

# Configuration
pd.options.display.float_format = '{:.3g}'.format
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_default_dtype(torch.float32)

# IPython display setup
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

# Load Replay buffers

In [None]:
games = [
    'Alien', 'Asterix', 'BankHeist', 'Berzerk', 'Breakout',
    'Centipede', 'DemonAttack', 'Enduro', 'Freeway', 'Frostbite',
    'Hero', 'MontezumaRevenge', 'MsPacman', 'NameThisGame', 'Phoenix',
    'Riverraid', 'RoadRunner', 'Seaquest', 'SpaceInvaders', 'Venture'
]

n_actions = {
    'Alien': 18,
    'Asterix': 9,
    'BankHeist': 18,
    'Berzerk': 18,
    'Breakout': 4,
    'Centipede': 18,
    'DemonAttack': 6,
    'Enduro': 9,
    'Freeway': 3,
    'Frostbite': 18,
    'Hero': 18,
    'MontezumaRevenge': 18,
    'MsPacman': 9,
    'NameThisGame': 6,
    'Phoenix': 8,
    'Riverraid': 18,
    'RoadRunner': 18,
    'Seaquest': 18,
    'SpaceInvaders': 6,
    'Venture': 18
}

AA_to_AH = {
    'Alien': 'alien',
    'Asterix': 'asterix',
    'BankHeist': 'bank_heist',
    'Berzerk': 'berzerk',
    'Breakout': 'breakout',
    'Centipede': 'centipede',
    'DemonAttack': 'demon_attack',
    'Enduro': 'enduro',
    'Freeway': 'freeway',
    'Frostbite': 'frostbite',
    'Hero': 'hero',
    'MontezumaRevenge': 'montezuma_revenge',
    'MsPacman': 'ms_pacman',
    'NameThisGame': 'name_this_game',
    'Phoenix': 'phoenix',
    'Riverraid': 'riverraid',
    'RoadRunner': 'road_runner',
    'Seaquest': 'seaquest',
    'SpaceInvaders': 'space_invaders',
    'Venture': 'venture'
}

# Configuration


In [None]:
# Configuration for training
config = {
    "game_name": 'Freeway',  # Change as needed
    "train_val_split": 0.8,
    "CTR_type": "dil",  # Self-Attention: "SA", or Dilated CNN: "dil"
    "psi_blend_dropout": 0.25,
    "train_additional_MPs": True,  # Whether to concurrently train additional plain and GP motor predictor
    "n_epochs": 300,  # higher
    "n_epochs_CTR": 250,
}

# Load replay buffers
memoryAA = HDF5ReplayBufferRAM.load(
    file_path=f'replay_buffers/atari_agents/{config["game_name"]}_atari_agents_buffer_4f_84gray.h5',
    file_rw_option='r', 
    train_val_split=config["train_val_split"], 
    RAM_ratio=1/1
)
memoryAH = HDF5ReplayBufferRAM.load(
    file_path=f'replay_buffers/atarihead/{AA_to_AH[config["game_name"]]}_atarihead_buffer_all_4f_84gray.h5',
    file_rw_option='r', 
    train_val_split=config["train_val_split"], 
    RAM_ratio=1/1
)

# Game configuration
game_res = (210, 160, 3)
game_max_reward = 1 
game_n_actionsAH = 18
game_n_actionsAA = n_actions[config["game_name"]]
game_action_type = 'discrete'

# Load Autoencoder and Gaze Predictor
ae_pattern = f'trained_models/autoencoder/AE_4f_TCDS_BlurPool64_{config["game_name"]}_*.pkl'
ae_files = glob.glob(ae_pattern)
checkpoint_path = max(ae_files, key=os.path.getmtime)
checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device)
autoencoder = Autoencoder(device).to(device)
autoencoder.load_state_dict(checkpoint['state'])
autoencoder.eval()

GP_pattern = f'trained_models/gaze_predictor/gaze_predictor_4f_{config["game_name"]}_*.pkl'
GP_files = glob.glob(GP_pattern)
GP_checkpoint_path = max(GP_files, key=os.path.getmtime)
GP_checkpoint = torch.load(GP_checkpoint_path, weights_only=True, map_location=device)
gaze_predictor = Gaze_predictor_pool(device, copy.deepcopy(autoencoder)).to(device)
gaze_predictor.load_state_dict(GP_checkpoint['model_state'])
gaze_predictor.eval()

# Determine class weights
for buffer_type, memory in zip(['AA', 'AH'], [memoryAA, memoryAH]):
    if buffer_type == 'AA':
        _, actions, _, _, _, _, done, _ = memory.sample_stacked(2048)
    elif buffer_type == 'AH':
        _, actions, _, _, _, _, done, _ = memory.sample_stacked(2048)

    unique_values, counts = np.unique(actions[~done], return_counts=True)
    for value, count in zip(unique_values, counts):
        print(f"Value: {value}, Frequency: {count}")

    class_weights = 1.0 / np.sqrt(counts)  # Inverse of class frequencies
    class_weights = class_weights / class_weights.mean()  # Normalize (optional)

    if buffer_type == 'AH':
        ah_array = np.ones(18)
    elif buffer_type == 'AA':
        ah_array = np.ones(game_n_actionsAA)
    for value, weight in zip(unique_values, class_weights):
        ah_array[value] = weight
    
    class_weights = ah_array
    class_weights = torch.tensor(class_weights, device=device, dtype=torch.float32)

    if config["game_name"] == 'Enduro' and buffer_type == 'AH':
        class_weights[[2, 6, 7, 10, 13, 14, 15, 16, 17]] *= 0

    if buffer_type == 'AA':
        class_weightsAA = class_weights
    elif buffer_type == 'AH':
        class_weightsAH = class_weights

# Initialize attention training network
Att_network = Attention_training_class(
    autoencoder, gaze_predictor, 
    action_dimAA=game_n_actionsAA, 
    action_dimAH=game_n_actionsAH, 
    lr=1e-3, 
    config=config
)


# Training setup and execution

In [None]:
# Training setup and execution
cols = ['q-epoch', 'lr_MP', 'lr_CTR', 't:psi', 't:lam', 't:MP CTR', 't:MPacc CTR', 't:MP CTRf', 't:MPacc CTRf', 't:MP Plain', 't:MPacc Plain', 't:MP GP', 't:MPacc GP', 'v:psi', 'v:lam', 'v:MP CTR', 'v:MPacc CTR', 'v:MP CTRf', 'v:MPacc CTRf', 'v:MP Plain', 'v:MPacc Plain', 'v:MP GP', 'v:MPacc GP']

df_AA = pd.DataFrame(columns=cols)
df_AH = pd.DataFrame(columns=cols)

batch_size = 128
n_batches = 100

j_epoch = 0
tv_results_sum = 0

torch.cuda.empty_cache()

# Main training loop
while j_epoch < config["n_epochs"]:
    all_batch_results = []  # Store each batch's results

    for i in range(n_batches):
        tv_results = Att_network.train_AE_att(
            j_epoch, memoryAA, memoryAH, batch_size, 
            config, class_weightsAA, class_weightsAH, 
            GP_comparison=False
        )
        all_batch_results.append(tv_results)  # Append the current batch results

    all_batch_results = np.array(all_batch_results)
    tv_results_mean = np.nanmean(all_batch_results, axis=0)  # Due to varying length of val samples, sometimes there are 0 leading to nan values

    lr_MP_AA = get_lr(Att_network.MP_AA_optimizer)
    lr_CTR_AA = get_lr(Att_network.CTR_AA_optimizer)
    lr_MP_AH = get_lr(Att_network.MP_AH_optimizer)
    lr_CTR_AH = get_lr(Att_network.CTR_AH_optimizer)

    df_AA.loc[j_epoch] = [j_epoch+1, lr_MP_AA, lr_CTR_AA, *tv_results_mean[0, 0], *tv_results_mean[0, 1]]
    df_AH.loc[j_epoch] = [j_epoch+1, lr_MP_AH, lr_CTR_AH, *tv_results_mean[1, 0], *tv_results_mean[1, 1]]
    
    clear_output()

    print('Agent CTR results:')
    print(df_AA.tail(5).drop([col for col in df_AA.columns if "MPacc" in col], axis=1).to_string(index=False))
    print('\nAgent CTR Accuracies:')
    print(df_AA.tail(5).filter(like="MPacc").to_string(index=False))
    print('Human CTR results:')
    print(df_AH.tail(5).drop([col for col in df_AH.columns if "MPacc" in col], axis=1).to_string(index=False))
    print('\nHuman CTR Accuracies:')
    print(df_AH.tail(5).filter(like="MPacc").to_string(index=False))

    tv_results_sum = 0
    v_results_sum = 0
    GPc_results_sum = 0
    
    j_epoch += 1

    memoryAA.shuffle_RAM()
    memoryAH.shuffle_RAM()
    print('Replay buffer RAM shuffled')


# Plotting and analysis

In [None]:
fig, ax = plt.subplots(3, 1, figsize=(8, 10))
ravg = 1
last = 1300

# Define colors and data
mpa_variants = ['MP CTR', 'MP Plain', 'MP GP']
colors = ["#2c81bd", '#ff7f0e', '#2ca02c', '#d62728']  # Blue, Orange, Green, Red
datasets = [('AA', 'b'), ('AH', 'r')]  # (name, color) pairs

# Plot MPa variants for AA and AH
for i, (name, color) in enumerate(datasets):
    df = df_AA if name == 'AA' else df_AH
    for j, variant in enumerate(mpa_variants):
        ax[i].plot(df[f't:{variant}'].rolling(ravg).mean()[-last:],
                  label=f'Train ({variant})', c=colors[j], alpha=0.8)
        ax[i].plot(df[f'v:{variant}'].rolling(ravg).mean()[-last:],
                  label=f'Val ({variant})', c=colors[j], linestyle='--', alpha=0.8)
    ax[i].set_title(f'{name}: MPa Variants')
    ax[i].legend(bbox_to_anchor=(1.05, 1))
    ax[i].grid()

# Plot psi for both AA and AH
for name, color in datasets:
    df = df_AA if name == 'AA' else df_AH
    ax[2].plot(df['t:psi'].rolling(ravg).mean()[-last:],
              label=f'{name}: Train', c=color)
    ax[2].plot(df['v:psi'].rolling(ravg).mean()[-last:],
              label=f'{name}: Val', c=color, linestyle='--')
ax[2].set_title('Psi Comparison')
ax[2].legend()
ax[2].grid()

plt.tight_layout()
plt.show()


# Testing and visualization

In [None]:
Att_network.CTR_AA.eval()
Att_network.CTR_AH.eval()
n_frames = 512
i_print = np.random.randint(batch_size)

with torch.no_grad():
    state_np, _, _, _, _, _, _, _ = memoryAH.sample_stacked(n_frames)
    state = torch.tensor(state_np, device=device, dtype=torch.float32)/255*2-1
    state_img = state_np.reshape(state_np.shape[0], state_np.shape[1], state_np.shape[2], -1)

    lam = 0.02*torch.ones(n_frames, 1, device=device)
    
    psiAA = Att_network.CTR_AA.psi(state, lam)
    psiAH = Att_network.CTR_AH.psi(state, lam)

    kernel = func_a.make_gaussian_kernel(size=11, sigma=.7, device=device)
    psiAA_blur = func_a.apply_gaussian_filter(psiAA, kernel)
    psiAH_blur = func_a.apply_gaussian_filter(psiAH, kernel)

    AA_heat = psiAA_blur.mean(axis=0).squeeze(0).detach().cpu().numpy()
    AH_heat = psiAH_blur.mean(axis=0).squeeze(0).detach().cpu().numpy()
    
    fig, ax = plt.subplots(2, 2, figsize=(8, 8))
    ax[0, 0].imshow(state_img[i_print][..., -1])
    ax[0, 0].set_title('Image')
    
    for idp, psi in enumerate([psiAA, psiAH]):
        psi_np = psi[i_print].permute(1, 2, 0).cpu().detach().numpy()
        psi_np_blur = scipy.ndimage.gaussian_filter(psi_np, sigma=0.0)
        im = ax[1, idp].imshow(psi_np_blur)
        fig.colorbar(im, ax=ax[1, idp], label='Attention score', fraction=0.046, pad=0.04)
    
    ax[0, 1].set_title('Gaze Predictor')
    ax[1, 0].set_title('Agent CTR')
    ax[1, 1].set_title('Human CTR')
    
    plt.tight_layout()

figh, axh = plt.subplots(1, 3, figsize=(8, 8))
axh[0].imshow(AA_heat)
axh[1].imshow(AH_heat)
axh[0].set_title('Agent CTR Heatmap')
axh[1].set_title('Human CTR Heatmap')
axh[2].set_title('Gaze Heatmap')

Att_network.CTR_AA.train()
Att_network.CTR_AH.train()
del psiAA, psiAH, psiAA_blur, psiAH_blur, AA_heat, AH_heat
torch.cuda.empty_cache()


# Save model checkpoint

In [None]:
model_descr = 'V15'
now = datetime.now().strftime("%Y-%m-%d-%H-%M")

CTR_checkpoint = {
    'CTR_AA_state': Att_network.CTR_AA.state_dict(),
    'CTR_AA_opt_state': Att_network.CTR_AA_optimizer.state_dict(),
    'CTR_AH_state': Att_network.CTR_AH.state_dict(),
    'CTR_AH_opt_state': Att_network.CTR_AH_optimizer.state_dict(),
    'MP_AA_state': Att_network.MP_AA.state_dict(),
    'MP_AA_opt_state': Att_network.MP_AA_optimizer.state_dict(),
    'MP_AH_state': Att_network.MP_AH.state_dict(),
    'MP_AH_opt_state': Att_network.MP_AH_optimizer.state_dict(),
    'MP_AA_plain_state': Att_network.MP_AA_plain.state_dict(),
    'MP_AA_plain_opt_state': Att_network.MP_AA_plain_optimizer.state_dict(),    
    'MP_AH_plain_state': Att_network.MP_AH_plain.state_dict(),
    'MP_AH_plain_opt_state': Att_network.MP_AH_plain_optimizer.state_dict(),
    'MP_AH_GP_state': Att_network.MP_AH_GP.state_dict(),
    'MP_AH_GP_opt_state': Att_network.MP_AH_GP_optimizer.state_dict(),
}

with open(f'trained_models/CTR_att/nCTR_AA_AH_{model_descr}_{config["game_name"]}_{now}.pkl', 'wb') as f:
    torch.save(CTR_checkpoint, f)

df_AA.to_csv(f'trained_models/CTR_att/nCTR_df_AA_{model_descr}_{config["game_name"]}_{now}.csv', index=False) 
df_AH.to_csv(f'trained_models/CTR_att/nCTR_df_AH_{model_descr}_{config["game_name"]}_{now}.csv', index=False)
