In [None]:


from pathlib import Path
from npz_loader import discover_game_npz_paths, get_sequences_by_game, fix_obs_paths
from atari_preprocess import preprocess_one_sequence, find_image_folders
from torch.utils.data import Dataset, DataLoader



# 1. GAME ROOT DIRECTORY (Pick any game)
train_game_dirs = [
    Path("dataset/BeamRiderNoFrameskip-v4/BeamRiderNoFrameskip-v4")
    #Path("dataset/MsPacmanNoFrameskip-v4/BeamRiderNoFrameskip-v4"),
    #Path("dataset/BreakoutNoFrameskip-v4/BreakoutNoFrameskip-v4"),
    #Path("dataset/MsPacmanNoFrameskip-v4/EnduroNoFrameskip-v4"),
    #Path("dataset/BreakoutNoFrameskip-v4/MsPacmanNoFrameskip-v4"),
    #Path("dataset/MsPacmanNoFrameskip-v4/QbertNoFrameskip-v4"),
    #Path("dataset/BreakoutNoFrameskip-v4/SeaquestNoFrameskip-v4"),
    #Path("dataset/BreakoutNoFrameskip-v4/SpaceInvadersNoFrameskip-v4"),
]
game_tokens = []
for game_root in train_game_dirs:

    # 2. Load ALL NPZ files (already sorted by discover_game_npz_paths)
    npz_paths_by_game = discover_game_npz_paths([game_root])
    game_to_sequences = get_sequences_by_game(npz_paths_by_game)
    sequences_by_game = fix_obs_paths(game_to_sequences, dataset_root="dataset")

    game_key = list(sequences_by_game.keys())[0]
    npz_sequences = sequences_by_game[game_key]

    print(f"Found {len(npz_sequences)} NPZ sequences.")

    # 3. Load ALL image folders for this game
    image_folders = find_image_folders(game_root)
    print(f"Found {len(image_folders)} image folders.")

    if len(image_folders) != len(npz_sequences):
        print("WARNING: NPZ count and image folder count differ!")
        print("Index pairing may need manual correction!")

    # 4. Preprocess ALL sequences into token lists
    all_token_lists = []

    for idx, npz_seq in enumerate(npz_sequences):
        img_folder = image_folders[idx]
        print(f"\nProcessing episode {idx}/{len(npz_sequences)-1}")
        print(f"NPZ ↔ {img_folder.name}")

        tokens = preprocess_one_sequence(
            npz_seq=npz_seq,
            game_root=game_root,
            seq_index=idx    # episode index
        )

        print(f" → Tokens created: {len(tokens)}")

        all_token_lists.append(tokens)

    print("\n======== PREPROCESSING COMPLETE ========")
    print(f"Total episodes processed: {len(all_token_lists)}")
    print(f"Example episode #0 token count: {len(all_token_lists[0])}")
    game_tokens.append(all_token_lists)
    


Found 200 NPZ sequences.
Found 200 image folders.

Processing episode 0/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-0
 → Tokens created: 21372

Processing episode 1/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-1
 → Tokens created: 14274

Processing episode 2/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-2
 → Tokens created: 21021

Processing episode 3/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-3
 → Tokens created: 17043

Processing episode 4/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-4
 → Tokens created: 20787

Processing episode 5/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-5
 → Tokens created: 18837

Processing episode 6/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-6
 → Tokens created: 20124

Processing episode 7/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-7
 → Tokens created: 32643

Processing episode 8/199
NPZ ↔ BeamRiderNoFrameskip-v4-recorded_images-8
 → Tokens created: 18954

Processing episode 9/199
NPZ ↔ BeamRiderNoFrameskip-v4-rec

In [None]:
import pickle
from pathlib import Path
from npz_loader import discover_game_npz_paths, get_sequences_by_game, fix_obs_paths
from atari_preprocess import preprocess_one_sequence, find_image_folders
import numpy as np

train_game_dirs = [
    Path("dataset/BeamRiderNoFrameskip-v4/BeamRiderNoFrameskip-v4"),
    #Path("dataset/BreakoutNoFrameskip-v4/BreakoutNoFrameskip-v4"),
    #Path("dataset/EnduroNoFrameskip-v4/EnduroNoFrameskip-v4"),
    #Path("dataset/MsPacmanNoFrameskip-v4/MsPacmanNoFrameskip-v4"),
    #Path("dataset/QbertNoFrameskip-v4/QbertNoFrameskip-v4"),
    #Path("dataset/SeaquestNoFrameskip-v4/SeaquestNoFrameskip-v4"),
    #Path("dataset/SpaceInvadersNoFrameskip-v4/SpaceInvadersNoFrameskip-v4"),
]

