In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        device = torch.device("cpu")
        print("Using CPU")
    return device

device = get_device()

Using GPU: NVIDIA GeForce RTX 4090


In [3]:
from npz_loader import discover_games

game_paths = discover_games(Path("dataset/"))
print("\n".join(str(p) for p in game_paths))

dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4
dataset\BreakoutNoFrameskip-v4\BreakoutNoFrameskip-v4
dataset\EnduroNoFrameskip-v4\EnduroNoFrameskip-v4
dataset\MsPacmanNoFrameskip-v4\MsPacmanNoFrameskip-v4
dataset\PongNoFrameskip-v4\PongNoFrameskip-v4
dataset\QbertNoFrameskip-v4\QbertNoFrameskip-v4
dataset\SeaquestNoFrameskip-v4\SeaquestNoFrameskip-v4
dataset\SpaceInvadersNoFrameskip-v4\SpaceInvadersNoFrameskip-v4


In [4]:
from npz_loader import discover_game_npz_paths, get_sequences_by_game
from pprint import pprint

train_game_dirs = [
    Path(r"dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4"),
    Path(r"dataset\BreakoutNoFrameskip-v4\BreakoutNoFrameskip-v4"),
]

npz_paths_by_game = discover_game_npz_paths(train_game_dirs)
sequences_by_game = get_sequences_by_game(npz_paths_by_game)

print(f"Loaded games:")
for game in sequences_by_game.keys():
    print(f"\t{game}")
    print(f"\t\tNumber of sequences: {len(sequences_by_game[game])}")

one_game_key = list(sequences_by_game.keys())[0]
one_game_seq_list = sequences_by_game[one_game_key]
one_game_seq = one_game_seq_list[0]

print(f"Keys in one_game_seq ({type(one_game_seq)}):")

one_row = []
for key in one_game_seq.keys():
    arr = one_game_seq[key]
    print(f"\t{key}:")
    print(f"\t\tshape={arr.shape}")
    print(f"\t\tdtype={arr.dtype}")

    row_to_print = 10
    if len(one_game_seq[key]) > 1:
        one_row.append(f"{key}: {one_game_seq[key][row_to_print]}")
    else:
        one_row.append(f"{key}: {one_game_seq[key][0]}")

Loaded games:
	dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4
		Number of sequences: 200
	dataset\BreakoutNoFrameskip-v4\BreakoutNoFrameskip-v4
		Number of sequences: 200
Keys in one_game_seq (<class 'dict'>):
	model selected actions:
		shape=(548, 1)
		dtype=int64
	taken actions:
		shape=(548, 1)
		dtype=int64
	obs:
		shape=(548,)
		dtype=<U193
	rewards:
		shape=(548,)
		dtype=float64
	episode_returns:
		shape=(1,)
		dtype=float64
	episode_starts:
		shape=(548,)
		dtype=bool
	repeated:
		shape=(548,)
		dtype=bool


In [5]:
from atari_dataset import AtariDataset

atari_dataset = AtariDataset(sequences_by_game, context_len=200)

for i in range(0, 1):
    print("-" * 100)
    obs_paths, actions, rtg = atari_dataset[i]

    print("=== Observation Sequence ===")
    for i, o in enumerate(obs_paths):
        print(f"{i}: {o}")

    print("\n=== Action Sequence ===")
    print(actions.tolist())

    print("\n=== RTG ===")
    print(rtg)

----------------------------------------------------------------------------------------------------
=== Observation Sequence ===
0: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\0.png
1: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\1.png
2: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\2.png
3: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\3.png
4: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\4.png
5: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\5.png
6: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\6.png
7: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\7.png
8: dataset\Bea

In [6]:
from pathlib import Path

from torch.utils.data import DataLoader

from atari_dataset import AtariDataset
from image_io import AtariImageLoader

# -------------------------------
# 1. Build dataset and dataloader
# -------------------------------
context_len = 32  # or whatever you want for testing

dataset = AtariDataset(sequences_by_game, context_len=context_len)
print("dataset length:", len(dataset))

if len(dataset) == 0:
    raise RuntimeError("Dataset is empty â€” check sequences_by_game and context_len.")

loader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=lambda batch: tuple(zip(*batch)),  # unzip (obs, actions, rtg)
)

# -------------------------------
# 2. Get a single batch
# -------------------------------
obs_paths_batch, actions_batch, rtg_batch = next(iter(loader))

print("\n=== Raw obs paths sample ===")
print("type(obs_paths_batch):", type(obs_paths_batch))
print("len(obs_paths_batch):", len(obs_paths_batch))

first_seq = obs_paths_batch[0]
print("len(first_seq):", len(first_seq))
print("first few paths in first_seq:")
for i, p in enumerate(first_seq[:5]):
    print(f"  {i}: {p}")

# Check existence of first path
first_path_str = first_seq[0]
first_path = Path(first_path_str)
print("\nFirst path check:")
print("  raw:", first_path_str)
print("  resolved:", first_path.resolve())
print("  exists?:", first_path.exists())

# -------------------------------
# 3. Load images with AtariImageLoader
# -------------------------------
img_loader = AtariImageLoader(img_size=84, grayscale=True)

frames = img_loader.load_batch(obs_paths_batch)  # (B, T, C, H, W)

print("\n=== Loaded frames ===")
print("frames shape:", frames.shape)             # expect (B, T, 1, 84, 84)
print("frames min/max:", frames.min().item(), frames.max().item())

