# Import

In [None]:
# Standard libraries
import random
import os
import sys

# Data processing
import numpy as np
import pandas as pd
from PIL import Image

# Deep learning
import torch

# Visualization
import matplotlib
from IPython.display import display

# Project-specific imports
sys.path.insert(0, './atarihead')
sys.path.insert(0, '../src')
import data_reader as ahdr
from replay_buffer import HDF5ReplayBufferRAM as HDF5ReplayBuffer

# Configuration
pd.options.display.float_format = '{:.3g}'.format
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_dtype(torch.float32)

# IPython display setup
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

# Load and pre-process Atarihead

In [None]:
def load_images_as_numpy(folder_path, frameid_list):
    """Load images from folder and convert to numpy arrays."""
    images = []
    image_ids = []
    
    for idf, frameid in enumerate(frameid_list):
        file_path = os.path.join(folder_path, f"{frameid}.png")
        img = Image.open(file_path)
        img_array = np.array(img, dtype=np.uint8)
        images.append(img_array)
        image_ids.append(frameid)

        if idf % 1000 == 0:
            print(f"Loaded {idf} images")
        
    return np.stack(images, axis=0), image_ids


In [None]:
game_name='...'
FRAMESKIP=4 # very important

game_folder_path=fr"atarihead/records/{game_name}"#/{game_name}"

file_names = [f for f in os.listdir(game_folder_path) if os.path.isdir(os.path.join(game_folder_path, f)) and f != "highscore"]

random.shuffle(file_names)
print(file_names)

['430_RZ_2946939_Jul-12-10-57-06', '373_RZ_2176102_Jul-03-12-49-03', '397_RZ_2529116_Jul-07-14-52-39', '438_RZ_3133937_Jul-14-14-53-28', '362_RZ_1998636_Jul-01-11-40-24', '342_RZ_1413379_Jun-24-16-58-50', '411_RZ_2693470_Jul-09-12-32-07', '513_RZ_3993183_Jul-24-13-33-44', '417_RZ_2789701_Jul-10-15-15-49', '470_RZ_3392231_Jul-17-14-40-04', '380_RZ_2261928_Jul-04-12-39-29', '403_RZ_2605063_Jul-08-12-00-13', '456_RZ_3229544_Jul-15-17-27-11', '368_RZ_2084683_Jul-02-11-26-18', '426_RZ_2883723_Jul-11-17-22-42', '386_RZ_2346273_Jul-05-12-06-05', '344_RZ_1489894_Jun-25-14-12-49', '509_RZ_3924539_Jul-23-18-29-39', '505_RZ_3917332_Jul-23-16-30-53', '355_RZ_1924894_Jun-30-15-02-13', '376_RZ_2180835_Jul-03-14-07-55', '506_RZ_3918510_Jul-23-16-49-15', '350_RZ_1740637_Jun-28-11-54-19', '347_RZ_1568679_Jun-26-12-06-00']


In [6]:
states=[]
action_vals_np=[]
action_vals = 0
reward_vals_np = []
reward_vals = 0

In [None]:

frameid2pos_all, frameid2action_all, frameid2duration_all, frameid2unclipped_reward_all, frameid2episode_all, frameid2score_all = {},{},{},{},{},{}
frameid2pos_long = {}
frameid_list_all=[]
frameid_list_long_all=[]

framestack_ids_all=[]
done_vals=[]
for idf,file_name in enumerate(file_names):
    print(fr'File {idf+1} out of {len(file_names)}')
    file_path=game_folder_path+r'/'+file_name+'.txt'
    folder_path=game_folder_path+r'/'+file_name
    tar_bz2_file_path = game_name+file_name+'.tar.bz2'

    frameid2pos, frameid2action, frameid2duration, frameid2unclipped_reward, frameid2episode, frameid2score, frameid_list = ahdr.read_gaze_data_csv_file(file_path)

    frameid_list_shortened=frameid_list[FRAMESKIP-1:-1:FRAMESKIP] # Exclude last frame as there is no reward or action information
    frameid_list_long=frameid_list[:len(frameid_list_shortened)*FRAMESKIP]

    frame_npstack, framestack_ids = load_images_as_numpy(folder_path, frameid_list_shortened)

    frameid2pos_long.update({k: frameid2pos[k] for k in frameid_list_long if k in frameid2pos})

    frameid2pos_all.update({k: frameid2pos[k] for k in frameid_list_shortened if k in frameid2pos})
    frameid2action_all.update({k: frameid2action[k] for k in frameid_list_shortened if k in frameid2action})
    frameid2duration_all.update({k: frameid2duration[k] for k in frameid_list_shortened if k in frameid2duration})
    frameid2unclipped_reward_all.update({k: frameid2unclipped_reward[k] for k in frameid_list_shortened if k in frameid2unclipped_reward})
    frameid2episode_all.update({k: frameid2episode[k] for k in frameid_list_shortened if k in frameid2episode})
    frameid2score_all.update({k: frameid2score[k] for k in frameid_list_shortened if k in frameid2score})

    frameid_list_all=frameid_list_all+frameid_list_shortened
    frameid_list_long_all=frameid_list_long_all+frameid_list_long
    
    framestack_ids_all=framestack_ids_all+framestack_ids
    bool_list = [False] * (len(frame_npstack) - 1) + [True]
    done_vals.extend(bool_list)

    if idf==0:
        frame_npstack_all=frame_npstack
    else:
        frame_npstack_all=np.vstack([frame_npstack_all,frame_npstack])
    print('Total: '+str(len(frame_npstack_all))) 


