## Imports

In [1]:
from pettingzoo.atari import surround_v2
import supersuit

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np

from tqdm import tqdm
import pickle
import random

from collections import deque

import matplotlib.pyplot as plt
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display

  from .autonotebook import tqdm as notebook_tqdm


## Making the environment

In [2]:
def plot(env, agent, action, idx):
    screen = env.render()
    plt.imshow(screen)
    plt.show()
    #plt.title(f"{idx}, {agent}, {action}")

def get_env():
    env = surround_v2.env(render_mode="human")
    env = supersuit.max_observation_v0(env, 2)
    # env = supersuit.sticky_actions_v0(env, repeat_action_probability=0.25)
    env = supersuit.frame_skip_v0(env, 4)
    env = supersuit.resize_v1(env, 84, 84)
    # env = supersuit.frame_stack_v1(env, 4)
    return env


# display = Display(visible=0, size=(400, 300))
# display.start()

## Agents

In [3]:
class Agent():
    def __init__(self, env, agent):
        self.action_space = env.action_space(agent)
        self.observation_space = env.observation_space(agent)
        self.agent = agent

    def get_action(self, obs):
        pass

    def remember(self, state, action, reward, next_state, done):
        pass

    def replay(self, batch_size):
        pass

    def canReplay(self, batch_size):
        return False

### RandomAgent

In [4]:
class RandomAgent(Agent):
    def __init__(self, env, agent):
        super(RandomAgent, self).__init__(env, agent)

    def get_action(self, obs):
        return self.action_space.sample()

### DeepReinforcementLearning Agent

#### - Model

In [5]:
class Model(nn.Module):
    def __init__(self, action_space: int, observation_space):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 3, stride = 1)  
        self.conv2 = nn.Conv2d(in_channels = 32, out_channels = 3, kernel_size = 2, stride = 1)  

        self.dense1 = nn.Linear(81*81*3, 1000)    
        self.dense2 = nn.Linear(1000, 300)    
        self.dense3 = nn.Linear(300, action_space)     

    def forward(self, x): 
        x = self.conv1(x)   
        x = F.relu(x)      
        x = self.conv2(x)              
        x = x.flatten()
        x = self.dense1(x)             
        x = F.relu(x)                  
        x = self.dense2(x) 
        x = F.relu(x)                  
        x = self.dense3(x) 
        return x

    def predict(self, x):              
        x = self.forward(x)     
        return torch.argmax(x, dim=0)

#### - Agent

In [6]:
class AlebaGiogAgent(Agent):
    def __init__(self, env, agent):
        super(AlebaGiogAgent, self).__init__(env, agent)
        self.model = Model(action_space = env.action_space(agent).n, observation_space = env.observation_space(agent))

        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)

        self.memory = deque()

        self.gamma = 0.9
        self.eps = 1
        self.max_eps = 1.0
        self.min_eps = 0.01
        self.decay_rate = 0.001

    def get_action(self, obs):
        obs = self.preprocess_obs(obs)
        return self.model.predict(obs).item()

    def remember(self, state, action, reward, next_state, done):
        if len(self.memory) >= 1000:
            self.memory.popleft()
        self.memory.append({"state" : self.preprocess_obs(state), 
                            "action" : action, 
                            "reward" : reward, 
                            "next_state" : self.preprocess_obs(next_state),
                            "done" : done})

    def preprocess_obs(self, obs):
        obs = torch.tensor(obs, dtype=torch.double) / 255
        obs = obs.view(-1, 3, 84, 84)
        return obs

    def canReplay(self, batch_size):
        return len(self.memory) >= batch_size

    def replay(self, batch_size=64):
        loader = DataLoader(self.memory, batch_size=batch_size, shuffle=True)
        self.train(loader)

    def train(self,loader):
        for sars in loader:                         # loop through batches
            # print(sars["state"].shape)
            self.optimizer.zero_grad()              # clean gradients of parameters
            pred = self.model(sars["state"][0])           # make prediction
            if sars["done"][0]:
                y = sars["reward"][0]
            else: 
                y = sars["reward"][0] + self.gamma * torch.max(self.model(sars["next_state"][0]))
            loss = self.criterion(pred, y)          # calculate loss with respect to prediction
            loss.backward()                         # calculate gradients of model.parameters() with respect to loss
            self.optimizer.step() 

## Training loop

In [7]:
env = get_env()
env.reset()

agents = {
    env.agents[0] : AlebaGiogAgent(env, env.agents[0]),
    env.agents[1] : RandomAgent(env, env.agents[1])
}

In [8]:
def train(env, episodes, batch_size = 64):
    epsilon = 1
    for i in tqdm(range(episodes)):
        epsilon -= 1/episodes
        env.reset()
        for agent_id in env.agent_iter():
            agent = agents[agent_id]
            # 1. Update state
            state, _, done, _, _ = env.last()
            if done:
                break

            # 2. Make a move in game.
            tradeoff = random.uniform(0,1)
            if tradeoff > epsilon:
                action = agent.get_action(state)
            else:
                action = agent.action_space.sample()
            # Take the action (a) and observe the outcome state(s') and reward (r)
            env.step(action)
            new_state, reward, done, _, _ = env.last()
            # 3. Have the agent remember stuff.
            agent.remember(state, action, reward, new_state, done)

            # 4. if we have enough experiences in our memory, learn from a batch with replay.
            if agent.canReplay(batch_size):
                agent.replay(batch_size)
            

# for agent_id in env.agent_iter():
#     agent = agents[agent_id]
#     observation, reward, done, _, _ = env.last()
#     if done:
#         break

#     action = agent.get_action(observation)
#     env.step(action)
    
    # if idx % 10 == 0:
    #     plot(env, agent, action, idx)
# plot(env, agent, action, idx)


In [None]:
train(env, 100, 64)
env.close()
torch.save(agents[env.agents[0]].model.state_dict(), 'modelTournament.pth')