print("\n=== Actions / RTG shapes ===")
print("actions[0].shape:", actions_batch[0].shape)
print("rtg[0].shape:", rtg_batch[0].shape)

print("\nFirst few actions of first seq:", actions_batch[0][:10].tolist())
print("First few RTG of first seq:", rtg_batch[0][:10])


dataset length: 160587

=== Raw obs paths sample ===
type(obs_paths_batch): <class 'tuple'>
len(obs_paths_batch): 2
len(first_seq): 32
first few paths in first_seq:
  0: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\0.png
  1: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\1.png
  2: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\2.png
  3: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\3.png
  4: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\4.png

First path check:
  raw: dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\0.png
  resolved: C:\Users\idanc\local\projects\AtariDeepLearning\dataset\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4\BeamRiderNoFrameskip-v4-recorded_images-0\0.

In [7]:
from baseline_encoder import AtariPatchEncoder

patch_encoder = AtariPatchEncoder(
    img_size=84,
    patch_size=14,
    in_channels=1,
    d_model=128,  # or 256, etc. instead of the default 768 if you want a smaller model
)

In [8]:
import torch

from baseline_encoder import AtariPatchEncoder
from tokenizer import MGDTTokenizer

# ------------------------------------------------
# 1. Prepare batch tensors from DataLoader output
# ------------------------------------------------
# obs_paths_batch, actions_batch, rtg_batch already from DataLoader:
#   obs_paths_batch: tuple length B, each (T,)
#   actions_batch:   tuple length B, each (T,)
#   rtg_batch:       tuple length B, each (T,)

B = len(actions_batch)
T = actions_batch[0].shape[0]

actions_tensor = torch.stack(actions_batch, dim=0)  # (B,T)
rtg_tensor = torch.stack(rtg_batch, dim=0)          # (B,T)

print("actions_tensor shape:", actions_tensor.shape)
print("rtg_tensor shape:", rtg_tensor.shape)
print("frames shape:", frames.shape)  # from image_io.AtariImageLoader.load_batch

# ------------------------------------------------
# 2. Initialize patch encoder + tokenizer
# ------------------------------------------------
n_actions = dataset.n_actions()  # from AtariDataset

patch_encoder = AtariPatchEncoder(
    img_size=84,
    patch_size=14,
    in_channels=1,
    d_model=128,   # choose a manageable model size for now
)

tokenizer = MGDTTokenizer(
    patch_encoder=patch_encoder,
    n_actions=n_actions,
    n_games=1,      # single-game for now (BeamRider)
    rtg_min=-20,
    rtg_max=100,
)

# ------------------------------------------------
# 3. Run tokenizer
# ------------------------------------------------
game_ids = torch.zeros(B, dtype=torch.long)  # all BeamRider -> game 0

tok_out = tokenizer(
    frames=frames,                # (B,T,1,84,84)
    actions=actions_tensor,       # (B,T)
    rtg=rtg_tensor,               # (B,T)
    game_ids=game_ids,            # (B,) or (B,T)
)

tokens = tok_out.tokens

print("\n=== Tokenizer output ===")
print("tokens shape:", tokens.shape)
print("B, T:", tok_out.B, tok_out.T)
print("num_patches:", tok_out.num_patches)
print("tokens_per_step:", tok_out.tokens_per_step)
print("sequence length L:", tokens.shape[1])


actions_tensor shape: torch.Size([2, 32])
rtg_tensor shape: torch.Size([2, 32])
frames shape: torch.Size([2, 32, 1, 84, 84])

=== Tokenizer output ===
tokens shape: torch.Size([2, 1248, 128])
B, T: 2 32
num_patches: 36
tokens_per_step: 39
sequence length L: 1248


In [9]:
import torch
from baseline_model import MultiGameDecisionTransformer

tokens = tok_out.tokens           # (B, L, d_model)
B, L, D = tokens.shape
T = tok_out.T
S = tok_out.tokens_per_step
n_actions = dataset.n_actions()

print("Transformer input tokens shape:", tokens.shape)
print("T:", T, "tokens_per_step:", S, "L:", L)

model = MultiGameDecisionTransformer(
    d_model=D,
    n_actions=n_actions,
    n_layers=2,          # keep small for now
    n_heads=4,
    dim_feedforward=4 * D,
    dropout=0.1,
    max_seq_len=2048,
)

out = model(tokens, tokens_per_step=S, T=T)

print("\n=== MGDT model output ===")
print("logits shape:", out.logits.shape)         # expect (B, T, n_actions)
print("hidden shape:", out.hidden.shape)         # (B, L, D)
print("action_positions shape:", out.action_positions.shape)
print("first action_positions row:", out.action_positions[0])


Transformer input tokens shape: torch.Size([2, 1248, 128])
T: 32 tokens_per_step: 39 L: 1248





=== MGDT model output ===
logits shape: torch.Size([2, 32, 9])
hidden shape: torch.Size([2, 1248, 128])
action_positions shape: torch.Size([2, 32])
first action_positions row: tensor([  38,   77,  116,  155,  194,  233,  272,  311,  350,  389,  428,  467,
         506,  545,  584,  623,  662,  701,  740,  779,  818,  857,  896,  935,
         974, 1013, 1052, 1091, 1130, 1169, 1208, 1247])


In [10]:
from baseline_train import train

model = train(sequences_by_game)

Using device: cuda
Dataset length: 160587
Epoch 1/1


                                                                                                  

Epoch 1 done. Avg loss: 0.0125


