In [1]:
from environment import EnvironmentGoogleSnake
from torch import nn
import torch

from torchvision.transforms import ToTensor

import gymnasium as gym
import numpy as np

import matplotlib.pyplot as plt
import copy


from torchinfo import summary

In [2]:
class PreProcessEnv(gym.Wrapper):

    def __init__(self, env):
        gym.Wrapper.__init__(self, env=env)

    def reset(self):
        obs = self.env.reset()
        return torch.tensor(obs).unsqueeze(dim = 0).float()
    
    def step(self, action: torch.Tensor):
        action = action.item()
        next_state, reward, done, info = self.env.step(action)
        next_state = torch.tensor(next_state).unsqueeze(dim = 0).float()
        reward = torch.tensor(reward).view(1, -1).float()
        done = torch.tensor(done).view(1, -1)
        return next_state, reward, done, info
    

In [3]:
# env = EnvironmentGoogleSnake()
# env = PreProcessEnv(env)
from maze import Maze

# Constants
GAME_HEIGHT = 600
GAME_WIDTH = 600
NUMBER_OF_TILES = 25
SCREEN_HEIGHT = 700
SCREEN_WIDTH = 700
TILE_SIZE = GAME_HEIGHT // NUMBER_OF_TILES

# Maze layout
level = [
    "XXXXXXXXXXXXXXXXXXXXXXXXX",
    "X XXXXXXXX          XXXXX",
    "X XXXXXXXX  XXXXXX  XXXXX",
    "X      XXX  XXXXXX  XXXXX",
    "X   P  XXX  XXX         X",
    "XXXXXX  XX  XXX        XX",
    "XXXXXX  XX  XXXXXX  XXXXX",
    "XXXXXX  XX  XXXXXX  XXXXX",
    "X  XXX      XXXXXXXXXXXXX",
    "X  XXX  XXXXXXXXXXXXXXXXX",
    "X         XXXXXXXXXXXXXXX",
    "X             XXXXXXXXXXX",
    "XXXXXXXXXXX      XXXXX  X",
    "XXXXXXXXXXXXXXX  XXXXX  X",
    "XXX  XXXXXXXXXX         X",
    "XXX                     X",
    "XXX         XXXXXXXXXXXXX",
    "XXXXXXXXXX  XXXXXXXXXXXXX",
    "XXXXXXXXXX              X",
    "XX   XXXXX              X",
    "XX   XXXXXXXXXXXXX  XXXXX",
    "XX    XXXXXXXXXXXX  XXXXX",
    "XX        XXXX          X",
    "XXXX                    X",
    "XXXXXXXXXXXXXXXXXXXXXXXXX",
]

env = Maze(
    level,
    goal_pos=(23, 20),
    MAZE_HEIGHT=GAME_HEIGHT,
    MAZE_WIDTH=GAME_WIDTH,
    SIZE=NUMBER_OF_TILES,
)
env = PreProcessEnv(env)
NO_OF_ACTIONS = 4

# env.unwrapped.start()
state = env.reset()



  from .autonotebook import tqdm as notebook_tqdm


In [4]:
next_state, reward, done, info = env.step(torch.randint(4, (1,1)))

print(f"Next State: {next_state}, Reward: {reward}, Done: {done}, Info: {info}")

Next State: tensor([[4., 4.]]), Reward: tensor([[-1.]]), Done: tensor([[False]]), Info: {}


In [5]:
class DeepRLModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.classifier= torch.nn.Sequential(
            nn.Linear(2, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 4),
        )
    def forward(self, x: torch.Tensor):
        return self.classifier(x)

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
q_network = DeepRLModel().to(device)
summary(q_network, (3, 2))


Layer (type:depth-idx)                   Output Shape              Param #
DeepRLModel                              [3, 4]                    --
├─Sequential: 1-1                        [3, 4]                    --
│    └─Linear: 2-1                       [3, 128]                  384
│    └─ReLU: 2-2                         [3, 128]                  --
│    └─Linear: 2-3                       [3, 64]                   8,256
│    └─ReLU: 2-4                         [3, 64]                   --
│    └─Linear: 2-5                       [3, 4]                    260
Total params: 8,900
Trainable params: 8,900
Non-trainable params: 0
Total mult-adds (M): 0.03
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.04
Estimated Total Size (MB): 0.04

In [7]:
target_q_network = copy.deepcopy(q_network).to(device)
target_q_network.eval()

DeepRLModel(
  (classifier): Sequential(
    (0): Linear(in_features=2, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=4, bias=True)
  )
)

In [8]:
q_network(next_state.to(device))

tensor([[ 0.1284, -0.2936, -0.2326,  0.1474]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [9]:
def policy(state, epsilon = 0.05):
    if torch.rand(1) < epsilon:
        return torch.randint(4, (1, 1))
    else:
        av = q_network(state.to(device))
        return torch.argmax(av, dim = -1, keepdim = True)

In [10]:
from torch.optim import AdamW
from memory import ReplayMemory
from tqdm import tqdm
import torch.nn.functional as F

def deep_sarsa(q_network:DeepRLModel, policy, episodes, alpha=0.001, batch_size=64, gamma=0.99, epsilon=0.5):
    optim = AdamW(q_network.parameters(), lr=alpha)
    memory = ReplayMemory(capacity = 1000000)
    stats = {'MSE Loss': [], 'Returns': []}
    
    for episode in tqdm(range(1, episodes + 1)):
        state = env.reset()
        done = False
        ep_return = 0
        while not done:
            action = policy(state, epsilon)
            next_state, reward, done, _ = env.step(action)
            next_state = next_state.to(device)
            reward = reward.to(device)
            done = done.to(device)
            memory.insert([state, action, reward, done, next_state])
            
            if memory.can_sample(batch_size):
                state_b, action_b, reward_b, done_b, next_state_b = memory.sample(batch_size)

                state_b = state_b.to(device)
                reward_b = reward_b.to(device)
                done_b = done_b.to(device)
                next_state_b = next_state_b.to(device)
                action_b = action_b.to(device)
                          
                qsa_b = q_network(state_b).gather(1, action_b).to(device)

                next_qsa_b = torch.argmax(target_q_network(next_state_b), dim = -1, keepdim = True)
                target_b = reward_b + ~done_b * gamma * next_qsa_b

                
                loss = F.mse_loss(qsa_b, target_b.to(device))
                q_network.zero_grad()
                loss.backward()
                optim.step()
                
                loss.item()
                stats['MSE Loss'].append(loss.item())
            
            state = next_state
            ep_return += reward.item()
        
        stats['Returns'].append(ep_return)
        if episode % 100 == 0:
            target_q_network.load_state_dict(q_network.state_dict())
        
    return stats

In [11]:
stats = deep_sarsa(q_network, policy, episodes  = 100)

  0%|          | 0/100 [12:57<?, ?it/s]


KeyboardInterrupt: 