In [None]:
import os
import json
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from btr.Agent import Agent

def agent_action(filtered_keys):
    sprint = filtered_keys.get("B", False)
    move_right = filtered_keys.get("Right", False)
    move_left = filtered_keys.get("Left", False)

    jump = filtered_keys.get("A", False)
    crouch = filtered_keys.get("Down", False)
    airborne = filtered_keys.get("One", False)
    sprint_left = move_left and sprint
    sprint_right = move_right and sprint
    jump_left = (move_left or sprint_left) and jump
    jump_right = (move_right or sprint_right) and jump
    stand_still = not any([move_right, move_left, jump, crouch, airborne, sprint_left, sprint_right, jump_left, jump_right])

    remove_some = True
    if remove_some:
        if any([crouch, airborne, sprint_left, jump]):
            stand_still = True
            crouch = airborne = sprint_left = jump = False
    return {
        # "jump": jump,
        # "crouch": crouch,
        # "airborne": airborne,
        # "sprint_left": sprint_left,
        "sprint_right": sprint_right,
        "jump_left": jump_left,
        "jump_right": jump_right,
        "move_right": move_right,
        "move_left": move_left,
        "none": stand_still,
    }

ACTION_KEYS = [
                # "crouch",
                # "airborne",
                # "sprint_left",
                "sprint_right",
                "jump_right",
                "jump_left",
                
                "move_right",
                "move_left", 
                # "jump",
                "none"]
ACTION_TO_INDEX = {action: idx for idx, action in enumerate(ACTION_KEYS)}
NUM_ACTIONS = len(ACTION_KEYS)

