In [None]:
# Standard libraries
import sys
import io
import copy
from datetime import datetime
from collections import namedtuple, deque
from itertools import count

# Data processing
import numpy as np
import pandas as pd
import cv2
from PIL import Image, ImageDraw, ImageFont
from scipy.ndimage import gaussian_filter
from scipy.signal import convolve2d

# Deep learning
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Visualization
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

# Project-specific imports
sys.path.insert(0, './atarihead')
sys.path.insert(0, '../src')
from replay_buffer import HDF5ReplayBufferRAM

# Configuration
pd.options.display.float_format = '{:.3g}'.format
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_dtype(torch.float32)

# IPython display setup
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

# Load Atarihead replay buffer

In [None]:
# List of Atari games to process
games = [
    'alien', 'asterix', 'bank_heist', 'berzerk', 'breakout', 'centipede',
    'demon_attack', 'enduro', 'freeway', 'frostbite', 'hero', 'montezuma_revenge',
    'ms_pacman', 'name_this_game', 'phoenix', 'riverraid', 'road_runner',
    'seaquest', 'space_invaders', 'venture'
]

In [None]:
# Configuration
game_name = '...'
flag_load_encoded_grid_tensor = True
train_val_split = 0.8

# Load replay buffer
memory = HDF5ReplayBufferRAM.load(
    file_path=f'replay_buffers/atarihead/{game_name}_atarihead_buffer_all_4f_84gray.h5',
    file_rw_option='r', 
    train_val_split=train_val_split, 
    RAM_ratio=1/1
)

game_res = (84, 84, 1)

In [None]:
def rescale_gaze_positions(gaze_pos):
    """Rescale gaze positions from 160x210 to 84x84 resolution."""
    gaze_pos_ = gaze_pos.copy()
    gaze_pos_[:, :, 0] *= 84/160  # Scale x-coordinate
    gaze_pos_[:, :, 1] *= 84/210  # Scale y-coordinate
    return gaze_pos_

## Point merging functions

In [None]:
# Point merging functions for gaze target construction
height, width = game_res[0], game_res[1]

def convert_coords_to_encoded_row(coords, encoded_grid_tensor):
    """Convert gaze coordinates to encoded grid representation."""
    valid_mask = (coords[:, 0] >= 0) & (coords[:, 0] < width) & \
                 (coords[:, 1] >= 0) & (coords[:, 1] < height)
    row_indices = torch.clamp(
        torch.floor(coords)[:, 0].long() * 84 + torch.floor(coords)[:, 1].long(), 
        0, 84**2-1
    )
    return encoded_grid_tensor[row_indices], valid_mask

def local_merge_torch(points, radius):
    """Merge nearby gaze points to reduce noise."""
    points_stretched = points.clone()
    points_stretched[:, 0] = points_stretched[:, 0] * 2  # Stretch horizontal dimension
    dists = torch.cdist(points_stretched, points_stretched)
    mask = (dists < radius).float()
    merged_points = torch.zeros_like(points)
    merged_indices = set()
    
    for i in range(len(points) - 1, -1, -1):  # Reverse iteration
        if i not in merged_indices:
            if torch.isnan(points[i]).any():
                merged_points[i] = points[i]  # Keep NaN as is
                merged_indices.add(i)
                continue
                
            neighbors = torch.nonzero(mask[i]).squeeze().tolist()
            if isinstance(neighbors, int):
                neighbors = [neighbors]
            
            valid_neighbors = [n for n in neighbors if not torch.isnan(points[n]).any()]
            
            if not valid_neighbors:
                merged_points[i] = points[i]
                merged_indices.add(i)
                continue
                
            merged_point = torch.mean(points[valid_neighbors], dim=0)
            merged_points[valid_neighbors] = merged_point
            merged_indices.update(valid_neighbors)
    
    return merged_points

# Load encoded grid tensor
encoded_grid_tensor = torch.load('encoded_grid_tensor.pt', weights_only=True)


# Construct target to export

In [None]:
# Target construction parameters
t_TIW = 3
FRAMESKIP=4
n_TIW = int(t_TIW * 20)
n_TIWs = int(t_TIW * 20 / FRAMESKIP)  # Length of TIW on state level
n_batch = 1000

# Games to process (Atari Agents style)
games = [
    'Enduro', 'Freeway', 'MsPacman', 'Riverraid', 'Seaquest', 'SpaceInvaders',
    'Alien', 'Asterix', 'BankHeist', 'Berzerk', 'Breakout', 'Centipede',
    'DemonAttack', 'Frostbite', 'Hero', 'MontezumaRevenge', 'NameThisGame',
    'Phoenix', 'RoadRunner', 'Venture'
]

# Mapping from Atari Agents to Atarihead naming
AA_to_AH = {
    'Alien': 'alien', 'Asterix': 'asterix', 'BankHeist': 'bank_heist',
    'Berzerk': 'berzerk', 'Breakout': 'breakout', 'Centipede': 'centipede',
    'DemonAttack': 'demon_attack', 'Enduro': 'enduro', 'Freeway': 'freeway',
    'Frostbite': 'frostbite', 'Hero': 'hero', 'MontezumaRevenge': 'montezuma_revenge',
    'MsPacman': 'ms_pacman', 'NameThisGame': 'name_this_game', 'Phoenix': 'phoenix',
    'Riverraid': 'riverraid', 'RoadRunner': 'road_runner', 'Seaquest': 'seaquest',
    'SpaceInvaders': 'space_invaders', 'Venture': 'venture'
}