for game_root in train_game_dirs:

    # Load NPZ files
    npz_paths_by_game = discover_game_npz_paths([game_root])
    game_to_sequences = get_sequences_by_game(npz_paths_by_game)
    sequences_by_game = fix_obs_paths(game_to_sequences, dataset_root="dataset")

    game_key = list(sequences_by_game.keys())[0]
    npz_sequences = sequences_by_game[game_key]
    print(f"Found {len(npz_sequences)} NPZ sequences.")

    # Load image folders
    image_folders = find_image_folders(game_root)
    print(f"Found {len(image_folders)} image folders.")

    if len(image_folders) != len(npz_sequences):
        print("WARNING: NPZ count and image folder count differ!")
    # Preprocess & collect tokens for this game
    all_token_lists = []

    for idx, npz_seq in enumerate(npz_sequences):
        img_folder = image_folders[idx]
        print(f"\nProcessing episode {idx}/{len(npz_sequences)-1} — {img_folder.name}")

        tokens = preprocess_one_sequence(
            npz_seq=npz_seq,
            game_root=game_root,
            seq_index=idx
        )
        print(f" → Tokens created: {len(tokens)}")
        all_token_lists.append(tokens)

    print("\n======== PREPROCESSING COMPLETE ========")

    # Save tokens
    out_path = Path("token_outputs")
    out_path.mkdir(exist_ok=True)

    safe_game_key = str(game_key).replace("/", "_")   # <-- SIMPLE FIX

    pickle_file = out_path / f"{safe_game_key}_tokens.pkl"

    with open(pickle_file, "wb") as f:
        pickle.dump(all_token_lists, f)

    print(f"Saved token file → {pickle_file}")
    print("=====================================================\n")


Found 200 NPZ sequences.
Found 200 image folders.

Processing episode 0/199 — BeamRiderNoFrameskip-v4-recorded_images-0
 → Tokens created: 3

Processing episode 1/199 — BeamRiderNoFrameskip-v4-recorded_images-1
 → Tokens created: 3

Processing episode 2/199 — BeamRiderNoFrameskip-v4-recorded_images-2
 → Tokens created: 3

Processing episode 3/199 — BeamRiderNoFrameskip-v4-recorded_images-3
 → Tokens created: 3

Processing episode 4/199 — BeamRiderNoFrameskip-v4-recorded_images-4
 → Tokens created: 3

Processing episode 5/199 — BeamRiderNoFrameskip-v4-recorded_images-5
 → Tokens created: 3

Processing episode 6/199 — BeamRiderNoFrameskip-v4-recorded_images-6
 → Tokens created: 3

Processing episode 7/199 — BeamRiderNoFrameskip-v4-recorded_images-7
 → Tokens created: 3

Processing episode 8/199 — BeamRiderNoFrameskip-v4-recorded_images-8
 → Tokens created: 3

Processing episode 9/199 — BeamRiderNoFrameskip-v4-recorded_images-9
 → Tokens created: 3

Processing episode 10/199 — BeamRiderNo

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from transformers.models.decision_transformer.modeling_decision_transformer import DecisionTransformerModel
from transformers.models.decision_transformer.configuration_decision_transformer import DecisionTransformerConfig
import torch
import EpisodeData
from torch.utils.data import Dataset, DataLoader
import random
import torch.optim as optim

IMG_SIZE = 84
PATCH_SIZE = 14
PATCHES_PER_FRAME = (IMG_SIZE // PATCH_SIZE) ** 2    # 36 patches
RETURN_MIN = -20
RETURN_MAX = 100
PATCH_EMB_SIZE = 128
NUM_EPOCHS = 10
CONTEXT_LENGTH = 20
BATCH_SIZE = 32


# Data Loading and Sampling
# all_token_lists[0] is an episode with states, actions, rewards-to-go


#Additional data preprocessing

dataset = EpisodeData.EpisodeData(all_token_lists, CONTEXT_LENGTH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# Initializing a DecisionTransformer configuration
configuration = DecisionTransformerConfig(state_dim=PATCHES_PER_FRAME * PATCH_EMB_SIZE, act_dim=18, max_length=CONTEXT_LENGTH, n_layer=3, n_head=1)
configuration.action_tanh = False

# Initializing a model (with random weights) from the configuration
model = DecisionTransformerModel(configuration)
print(model.embed_state)
print(model.embed_timestep)
print(model.embed_action)




# Forward pass need to standardize lengths for episode and mask the padded tokens
# Context Length = 20
# After padding, randomly select a subsequence of length 20 from each episode of the batch
optimizer = optim.Adam(model.parameters(), lr=1e-4)


for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    batch_idx = 1
    for state, action, rtg in dataloader:

        state = state.clone().detach().float()
        action = action.clone().detach().long()
        rtg = rtg.clone().detach().float()
        timesteps = torch.arange(CONTEXT_LENGTH).unsqueeze(0).repeat(state.size(0), 1)

        action = torch.nn.functional.one_hot(action, num_classes=18).float()
        rtg = rtg.unsqueeze(-1)
        output = model(
            states=state,
            actions=action,
            returns_to_go=rtg,
            timesteps=timesteps
        )

        action_preds = output.action_preds
        action_targets = action.argmax(dim=-1)

        loss = torch.nn.functional.cross_entropy(action_preds.reshape(-1, 18),
            action_targets.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        
        batch_idx += 1
    
    print(f"Epoch {epoch} Average Loss: {total_loss / len(dataloader):.4f}")

        

Linear(in_features=4608, out_features=128, bias=True)
Embedding(4096, 128)
Linear(in_features=18, out_features=128, bias=True)
Epoch 0 Average Loss: 2.5397
Epoch 1 Average Loss: 2.1611
Epoch 2 Average Loss: 2.0659
Epoch 3 Average Loss: 1.9389
Epoch 4 Average Loss: 1.9744
Epoch 5 Average Loss: 1.9257
Epoch 6 Average Loss: 1.8395
Epoch 7 Average Loss: 1.8218
Epoch 8 Average Loss: 1.7699
Epoch 9 Average Loss: 1.6390
