In [1]:
import sys 
sys.path.append('/Users/aoife/git/PoliwhiRL/mgba/build/python/lib.macosx-11.1-arm64-cpython-310')
import mgba.core
import mgba.image

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
from itertools import count

import pytesseract
from PIL import Image
import cv2
import io
import os
from PIL import Image as PImage

import time
from tqdm import tqdm



def create_emulator_view():
    emulator = mgba.core.load_path('Pokemon - Crystal Version.gbc')
    viewer = mgba.image.Image(160,144)
    emulator.set_video_buffer(viewer)
    emulator.reset()
    return emulator, viewer


emulator, viewer = create_emulator_view()

In [2]:
SAVE_IMG = True

def extract_text_from_image(image):
    # Convert the image to gray scale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Applying thresholding technique
    # You might need to adjust the threshold value for better results
    _, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
    text = pytesseract.image_to_string(thresh, config='--psm 6')
    # if contains text then print 
    return text

movements = [emulator.KEY_A, emulator.KEY_B, emulator.KEY_UP, emulator.KEY_DOWN, emulator.KEY_LEFT, emulator.KEY_RIGHT, emulator.KEY_START, emulator.KEY_SELECT]

# Create the images folder if it doesn't exist
if not os.path.exists('images'):
    os.makedirs('images')

# Function to run the emulator for a specified number of seconds
def wait_seconds(seconds):
    for _ in range(int(60 * seconds)):
        emulator.run_frame()

# Function to simulate a key press
def key_press(key):
    emulator.set_keys(key)
    wait_seconds(1)
    emulator.clear_keys(key)


In [3]:

# Define the neural network
class DQN(nn.Module):
    def __init__(self, h, w, outputs):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)  
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)

        # Compute the output size of the conv layers dynamically
        self._to_linear = None
        self._compute_conv_output_size(h, w)

        self.fc = nn.Linear(8160, outputs)

    def _compute_conv_output_size(self, h, w):
        # Temporary tensor to compute output size
        x = torch.rand(1, 3, h, w)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        self._to_linear = x.view(1, -1).size(1)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = torch.relu(self.bn3(self.conv3(x)))
        x = x.reshape(x.size(0), -1)

        return self.fc(x)

# Replay Memory
class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(args)

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)



# Path to save or load the checkpoint
checkpoint_path = "pokemon_rl_checkpoint.pth"

# Function to save a checkpoint
def save_checkpoint(state, filename=checkpoint_path):
    torch.save(state, filename)

# Function to load a checkpoint
def load_checkpoint():
    if os.path.isfile(checkpoint_path):
        print(f"Loading checkpoint '{checkpoint_path}'")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_episode = checkpoint['epoch']
        epsilon = checkpoint['epsilon']
        return start_episode, epsilon
    else:
        print(f"No checkpoint found at '{checkpoint_path}'")
        return 0, 0.9  # Return default values for start_episode and epsilon



# Load the checkpoint if it exists
start_episode, epsilon = load_checkpoint()


# Initialize the model and optimizer
num_actions = len(movements)  # Define movements somewhere in your code
input_shape = (160, 140, 3)  # RGB images
model = DQN(input_shape[0], input_shape[1], num_actions)
optimizer = optim.Adam(model.parameters(), lr=0.001)
memory = ReplayMemory(10000)

# Function to choose an action
epsilon = 0.9  # Make sure to decay epsilon over time
def select_action(state):
    global epsilon
    if random.random() > epsilon:
        with torch.no_grad():
            return model(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(num_actions)]], dtype=torch.long)

# Function to convert the emulator image to a tensor
def image_to_tensor(image):
    # Convert image to PyTorch tensor
    image = np.array(image)
    image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2)
    image = image.to(torch.float32) / 255  # Normalize the input
    return image

# Optimization function
def optimize_model(batch_size=128):
    if len(memory) < batch_size:
        return
    transitions = memory.sample(batch_size)
    batch = tuple(zip(*transitions))

    # Extract tensors from batch
    state_batch = torch.cat(batch[0])
    action_batch = torch.cat(batch[1])
    reward_batch = torch.cat(batch[2])

    # Compute Q(s_t, a)
    state_action_values = model(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states
    next_state_values = torch.zeros(batch_size)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch[3])), dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch[3] if s is not None])
    next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0].detach()

    # Compute the expected Q values
    expected_state_action_values = (next_state_values * 0.99) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in model.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()


num_episodes = 100
for i_episode in tqdm(range(num_episodes)):
    # Initialize the environment and state
    emulator.reset()
    state = image_to_tensor(viewer.get_cv2_image())

    for t in count():
        # Select and perform an action
        action = select_action(state)
        key_press(movements[action.item()])
        reward = torch.tensor([-0.01], dtype=torch.float32)  # Small negative reward for each step

        # Check if the goal is achieved
        if 'neighbor' in extract_text_from_image(viewer.get_cv2_image()):
            reward = torch.tensor([1.0], dtype=torch.float32)
            done = True
        else:
            done = False

        # Observe new state
        img = viewer.get_cv2_image()
        if SAVE_IMG:
            # Create images folder for this run if it doesn't exist
            if not os.path.exists(f'images/{i_episode}'):
                os.makedirs(f'images/{i_episode}')
            cv2.imwrite(f'images/{i_episode}/{t}.png', img)

        # Update the existing imshow window 

        next_state = image_to_tensor(img) if not done else None

        # Store the transition in memory
        memory.push(state, action, reward, next_state)

        # Perform optimization step
        optimize_model()

        # Move to the next state
        state = next_state

        if done:
            print(f"Episode {i_episode} finished after {t+1} steps")
            break

    # Decrease epsilon
    epsilon = max(epsilon * 0.99, 0.05)
   
    # Save checkpoint every 10 episodes
    if i_episode % 10 == 0:
        save_checkpoint({
            'epoch': i_episode + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epsilon': epsilon
        })

# Save the model
torch.save(model.state_dict(), "pokemon_rl_model_final.pth")


No checkpoint found at 'pokemon_rl_checkpoint.pth'


  0%|          | 0/100 [00:00<?, ?it/s]

GB I/O: Writing to unknown register FF56:00
GB I/O: Writing to unknown register FF15:00
GB I/O: Writing to unknown register FF1F:00
GB I/O: Writing to unknown register FF15:00
GB I/O: Writing to unknown register FF1F:00
GB I/O: Writing to unknown register FF1F:00
GB I/O: Writing to unknown register FF15:00
GB I/O: Writing to unknown register FF15:00
GB I/O: Writing to unknown register FF15:00
GB I/O: Writing to unknown register FF15:00
GB I/O: Writing to unknown register FF15:00
GB I/O: Writing to unknown register FF1F:00
GB I/O: Writing to unknown register FF1F:00
GB I/O: Writing to unknown register FF15:00
GB I/O: Writing to unknown register FF1F:00
GB I/O: Writing to unknown register FF1F:00
GB I/O: Writing to unknown register FF15:00
GB I/O: Writing to unknown register FF15:00
GB I/O: Writing to unknown register FF15:00
GB I/O: Writing to unknown register FF1F:00
GB I/O: Writing to unknown register FF15:00
GB I/O: Writing to unknown register FF1F:00
GB I/O: Writing to unknown regis