# Load encoded grid tensor
flag_load_encoded_grid_tensor = True
train_val_split = 0.8
encoded_grid_tensor = torch.load('other/encoded_grid_tensor.pt', weights_only=True)

# Process each game
for game_name in games:
    print(game_name)
    memory = HDF5ReplayBufferRAM.load(
        file_path=f'replay_buffers/atarihead/{AA_to_AH[game_name]}_atarihead_buffer_all_4f_84gray.h5',
        file_rw_option='r', 
        train_val_split=train_val_split, 
        RAM_ratio=1/1
    )

    n_states=memory.size
    game_res=(84,84,1)
    gaze_target_out=np.zeros([n_states,1,21,21],dtype=np.float16)


    for j in np.arange(0,n_states,n_batch):
        print(j)
        len_batch=np.min([n_states-j,n_batch])

        state_np=memory.file['state'][j:j+n_batch]
        gaze_positions=memory.file['gaze_pos'][np.max([0,j-n_TIWs+1]):j+n_batch]
        gaze_positions=rescale_gaze_positions(gaze_positions)
        gaze_positions=gaze_positions.reshape(-1,2)

        if len(gaze_positions)<(len(state_np)+n_TIWs-1)*FRAMESKIP:
            n_missing=(len(state_np)+n_TIWs-1)*FRAMESKIP-len(gaze_positions)
            gaze_positions=np.concat([gaze_positions[:1].repeat(n_missing,0),gaze_positions],axis=0)

        non_nan_mask = ~np.isnan(gaze_positions).any(axis=1)
        gaze_pos=torch.tensor(gaze_positions,dtype=torch.float32)

        gaze_pos_merged=torch.zeros_like(gaze_pos)
        for i in range(len_batch):
            cell_size=6
            gaze_pos_merged[i*FRAMESKIP:i*FRAMESKIP+n_TIW]=local_merge_torch(gaze_pos[i*FRAMESKIP:i*FRAMESKIP+n_TIW],cell_size)

        gaze_target, valid_mask = convert_coords_to_encoded_row(gaze_pos_merged,encoded_grid_tensor)

        full_mask=non_nan_mask*valid_mask.cpu().detach().numpy()
        full_mask_unfold = []

        for i in range(len_batch):
            row = full_mask[i*FRAMESKIP:i*FRAMESKIP+n_TIW]
            full_mask_unfold.append(row)
        
        full_mask_unfold = np.array(full_mask_unfold,dtype=np.bool)
        gaze_target_unfold = gaze_target.unfold(dimension=0, size=n_TIW, step=FRAMESKIP).permute(0,4,1,2,3)

        gaze_target_full=torch.zeros(len_batch,*gaze_target.shape[1:])

        for i in range(len_batch):
            if full_mask_unfold[i].sum()!=0:
                TIW_gazes=gaze_target_unfold[i][full_mask_unfold[i]]

                decay=decay=1-torch.pow(10,1*torch.linspace(1,0,len(TIW_gazes),device=device))/10+1/10

                gaze_target_full[i],_=(TIW_gazes*decay.view(full_mask_unfold[i].sum(), 1, 1, 1)).max(dim=0)
                
                gaze_target_full[i]=gaze_target_full[i]/gaze_target_full[i].sum()

        gaze_target_out[j:j+n_batch]=(gaze_target_full.clone().detach().cpu().numpy()).astype(np.float16)
    
    np.save(fr'other/{game_name}_gaze_target_all_4f_v6.npy',gaze_target_out)

Riverraid
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000


# Animation to visualize target map

In [None]:
# Animation parameters
fromm = 30000
until = fromm + 500
gaze_limit = 0.007

# Create upscaled gaze target visualization
gtf_upscale = F.interpolate(
    torch.tensor(gaze_target_out[fromm:until]), 
    size=(84, 84), 
    mode='bilinear', 
    align_corners=False
).squeeze(1) > gaze_limit

# Create image overlay
img_overlay = memory.file['state'][fromm:until]
img_overlay[...] = img_overlay[...] + 50 * (gtf_upscale.detach().cpu().numpy()).astype(np.uint8)

# Create frames for animation
frames = [Image.fromarray(entry) for entry in img_overlay]

# Process frames for animation
i_gaze = 0
for i in range(500):
    if non_nan_mask[i]:
        frames[i] = Image.fromarray(img_overlay[i])
        i_gaze += 1

        if i_gaze >= 1 + non_nan_mask.sum():
            break

# Save animation
frames[0].save(
    'Gifs/Gaze_target_construction/gaze_final_target_animation.gif', 
    save_all=True, 
    append_images=frames[1:], 
    duration=1/60,
    loop=0  # Infinite loop
)