In [None]:


from pathlib import Path
from npz_loader import discover_game_npz_paths, get_sequences_by_game, fix_obs_paths
from atari_preprocess import preprocess_one_sequence, find_image_folders
from torch.utils.data import Dataset, DataLoader



# 1. GAME ROOT DIRECTORY (Pick any game)
train_game_dirs = [
    Path("dataset/BeamRiderNoFrameskip-v4/BeamRiderNoFrameskip-v4")
    #Path("dataset/MsPacmanNoFrameskip-v4/BeamRiderNoFrameskip-v4"),
    #Path("dataset/BreakoutNoFrameskip-v4/BreakoutNoFrameskip-v4"),
    #Path("dataset/MsPacmanNoFrameskip-v4/EnduroNoFrameskip-v4"),
    #Path("dataset/BreakoutNoFrameskip-v4/MsPacmanNoFrameskip-v4"),
    #Path("dataset/MsPacmanNoFrameskip-v4/QbertNoFrameskip-v4"),
    #Path("dataset/BreakoutNoFrameskip-v4/SeaquestNoFrameskip-v4"),
    #Path("dataset/BreakoutNoFrameskip-v4/SpaceInvadersNoFrameskip-v4"),
]
game_tokens = []
for game_root in train_game_dirs:

    # 2. Load ALL NPZ files (already sorted by discover_game_npz_paths)
    npz_paths_by_game = discover_game_npz_paths([game_root])
    game_to_sequences = get_sequences_by_game(npz_paths_by_game)
    sequences_by_game = fix_obs_paths(game_to_sequences, dataset_root="dataset")

    game_key = list(sequences_by_game.keys())[0]
    npz_sequences = sequences_by_game[game_key]

    print(f"Found {len(npz_sequences)} NPZ sequences.")

    # 3. Load ALL image folders for this game
    image_folders = find_image_folders(game_root)
    print(f"Found {len(image_folders)} image folders.")

    if len(image_folders) != len(npz_sequences):
        print("WARNING: NPZ count and image folder count differ!")
        print("Index pairing may need manual correction!")

    # 4. Preprocess ALL sequences into token lists
    all_token_lists = []

    for idx, npz_seq in enumerate(npz_sequences):
        img_folder = image_folders[idx]
        print(f"\nProcessing episode {idx}/{len(npz_sequences)-1}")
        print(f"NPZ ↔ {img_folder.name}")

        tokens = preprocess_one_sequence(
            npz_seq=npz_seq,
            game_root=game_root,
            seq_index=idx    # episode index
        )

        print(f" → Tokens created: {len(tokens)}")

        all_token_lists.append(tokens)

    print("\n======== PREPROCESSING COMPLETE ========")
    print(f"Total episodes processed: {len(all_token_lists)}")
    print(f"Example episode #0 token count: {len(all_token_lists[0])}")
    game_tokens.append(all_token_lists)
    


Found 200 NPZ sequences.
Found 200 image folders.

Processing episode 0/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-0
 → Tokens created: 21372

Processing episode 1/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-1
 → Tokens created: 14274

Processing episode 2/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-2
 → Tokens created: 21021

Processing episode 3/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-3
 → Tokens created: 17043

Processing episode 4/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-4
 → Tokens created: 20787

Processing episode 5/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-5
 → Tokens created: 18837

Processing episode 6/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-6
 → Tokens created: 20124

Processing episode 7/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-7
 → Tokens created: 32643

Processing episode 8/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-8
 → Tokens created: 18954

Processing episode 9/199
NPZ ↔ BeamRiderNoFrameskip-v4-rec

In [1]:
import pickle
from pathlib import Path
from npz_loader import discover_game_npz_paths, get_sequences_by_game, fix_obs_paths
from atari_preprocess import preprocess_one_sequence, find_image_folders
import numpy as np

train_game_dirs = [
    Path("dataset/BeamRiderNoFrameskip-v4/BeamRiderNoFrameskip-v4"),
    #Path("dataset/BreakoutNoFrameskip-v4/BreakoutNoFrameskip-v4"),
    #Path("dataset/EnduroNoFrameskip-v4/EnduroNoFrameskip-v4"),
    #Path("dataset/MsPacmanNoFrameskip-v4/MsPacmanNoFrameskip-v4"),
    #Path("dataset/QbertNoFrameskip-v4/QbertNoFrameskip-v4"),
    #Path("dataset/SeaquestNoFrameskip-v4/SeaquestNoFrameskip-v4"),
    #Path("dataset/SpaceInvadersNoFrameskip-v4/SpaceInvadersNoFrameskip-v4"),
]

