# Import

In [None]:
# Standard libraries
import sys
import pickle
from datetime import datetime
from collections import namedtuple, deque

# Data processing
import numpy as np
import pandas as pd
import cv2
from PIL import Image

# Deep learning
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Visualization
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

# Project-specific imports
sys.path.insert(0, './atarihead')
sys.path.insert(0, '../src')
from replay_buffer import HDF5ReplayBufferRAM
from models import Autoencoder
from utils import get_lr

# Configuration
pd.options.display.float_format = '{:.3g}'.format
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_default_dtype(torch.float32)

# IPython display setup
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display


# Load Atarihead replay buffer

In [3]:
games=['Alien',
       'Asterix',
       'BankHeist',
       'Berzerk',
       'Breakout',
       'Centipede',
       'DemonAttack',
       'Enduro',
       'Freeway',
       'Frostbite',
       'Hero',
       'MontezumaRevenge',
       'MsPacman',
       'NameThisGame',
       'Phoenix',
       'Riverraid',
       'RoadRunner',
       'Seaquest',
       'SpaceInvaders',
       'Venture']

n_actions={'Alien': 18,
       'Asterix': 9,
       'BankHeist': 18,
       'Berzerk': 18,
       'Breakout': 4,
       'Centipede': 18,
       'DemonAttack': 6,
       'Enduro': 9,
       'Freeway': 3,
       'Frostbite': 18,
       'Hero': 18,
       'MontezumaRevenge': 18,
       'MsPacman': 9,
       'NameThisGame': 6,
       'Phoenix': 8,
       'Riverraid': 18,
       'RoadRunner': 18,
       'Seaquest': 18,
       'SpaceInvaders': 6,
       'Venture': 18}

AA_to_AH={'Alien': 'alien',
       'Asterix': 'asterix',
       'BankHeist': 'bank_heist',
       'Berzerk': 'berzerk',
       'Breakout': 'breakout',
       'Centipede': 'centipede',
       'DemonAttack': 'demon_attack',
       'Enduro': 'enduro',
       'Freeway': 'freeway',
       'Frostbite': 'frostbite',
       'Hero': 'hero',
       'MontezumaRevenge': 'montezuma_revenge',
       'MsPacman': 'ms_pacman',
       'NameThisGame': 'name_this_game',
       'Phoenix': 'phoenix',
       'Riverraid': 'riverraid',
       'RoadRunner': 'road_runner',
       'Seaquest': 'seaquest',
       'SpaceInvaders': 'space_invaders',
       'Venture': 'venture'}

In [None]:
### Make game max reward!!!

game_name='Seaquest' # Give name in AA style
train_val_split=0.8

memoryAA = HDF5ReplayBufferRAM.load(
file_path=fr'replay_buffers/atari_agents/{game_name}_atari_agents_buffer_4f_84gray.h5',file_rw_option='r', train_val_split=train_val_split, RAM_ratio=1/4)
memoryAH = HDF5ReplayBufferRAM.load(
file_path=fr'replay_buffers/atarihead/{AA_to_AH[game_name]}_atarihead_buffer_all_4f_84gray.h5',file_rw_option='r', train_val_split=train_val_split, RAM_ratio=1/4)


# Define AE training Class


In [None]:
class AE_training:
    def __init__(self, lr):

        self.n_samples=8
        self.flag_load_ae = False
                
        self.autoencoder=Autoencoder(device).to(device)
        self.ae_optimizer = optim.AdamW(self.autoencoder.parameters(), lr=lr,weight_decay=1e-2)

    def train_AE_att(self,replay_bufferAA,replay_bufferAH,batch_size):

        results=[[],[]]

        stateAA_np_all, _, _, _, _, _, train_flagsAA = replay_bufferAA.sample_stacked_fwd(int(batch_size/train_val_split))
        stateAA_all = torch.tensor(stateAA_np_all, dtype=torch.float32, device=device) / 255.0 * 2 - 1

        stateAH_np_all, _, _, _, _, _, train_flagsAH = replay_bufferAH.sample_stacked_fwd(int(batch_size/train_val_split))
        stateAH_all = torch.tensor(stateAH_np_all, dtype=torch.float32, device=device) / 255.0 * 2 - 1
        
        normal_criterion = nn.MSELoss()        
        for sample_mode,stateAA,stateAH,id_train in zip(['train','val'],
            [stateAA_all[train_flagsAA],stateAA_all[~train_flagsAA]],
            [stateAH_all[train_flagsAH],stateAH_all[~train_flagsAH]],
            [0,1]):

            if sample_mode=='val':
                self.autoencoder.eval()

            stateAA_recr=self.autoencoder(stateAA)
            stateAH_recr=self.autoencoder(stateAH)
            
            loss_auto=normal_criterion(stateAA,stateAA_recr)+normal_criterion(stateAH,stateAH_recr)

            if sample_mode=='train':

                self.ae_optimizer.zero_grad()
                loss_auto.backward()
                torch.nn.utils.clip_grad_norm_(self.autoencoder.parameters(), 1.0)
                self.ae_optimizer.step()
            if sample_mode=='val':
                self.autoencoder.train()

            results[id_train]=loss_auto.item()
        return results

