In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from model import BTRNetwork



def agent_action(filtered_keys):
    sprint = filtered_keys.get("B", False)
    move_right = filtered_keys.get("Right", False)
    move_left = filtered_keys.get("Left", False)

    jump = filtered_keys.get("A", False)
    crouch = filtered_keys.get("Down", False)
    airborne = filtered_keys.get("One", False)
    sprint_left = move_left and sprint
    sprint_right = move_right and sprint
    jump_left = (move_left or sprint_left) and jump
    jump_right = (move_right or sprint_right) and jump
    stand_still = not any([move_right, move_left, jump, crouch, airborne, sprint_left, sprint_right, jump_left, jump_right])
    return {

        "jump": jump,
        "crouch": crouch,
        "airborne": airborne,
        "sprint_left": sprint_left,
        "sprint_right": sprint_right,
        "jump_left": jump_left,
        "jump_right": jump_right,
        "move_right": move_right,
        "move_left": move_left,
        "none": stand_still,
    }

# --- Constants ---
ACTION_KEYS = ["jump", "crouch", "airborne", "sprint_left", "sprint_right",
               "jump_left", "jump_right", "move_right", "move_left", "none"]



ACTION_TO_INDEX = {action: idx for idx, action in enumerate(ACTION_KEYS)}
NUM_ACTIONS = len(ACTION_KEYS)

# --- Dataset ---
class MarioDataset(Dataset):
    def __init__(self, image_folder, movements, frame_window=4):
        self.image_folder = image_folder
        self.movements = movements
        self.frame_window = frame_window
        self.frames = sorted(movements.keys())
        self.valid_indices = list(range(frame_window, len(self.frames)))


        # self.transform = transforms.Compose([
        #     transforms.Resize((84, 84)),
        #     transforms.ToTensor(),
        # ])

    # use custom process instead of torchvision transforms, this way we dont need to install torchvision
    # for only its transforms. They should effectively do the same thing
    def process_image(self, img, resize_to=(140, 114)):
        img = img.resize(resize_to, Image.Resampling.LANCZOS)
        img_array = np.array(img, dtype=np.float32)
        
        img_tensor = torch.tensor(img_array).unsqueeze(0)  # Add channel dimension (1, H, W)

        return img_tensor # no division (/255) needed because model does it, TODO: maybe move it to procesing instead of model


    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        idx = self.valid_indices[idx]
        imgs = []
        for i in range(idx - self.frame_window, idx):
            frame_key = self.frames[i]
            img_path = os.path.join(self.image_folder, f"d_{frame_key}.png")
            img = Image.open(img_path).convert("L")
            imgs.append(self.process_image(img))



        state = torch.cat(imgs, dim=0)  # [3*4, 84, 84]
        #print("state shape", state.shape)

        frame_key = self.frames[idx]
        keys = {k: v for k, v in self.movements[frame_key].items() if k != "state"}
        actions_dict = agent_action(keys)
        #print("action-dict", actions_dict)

        # TODO check if this is always correct, I think so
        # idea here to get the valid action, we take the last if there are more
        # e.g. jump_right is after jump and move/sprint right
        # latest action counts
        for a, value in reversed(actions_dict.items()):
            if value:
                action = a


        reward = self.movements[frame_key]["state"].get("reward", 0.0)

        # Next frame
        if idx + 1 < len(self.frames):
            next_frame_key = self.frames[idx + 1]
            next_img_path = os.path.join(self.image_folder, f"d_{next_frame_key}.png")
            next_img = Image.open(next_img_path).convert("L")
            next_state = torch.cat([
                *imgs[1:],
                self.process_image(next_img)
            ], dim=0)
        else:
            next_state = state

        return state, action, reward, next_state
    

# -- testing dataset functionality --


# TODO: change the paths here to the ones you have
dataset = MarioDataset(
    image_folder="../experiment/data/screenshots",
    movements=json.load(open("../experiment/data/movements.json")),
    frame_window=4
)

In [2]:
data0 = dataset[0]
print(data0[0].shape)  # [3*4, 84, 84]