for game_root in train_game_dirs:

    # Load NPZ files
    npz_paths_by_game = discover_game_npz_paths([game_root])
    game_to_sequences = get_sequences_by_game(npz_paths_by_game)
    sequences_by_game = fix_obs_paths(game_to_sequences, dataset_root="dataset")

    game_key = list(sequences_by_game.keys())[0]
    npz_sequences = sequences_by_game[game_key]
    print(f"Found {len(npz_sequences)} NPZ sequences.")

    # Load image folders
    image_folders = find_image_folders(game_root)
    print(f"Found {len(image_folders)} image folders.")

    if len(image_folders) != len(npz_sequences):
        print("WARNING: NPZ count and image folder count differ!")
    # Preprocess & collect tokens for this game
    all_token_lists = []

    for idx, npz_seq in enumerate(npz_sequences):
        img_folder = image_folders[idx]
        print(f"\nProcessing episode {idx}/{len(npz_sequences)-1} — {img_folder.name}")

        tokens = preprocess_one_sequence(
            npz_seq=npz_seq,
            game_root=game_root,
            seq_index=idx
        )
        print(f" → Tokens created: {len(tokens)}")
        all_token_lists.append(tokens)

    print("\n======== PREPROCESSING COMPLETE ========")

    # Save tokens
    out_path = Path("token_outputs")
    out_path.mkdir(exist_ok=True)

    safe_game_key = str(game_key).replace("/", "_")   # <-- SIMPLE FIX

    pickle_file = out_path / f"{safe_game_key}_tokens.pkl"

    with open(pickle_file, "wb") as f:
        pickle.dump(all_token_lists, f)

    print(f"Saved token file → {pickle_file}")
    print("=====================================================\n")


Found 200 NPZ sequences.
Found 200 image folders.

Processing episode 0/199 — BeamRiderNoFrameskip-v4-recorded_images-0


  "value": patch_emb(torch.tensor(nn_input, dtype=torch.float32))


 → Tokens created: 1644

Processing episode 1/199 — BeamRiderNoFrameskip-v4-recorded_images-1
 → Tokens created: 1098

Processing episode 2/199 — BeamRiderNoFrameskip-v4-recorded_images-2
 → Tokens created: 1617

Processing episode 3/199 — BeamRiderNoFrameskip-v4-recorded_images-3
 → Tokens created: 1311

Processing episode 4/199 — BeamRiderNoFrameskip-v4-recorded_images-4
 → Tokens created: 1599

Processing episode 5/199 — BeamRiderNoFrameskip-v4-recorded_images-5
 → Tokens created: 1449

Processing episode 6/199 — BeamRiderNoFrameskip-v4-recorded_images-6
 → Tokens created: 1548

Processing episode 7/199 — BeamRiderNoFrameskip-v4-recorded_images-7
 → Tokens created: 2511

Processing episode 8/199 — BeamRiderNoFrameskip-v4-recorded_images-8
 → Tokens created: 1458

Processing episode 9/199 — BeamRiderNoFrameskip-v4-recorded_images-9
 → Tokens created: 1749

Processing episode 10/199 — BeamRiderNoFrameskip-v4-recorded_images-10
 → Tokens created: 2826

Processing episode 11/199 — BeamR

In [None]:
from transformers import DecisionTransformerConfig, DecisionTransformerModel
import torch

IMG_SIZE = 84
PATCH_SIZE = 14
PATCHES_PER_FRAME = (IMG_SIZE // PATCH_SIZE) ** 2    # 36 patches
RETURN_MIN = -20
RETURN_MAX = 100
PATCH_EMB_SIZE = 128

states = [state[0] for state in all_token_lists]
rtg = [rtg[1] for rtg in all_token_lists]
actions = [actions[2] for actions in all_token_lists]
states = [a['value'] for a in states]
rtg = [a['value'] for a in rtg]
actions = [a['value'] for a in actions]
states = [a.flatten() for a in states]
states = torch.stack(states)
#states = states.unsqueeze(0)
#actions = torch.tensor(actions).unsqueeze(0)
#rtg = torch.tensor(rtg).unsqueeze(0)

print(states.shape)

# Initializing a DecisionTransformer configuration
configuration = DecisionTransformerConfig(state_dim=PATCHES_PER_FRAME * PATCH_EMB_SIZE, action_dim=18)

# Initializing a model (with random weights) from the configuration
model = DecisionTransformerModel(configuration)

# Forward pass need to standardize lengths for episode and mask the padded tokens
# Context Length = 20
# After padding, randomly select a subsequence of length 20 from each episode of the batch

num_epochs = 10
for epoch in range(num_epochs):
    output = model(
        states=states,
        actions=actions,
        returns_to_go=rtg
    )

torch.Size([200, 4608])
torch.Size([4608])


TypeError: linear(): argument 'input' (position 1) must be Tensor, not list