# Training

In [None]:
AE_trainer=AE_training( lr=1e-3)
loss_df = pd.DataFrame(columns=['q-epoch','lr','t:ae','v:ae'])

batch_size=64
n_qepochs=800
n_batches=100
ratio_val=4

scheduler = torch.optim.lr_scheduler.StepLR(AE_trainer.ae_optimizer, step_size=600, gamma=0.1)

AE_trainer.train_AE=True
j_epoch=0

In [None]:
ae_sum=0

stop_training=False

while j_epoch <n_qepochs:
    for i in range(n_batches):

        ae=AE_trainer.train_AE_att(memoryAA,memoryAH,batch_size)
        ae=np.array(ae)
        ae_sum+=ae


    ae_mean=ae_sum/n_batches
    current_lr = get_lr(AE_trainer.ae_optimizer)

    loss_df.loc[j_epoch] = [j_epoch+1, current_lr, *ae_mean]
    clear_output()
    print(loss_df.tail(10).to_string(index=False))

    ae_sum=0
    
    j_epoch+=1
    scheduler.step()
    memoryAA.shuffle_RAM()
    memoryAH.shuffle_RAM()
    print('RB RAM shuffled')
    
    if stop_training:
        break


KeyboardInterrupt: 

## Plot learning progress

In [None]:
fig,ax=plt.subplots(figsize=(7,5))
ravg=5

ax.plot(loss_df['t:ae'].rolling(ravg).mean(),label='train')
ax.plot(loss_df['v:ae'].rolling(ravg).mean(),label='val')
ax.set_title('AE loss')
ax.set_yscale('log')
ax.legend()
plt.tight_layout()


## Single image

In [None]:
i_print=np.random.randint(batch_size)
state_np, actions, next_state_np, rewards, frame_ids, gaze_positions, done_np,_ = memoryAA.sample_stacked(batch_size)
state = torch.tensor(state_np,device=device,dtype=torch.float32)/255*2-1
state_img = state_np.reshape(state_np.shape[0], state_np.shape[1], state_np.shape[2], -1)

next_state_img=next_state_np
gaze_pos = gaze_positions[i_print]

fig,ax=plt.subplots(1,2)
state_recr=AE_trainer.autoencoder(state)
state_recr_img=((state_recr.detach().cpu().numpy()[...,-1]+1)/2*255).astype(np.uint8)

ax[0].imshow(state_img[i_print][..., -1], interpolation='nearest')
ax[0].set_title('Image')
ax[1].imshow(state_recr_img[i_print], interpolation='nearest')
ax[1].set_title('Recreated Image')


# Save whole AE Checkpoint

In [None]:
model_descr='BlurPool64'

now = datetime.now().strftime("%Y-%m-%d-%H-%M")

AE_checkpoint = {
    'state': AE_trainer.autoencoder.state_dict(),
    'optimizer_state': AE_trainer.ae_optimizer.state_dict(),
    }

with open(fr'trained_models/AE_4f_TCDS_{model_descr}_{game_name}_{now}.pkl', 'wb') as f:
    torch.save(AE_checkpoint, f)

loss_df.to_csv(fr'trained_models/AE_4f_TCDS_df_{model_descr}_{game_name}_{now}.csv', index=False) 