class MarioDataset(Dataset):
    def __init__(self, image_folder, movements, frame_window=4):
        self.image_folder = image_folder
        self.movements = movements
        self.frame_window = frame_window
        self.frames = sorted(movements.keys())

        def is_valid_action(frame_key):
            keys = {k: v for k, v in movements[frame_key].items() if k != "state"}
            actions = agent_action(keys)
            return any(actions.get(k, False) for k in ACTION_KEYS if k != "none")

        # Only keep valid indices
        self.valid_indices = [
            i for i in range(frame_window, len(self.frames))
            if is_valid_action(self.frames[i])
        ]

    def process_image(self, img, resize_to=(140, 114)):
        img = img.resize(resize_to, Image.Resampling.LANCZOS)
        img_array = np.array(img, dtype=np.float32)
        img_tensor = torch.tensor(img_array).unsqueeze(0)
        return img_tensor

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        idx = self.valid_indices[idx]
        imgs = []
        for i in range(idx - self.frame_window, idx):
            frame_key = self.frames[i]
            img_path = os.path.join(self.image_folder, f"d_{frame_key}.png")
            img = Image.open(img_path).convert("L")
            imgs.append(self.process_image(img))

        state = torch.cat(imgs, dim=0)

        frame_key = self.frames[idx]
        keys = {k: v for k, v in self.movements[frame_key].items() if k != "state"}
        actions_dict = agent_action(keys)
        valid_actions = {a: v for a, v in actions_dict.items() if a in ACTION_KEYS}
        action = next((a for a, v in reversed(valid_actions.items()) if v), "none")

        reward = self.movements[frame_key]["state"].get("reward", 0.0)

        if idx + 1 < len(self.frames):
            next_frame_key = self.frames[idx + 1]
            next_img_path = os.path.join(self.image_folder, f"d_{next_frame_key}.png")
            next_img = Image.open(next_img_path).convert("L")
            next_state = torch.cat([*imgs[1:], self.process_image(next_img)], dim=0)
        else:
            next_state = state

        return state, action, reward, next_state

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def train(dataset, epochs=1, batch_size=32, learning_rate=0.0001):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    agent = Agent(
        n_actions=NUM_ACTIONS,
        input_dims=(4, 140, 114),
        device=device,
        num_envs=1,
        agent_name="offline_agent",
        total_frames=50000,
        testing=False,
        batch_size=batch_size,
        imagex=114,
        imagey=140,
    )

    for epoch in range(epochs):
        for states, actions, rewards, next_states in dataloader:
            actions_discrete = [ACTION_TO_INDEX[a] for a in actions]
            actions = torch.tensor(actions_discrete, dtype=torch.long)

            for s, a, r, ns in zip(states, actions, rewards, next_states):
                agent.store_transition(
                    state=s.numpy(),
                    action=a.item(),
                    reward=r.item(),
                    next_state=ns.numpy(),
                    done=False,
                    stream=0,
                    prio=True
                )

            for _ in range(agent.batch_size // batch_size):
                agent.learn_call()

    agent.save_model()
    print("Agent model saved.")
    return agent

# Paths - adjust to your setup
image_folder = "../data/screenshots"
movement_json = "../data/movements.json"

with open(movement_json, "r") as f:
    movements = json.load(f)

Using device: cuda


In [10]:
dataset = MarioDataset(image_folder=image_folder, movements=movements)
train(dataset, epochs=10, batch_size=32, learning_rate=0.0004)

Agent model saved.


<btr.Agent.Agent at 0x288869e27d0>

In [11]:
import os
import torch
from PIL import Image
import numpy as np
import json
from btr.Agent import Agent

# --- Configuration ---
NUM_ACTIONS = len(ACTION_KEYS)
INPUT_DIMS = (4, 114, 140)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FRAME_DIR = "../data/screenshots"      # Path to your screenshots
MOVEMENT_FILE = "../data/movements.json"  # Path to movements.json
MODEL_NAME = "offline_agent_0M.model"          # Saved model name

# --- Load movements ---
with open(MOVEMENT_FILE, "r") as f:
    movements = json.load(f)

# --- Load Agent ---
agent = Agent(
    n_actions=NUM_ACTIONS,
    input_dims=INPUT_DIMS,
    device=DEVICE,
    num_envs=1,
    agent_name="offline_agent",
    total_frames=50000,
    testing=True
)

agent.load_models(MODEL_NAME)
print("Model loaded.")

# --- Helper: Load and preprocess 1 image ---
def load_and_process_image(path):
    img = Image.open(path).convert("L")
    img = img.resize((140, 114), Image.Resampling.LANCZOS)
    img_array = np.array(img, dtype=np.float32)
    tensor = torch.tensor(img_array).unsqueeze(0)  # shape: (1, H, W)
    return tensor

# --- Prepare frames ---
all_frames = sorted([f for f in os.listdir(FRAME_DIR) if f.endswith(".png")])
if len(all_frames) < 4:
    raise ValueError("Not enough frames to make a prediction (need at least 4).")

# Iterate over all sequences of 4 consecutive frames
correct_predictions = 0
total_predictions = 0
for i in range(len(all_frames) - 3):
    frame_keys = [f.replace("d_", "").replace(".png", "") for f in all_frames[i:i+4]]
    frame_paths = [os.path.join(FRAME_DIR, f"d_{k}.png") for k in frame_keys]
    frames = [load_and_process_image(p) for p in frame_paths]
    state = torch.cat(frames, dim=0).unsqueeze(0).to(DEVICE)  # shape: (1, 4, 114, 140)

    action_tensor = agent.choose_action(state)
    action_index = action_tensor.item()
    predicted_action = ACTION_KEYS[action_index]

    # Get true action from movements.json (using last frame)
    last_frame_key = frame_keys[-1]
    if last_frame_key in movements:
        keys = {k: v for k, v in movements[last_frame_key].items() if k != "state"}
        true_actions = [action for action, is_pressed in agent_action(keys).items() if is_pressed]
        true_action = next((action for action in ACTION_KEYS if action in true_actions), "none")
    else:
        true_action = "UNKNOWN"

    print(f"Frames {frame_paths[0]} to {frame_paths[3]} => Predicted: {predicted_action} | True: {true_action}")
    total_predictions += 1
    if predicted_action == true_action:
        correct_predictions += 1


Model loaded.
Frames ../data/screenshots\d_2025-05-05_17-59_frame_0.png to ../data/screenshots\d_2025-05-05_17-59_frame_1004.png => Predicted: jump_left | True: sprint_right
Frames ../data/screenshots\d_2025-05-05_17-59_frame_100.png to ../data/screenshots\d_2025-05-05_17-59_frame_1008.png => Predicted: jump_right | True: sprint_right
Frames ../data/screenshots\d_2025-05-05_17-59_frame_1000.png to ../data/screenshots\d_2025-05-05_17-59_frame_1012.png => Predicted: move_right | True: sprint_right
Frames ../data/screenshots\d_2025-05-05_17-59_frame_1004.png to ../data/screenshots\d_2025-05-05_17-59_frame_1016.png => Predicted: move_right | True: sprint_right
Frames ../data/screenshots\d_2025-05-05_17-59_frame_1008.png to ../data/screenshots\d_2025-05-05_17-59_frame_1020.png => Predicted: jump_left | True: sprint_right
Frames ../data/screenshots\d_2025-05-05_17-59_frame_1012.png to ../data/screenshots\d_2025-05-05_17-59_frame_1024.png => Predicted: move_right | True: sprint_right
Frames .

In [12]:
accuracy = correct_predictions / total_predictions * 100 if total_predictions > 0 else 0
print(f"Accuracy: {accuracy:.2f}%")

Accuracy: 0.75%
