In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        device = torch.device("cpu")
        print("Using CPU")
    return device

device = get_device()

Using GPU: NVIDIA GeForce RTX 4090


# Find Game Paths

In [3]:
from npz_loader import discover_games

game_paths = discover_games(Path("dataset/"))
print("\n".join(str(p) for p in game_paths))

dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4
dataset\BreakoutNoFrameskip-v4\BreakoutNoFrameskip-v4
dataset\EnduroNoFrameskip-v4\EnduroNoFrameskip-v4
dataset\MsPacmanNoFrameskip-v4\MsPacmanNoFrameskip-v4
dataset\PongNoFrameskip-v4\PongNoFrameskip-v4
dataset\QbertNoFrameskip-v4\QbertNoFrameskip-v4
dataset\SeaquestNoFrameskip-v4\SeaquestNoFrameskip-v4
dataset\SpaceInvadersNoFrameskip-v4\SpaceInvadersNoFrameskip-v4


# Load NPZ Files in to Dict

In [4]:
from npz_loader import discover_game_npz_paths, get_sequences_by_game, fix_obs_paths
from pprint import pprint

train_game_dirs = [
    Path(r"dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4"),
    Path(r"dataset\BreakoutNoFrameskip-v4\BreakoutNoFrameskip-v4"),
]

npz_paths_by_game = discover_game_npz_paths(train_game_dirs)
game_to_sequences = get_sequences_by_game(npz_paths_by_game)
sequences_by_game = fix_obs_paths(game_to_sequences, dataset_root="dataset")

print(f"Loaded games:")
for game in sequences_by_game.keys():
    print(f"\t{game}")
    print(f"\t\tNumber of sequences: {len(sequences_by_game[game])}")

one_game_key = list(sequences_by_game.keys())[0]
one_game_seq_list = sequences_by_game[one_game_key]
one_game_seq = one_game_seq_list[0]

print(f"Keys in one_game_seq ({type(one_game_seq)}):")

one_row = []
for key in one_game_seq.keys():
    arr = one_game_seq[key]
    print(f"\t{key}:")
    print(f"\t\tshape={arr.shape}")
    print(f"\t\tdtype={arr.dtype}")

    row_to_print = 10
    if len(one_game_seq[key]) > 1:
        one_row.append(f"{key}: {one_game_seq[key][row_to_print]}")
    else:
        one_row.append(f"{key}: {one_game_seq[key][0]}")

Loaded games:
	dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4
		Number of sequences: 200
	dataset\BreakoutNoFrameskip-v4\BreakoutNoFrameskip-v4
		Number of sequences: 200
Keys in one_game_seq (<class 'dict'>):
	model selected actions:
		shape=(548, 1)
		dtype=int64
	taken actions:
		shape=(548, 1)
		dtype=int64
	obs:
		shape=(548,)
		dtype=<U193
	rewards:
		shape=(548,)
		dtype=float64
	episode_returns:
		shape=(1,)
		dtype=float64
	episode_starts:
		shape=(548,)
		dtype=bool
	repeated:
		shape=(548,)
		dtype=bool


# Test Dataset

In [5]:
from atari_dataset import AtariDataset

atari_dataset = AtariDataset(sequences_by_game, context_len=200)

for i in range(0, 1):
    print("-" * 100)
    frames, actions, rtg = atari_dataset[i]

    print("=== Frames (images) ===")
    print(f"frames shape: {frames.shape}")  # (T, C, H, W)
    print(f"frames min/max: {frames.min().item():.4f} / {frames.max().item():.4f}")

    print("\n=== Action Sequence ===")
    print(actions.tolist())

    print("\n=== RTG ===")
    print(rtg)

----------------------------------------------------------------------------------------------------
=== Frames (images) ===
frames shape: torch.Size([200, 1, 84, 84])
frames min/max: 0.0000 / 0.9255

=== Action Sequence ===
[8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1, 1, 1, 1, 1, 8, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 7, 7, 8, 7, 8, 7, 7, 8, 8, 7, 8, 7, 8, 7, 8, 7, 7, 8, 8, 2, 7, 7, 7, 7, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 8, 8, 1, 7, 7, 8, 8, 8, 7, 1, 1, 1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 1, 1, 1, 2, 1, 8, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 7, 7, 7, 1, 1, 7, 7, 1, 7, 7, 7, 8, 8, 8, 8, 1, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

=== RTG ===
tensor([7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,
        7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7., 7.,
 

# Test Dataloader

In [6]:
from torch.utils.data import DataLoader
from baseline_train import collate_fn

from atari_dataset import AtariDataset

# -------------------------------
# 1. Build dataset and dataloader
# -------------------------------
context_len = 32  # or whatever you want for testing

dataset = AtariDataset(sequences_by_game, context_len=context_len)
print("dataset length:", len(dataset))

if len(dataset) == 0:
    raise RuntimeError("Dataset is empty â€” check sequences_by_game and context_len.")

loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=collate_fn,  # Use the collate_fn from baseline_train
)

# -------------------------------
# 2. Get a single batch
# -------------------------------
frames_batch, actions_batch, rtg_batch = next(iter(loader))

print("\n=== Frames batch ===")
print("frames_batch shape:", frames_batch.shape)  # (B, T, C, H, W)
print("frames_batch min/max:", frames_batch.min().item(), frames_batch.max().item())

print("\n=== Actions / RTG shapes ===")
print("actions_batch shape:", actions_batch.shape)  # (B, T)
print("rtg_batch shape:", rtg_batch.shape)  # (B, T)

print("\nFirst few actions of first seq:", actions_batch[0][:10].tolist())
print("First few RTG of first seq:", rtg_batch[0][:10])


dataset length: 160587

=== Frames batch ===
frames_batch shape: torch.Size([2, 32, 1, 84, 84])
frames_batch min/max: 0.0 0.9254902005195618

=== Actions / RTG shapes ===
actions_batch shape: torch.Size([2, 32])
rtg_batch shape: torch.Size([2, 32])

First few actions of first seq: [8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
First few RTG of first seq: tensor([7., 7., 7., 7., 7., 7., 7., 7., 7., 7.])
