In [None]:
import os
import json
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader

# Paths
image_folder = "/mnt/data/screenshots"
json_path = "/mnt/data/movements.json"

# Load movements JSON
with open(json_path, 'r') as f:
    movements = json.load(f)

# Define action space (binary encoding)
ACTION_KEYS = ["left", "right", "jump", "sprint"]

# Transform: Resize to 84x84 and normalize (for grayscale, but here we use color)
transform = transforms.Compose([
    transforms.Resize((84, 84)),
    transforms.ToTensor(),  # Converts to [0, 1]
])

class MarioDataset(Dataset):
    def __init__(self, image_folder, movements, frame_window=4):
        self.image_folder = image_folder
        self.movements = movements
        self.frame_window = frame_window
        self.frames = sorted([f for f in movements.keys() if f".png" not in f])
        self.valid_indices = list(range(self.frame_window, len(self.frames)))

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        idx = self.valid_indices[idx]
        # Load N previous frames
        imgs = []
        for i in range(idx - self.frame_window, idx):
            frame_key = self.frames[i]
            img_path = os.path.join(self.image_folder, f"{frame_key}.png")
            img = Image.open(img_path).convert("RGB")
            imgs.append(transform(img))

        # Stack into a 4x3x84x84 tensor (N=4)
        state = torch.stack(imgs, dim=0)

        # Get action as binary vector
        frame_key = self.frames[idx]
        action_info = self.movements[frame_key]["state"]
        action = torch.tensor([float(action_info[k]) for k in ACTION_KEYS], dtype=torch.float32)

        # Get reward
        reward = torch.tensor(action_info["reward"], dtype=torch.float32)

        return state, action, reward

# Load dataset
dataset = MarioDataset(image_folder, movements)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Show batch shape to confirm
batch = next(iter(dataloader))
state_batch, action_batch, reward_batch = batch

import ace_tools as tools; tools.display_dataframe_to_user(name="Sample Batch Info", dataframe={
    "State Shape": [tuple(state_batch.shape)],
    "Action Shape": [tuple(action_batch.shape)],
    "Reward Shape": [tuple(reward_batch.shape)],
})