File 1 out of 24
0
1000
2000
3000
4000
Total: 4448
File 2 out of 24
0
1000
2000
3000
4000
Total: 8851
File 3 out of 24
0
1000
2000
3000
4000
Total: 13244
File 4 out of 24
0
1000
2000
3000
4000
Total: 17677
File 5 out of 24
0
1000
2000
3000
4000
Total: 22083
File 6 out of 24
0
1000
2000
3000
4000
Total: 26432
File 7 out of 24
0
1000
2000
3000
4000
Total: 30795
File 8 out of 24
0
1000
2000
Total: 33699
File 9 out of 24
0
1000
2000
3000
4000
Total: 38116
File 10 out of 24
0
1000
2000
3000
4000
Total: 42540
File 11 out of 24
0
1000
2000
3000
4000
Total: 46952
File 12 out of 24
0
1000
2000
3000
4000
Total: 51367
File 13 out of 24
0
1000
2000
3000
4000
Total: 55792
File 14 out of 24
0
1000
2000
3000
4000
Total: 60179
File 15 out of 24
0
1000
2000
3000
4000
Total: 64589
File 16 out of 24
0
1000
2000
3000
4000
Total: 69011
File 17 out of 24
0
1000
2000
3000
4000
Total: 73411
File 18 out of 24
0
1000
2000
3000
4000
Total: 77639
File 19 out of 24
0
1000
2000
3000
4000
Total: 82063
File 20 out of

# Create state, action, reward tensors

In [None]:
states=[frame_npstack_all[framestack_ids_all.index(key)] for key in frameid_list_all]
states = np.array(states, dtype=np.uint8)

action_vals_np=[frameid2action_all[key] for key in frameid_list_all if key in frameid2action_all]
action_vals = torch.tensor(action_vals_np, dtype=torch.int32).unsqueeze(1)

reward_vals_np = [frameid2unclipped_reward_all[key] if frameid2unclipped_reward_all[key] is not None else 0 for key in frameid_list_all if key in frameid2unclipped_reward_all]

reward_vals = torch.tensor(reward_vals_np, dtype=torch.float32).unsqueeze(1)

# Create replay buffer

In [10]:
memory = HDF5ReplayBuffer(
    file_path=fr'atarihead/replay_buffers/{game_name}_atarihead_buffer_all_4f.h5',
    initial_capacity=size_buffer,
    state_shape=(210,160,3),
    action_shape=(1,),
    gaze_shape=(FRAMESKIP,2),
)

In [None]:
gaze_records={}
for i in range(len(action_vals)):
    state = np.array(states[i],dtype=np.uint8)
    next_state = states[i+1] if i+1 < len(states) else states[i]
    done=done_vals[i]

    final_gaze_pos=frameid2pos_all[frameid_list_all[i]][-1] if frameid2pos_all[frameid_list_all[i]] else None

    final_gaze_pos_long1=frameid2pos_long[frameid_list_long_all[i*FRAMESKIP+FRAMESKIP-4]][-1] if frameid2pos_long[frameid_list_long_all[i*FRAMESKIP+FRAMESKIP-4]] else np.empty(2)
    final_gaze_pos_long2=frameid2pos_long[frameid_list_long_all[i*FRAMESKIP+FRAMESKIP-3]][-1] if frameid2pos_long[frameid_list_long_all[i*FRAMESKIP+FRAMESKIP-3]] else np.empty(2)
    final_gaze_pos_long3=frameid2pos_long[frameid_list_long_all[i*FRAMESKIP+FRAMESKIP-2]][-1] if frameid2pos_long[frameid_list_long_all[i*FRAMESKIP+FRAMESKIP-2]] else np.empty(2)
    final_gaze_pos_long4=frameid2pos_long[frameid_list_long_all[i*FRAMESKIP+FRAMESKIP-1]][-1] if frameid2pos_long[frameid_list_long_all[i*FRAMESKIP+FRAMESKIP-1]] else np.empty(2)
    fgp_long=np.vstack([final_gaze_pos_long1,final_gaze_pos_long2,final_gaze_pos_long3,final_gaze_pos_long4])
    
    frame_id=frameid_list_all[i]

    gaze_records[frame_id]=fgp_long

    memory.push(state=state,action=action_vals[i],next_state=next_state,reward=reward_vals[i],frame_id=frameid_list_all[i],gaze_pos=fgp_long,done=done)
    if i%10000==0:
        print(i)
        
memory.close()

0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