torch.Size([4, 114, 140])


In [3]:
import torch.nn.functional as F
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


def train(dataset, epochs=1, batch_size=32, learning_rate=0.0001, gamma=0.99):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Initialize the model
    model = BTRNetwork(num_actions=NUM_ACTIONS).to(device)
    model.train()

    target_model = BTRNetwork(num_actions=NUM_ACTIONS).to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.MSELoss()

    losses = []
    for epoch in range(epochs):
        for i, (states, actions, rewards, next_states) in enumerate(dataloader):

            #print(type(states), type(actions), type(rewards), type(next_states))
                
    
            actions_discrete = [ACTION_TO_INDEX[a] for a in actions]
            actions = torch.tensor(actions_discrete, dtype=torch.long).to(device)

            states = states.float().to(device)
            #print("states shape", states.shape)
            rewards = rewards.float().to(device)
            next_states = next_states.float().to(device)

            # Configs for Munchausen RL
            alpha = 0.9  # log-policy scaling
            tau = 0.03   # softmax temperature
            clip_min = -1.0  # munchausen term clipped minimum

            # Get current Q-values
            q_values = model(states)
            #print("q_values shape", q_values.shape)
            q_values_mean = q_values.mean(dim=1)
            #print("q_values_mean shape", q_values_mean.shape)
            q_taken = q_values_mean.gather(1, actions.unsqueeze(1)).squeeze(1)
            #print("q_taken shape", q_taken.shape)

            with torch.no_grad():
                next_q = target_model(next_states)
                #print("next_q shape", next_q.shape)
                next_q_mean = next_q.mean(dim=1)
                #print("next_q_mean shape", next_q_mean.shape)
                next_probs = F.softmax(next_q_mean / tau, dim=1)       # π(a|s')
                #print("next_probs shape", next_probs.shape)
                next_v = (next_q_mean * next_probs).sum(dim=1)  # Shape: [32]
                #print("next_v shape", next_v.shape)

                # Softmax over current Q-values (for log π(a|s))
                current_probs = F.softmax(q_values_mean / tau, dim=1)
                log_policy = torch.log(current_probs + 1e-8)       # [B, A]
                log_pi_a = log_policy.gather(1, actions.unsqueeze(1)).squeeze(1)  # log π(a|s)
                munchausen_term = alpha * torch.clamp(log_pi_a, min=clip_min)

                # Munchausen target
                target_q = rewards + munchausen_term + gamma * next_v

            loss = loss_fn(q_taken, target_q)
            losses.append(loss.item())

            # Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
        # After each epoch
        target_model.load_state_dict(model.state_dict())

    return model, losses

Using device: cuda


In [4]:
(model, losses) = train(dataset, epochs=5, batch_size=32, learning_rate=0.0004, gamma=0.99)

In [5]:
print(losses)

[808.0996704101562, 122.07223510742188, 132.61505126953125, 88.66462707519531, 1052.9410400390625, 57.21360778808594, 367.99700927734375, 732.253662109375, 111.39972686767578, 231.74403381347656, 90.65411376953125, 80.9020004272461, 484.17669677734375, 1977.562744140625, 61.17024230957031, 296.0118713378906, 330.04791259765625, 225.67422485351562, 61.29290771484375, 348.9710388183594, 20.270885467529297, 56.10218811035156, 55.152156829833984, 66.62965393066406, 321.12969970703125, 700.02734375, 130.2519989013672, 452.03082275390625, 252.7568359375, 58.86053466796875, 708.6806030273438, 121.07814025878906, 28.510343551635742, 557.5345458984375, 403.33447265625, 124.31763458251953, 66.4805679321289, 65.92271423339844, 24.800804138183594, 103.87236022949219, 46.56050109863281, 853.0126953125, 33.31560516357422, 44.231170654296875, 67.92382049560547, 111.99029541015625, 120.53428649902344, 270.84564208984375, 65.20271301269531, 120.75869750976562, 510.658203125, 103.84513092041016, 53.8133

In [None]:
# save model weights
#torch.save(model.state_dict(), "btr_model.pth")
