In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

# --- Constants ---
ACTION_KEYS = ["left", "right", "jump", "sprint"]
ACTION_TO_INDEX = {tuple([int(k == a) for k in ACTION_KEYS]): i for i, a in enumerate(ACTION_KEYS)}
NUM_ACTIONS = len(ACTION_KEYS)

# --- Dataset ---
class MarioDataset(Dataset):
    def __init__(self, image_folder, movements, frame_window=4):
        self.image_folder = image_folder
        self.movements = movements
        self.frame_window = frame_window
        self.frames = sorted(movements.keys())
        self.valid_indices = list(range(frame_window, len(self.frames)))

        self.transform = transforms.Compose([
            transforms.Resize((84, 84)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        idx = self.valid_indices[idx]
        imgs = []
        for i in range(idx - self.frame_window, idx):
            frame_key = self.frames[i]
            img_path = os.path.join(self.image_folder, f"{frame_key}.png")
            img = Image.open(img_path).convert("RGB")
            imgs.append(self.transform(img))

        state = torch.cat(imgs, dim=0)  # [3*4, 84, 84]

        # Action (convert multi-hot to index)
        frame_key = self.frames[idx]
        action_state = self.movements[frame_key]["state"]
        action_vec = tuple([int(action_state[k]) for k in ACTION_KEYS])
        action = ACTION_TO_INDEX.get(action_vec, 0)

        reward = self.movements[frame_key]["state"].get("reward", 0.0)

        # Next frame
        if idx + 1 < len(self.frames):
            next_frame_key = self.frames[idx + 1]
            next_img_path = os.path.join(self.image_folder, f"{next_frame_key}.png")
            next_img = Image.open(next_img_path).convert("RGB")
            next_state = torch.cat([
                *imgs[1:],
                self.transform(next_img)
            ], dim=0)
        else:
            next_state = state

        return state, action, reward, next_state

# --- Simple Q-Network (based on BTR) ---
class BTRNetwork(nn.Module):
    def __init__(self, input_channels, num_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def forward(self, x):
        return self.fc(self.conv(x))

# --- Training Loop ---
def train_q_learning(dataset, model_path, epochs=5, batch_size=32, gamma=0.99):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = BTRNetwork(input_channels=12, num_actions=NUM_ACTIONS).to(device)
    target_model = BTRNetwork(input_channels=12, num_actions=NUM_ACTIONS).to(device)
    target_model.load_state_dict(model.state_dict())
    target_model.eval()

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = nn.MSELoss()

    for epoch in range(epochs):
        total_loss = 0
        for state, action, reward, next_state in dataloader:
            state = state.to(device)
            next_state = next_state.to(device)
            action = action.to(device)
            reward = reward.to(device)

            q_values = model(state)
            next_q_values = target_model(next_state)
            max_next_q = next_q_values.max(dim=1)[0]
            targets = q_values.clone().detach()
            for i in range(state.size(0)):
                targets[i, action[i]] = reward[i] + gamma * max_next_q[i]

            loss = loss_fn(q_values, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")

# --- Pseudo-code for online usage ---
ONLINE_USAGE_PSEUDO = """
# Load model
model = BTRNetwork(input_channels=12, num_actions=4)
model.load_state_dict(torch.load("btr_q_model.pth"))
model.eval()

# Inside emulator loop:
obs = get_4frame_stack_from_emulator()  # shape [12, 84, 84]
action_values = model(obs.unsqueeze(0))
action = action_values.argmax().item()

# Send action to emulator
send_action_to_emulator(action)

# Collect new (s, a, r, s') and store in new buffer
# Optionally fine-tune with Q-learning
"""

# --- Main driver ---
if __name__ == "__main__":
    with open("movements.json", "r") as f:
        movements = json.load(f)
    dataset = MarioDataset("screenshots", movements)
    train_q_learning(dataset, "btr_q_model.pth")
