In [26]:
from gymnasium import spaces

from environment.agent import Agent, UserInputAgent, run_real_time_match
from environment.environment import WarehouseBrawl

env = WarehouseBrawl()

from collections import namedtuple, deque
from itertools import count
from typing import Optional
import random
import math
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F

import matplotlib
import matplotlib.pyplot as plt

is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Obs space [-1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, -1, -1, 0, -1, -1, 0, -1, -1, 0, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, -1, -1, 0, -1, -1, 0, -1, -1, 0, -1, -1, -1, -1] [1, 1, 1, 1, 1, 1, 1, 2, 12, 1, 1, 1, 1, 3, 11, 2, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 12, 1, 1, 1, 1, 3, 11, 2, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 1, 1]
Action space [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [8]:
# Create replay memory
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [9]:
# Design Deep Q-network
class DQN_MLP_PFA(nn.Module):
    def __init__(self):
        super(DQN_MLP_PFA, self).__init__()
        self.input_layer = nn.Linear(6900, 1024)
        self.hidden_layers = nn.Sequential(
            *[nn.Sequential(nn.Linear(1024, 1024), nn.ReLU()) for _ in range(10)]
        )
        self.output_layer = nn.Linear(1024, 32)


    def forward(self, x):
        x = F.relu(self.input_layer(x))
        x = self.hidden_layers(x)
        return self.output_layer(x)

# Define DQN output to action-space function

In [16]:
# Define hyperparams and epsilon-greedy policy
BATCH_SIZE = 256
GAMMA = 0.98
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 5000
TAU = 0.001
LR = 3e-4

policy_net = DQN_MLP_PFA().to(device)
target_net = DQN_MLP_PFA().to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.Adam(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(30000)

steps_done = 0

def select_action_rep(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1).indices.view(1, 1)
    else:
        return torch.tensor([[random.choice(range(32))]], device=device, dtype=torch.long)


def process_action_rep(action_rep):
    # 0, 1%4, 2%4, 3%4, >16, 0, =(12-15, 28-31), =(4-7, 20-23), =(8-11, 24-27), 0
    action_rep = action_rep.item()
    return np.array([0,
            action_rep%4==1,
            action_rep%4==2,
            action_rep%4==3,
            action_rep>16,
            0,
            12<=action_rep%16,
            4<=action_rep%16<=7,
            8<=action_rep%16<=11,
            0
            ])


def process_half_obs(half_obs):
    half_obs = torch.tensor(half_obs, device=device)
    offset = 32
    try:
        return torch.cat((half_obs[:8],
                   F.one_hot(half_obs[8].to(torch.int64), num_classes=13),
                   half_obs[9:13],
                   F.one_hot(half_obs[14].to(torch.int64), num_classes=13),
                   half_obs[offset:offset+8],
                   F.one_hot(half_obs[offset+8].to(torch.int64), num_classes=13),
                   half_obs[offset+9:offset+13],
                   F.one_hot(half_obs[offset+14].to(torch.int64), num_classes=13),
                   F.one_hot(half_obs[offset+15].to(torch.int64), num_classes=3),
                   half_obs[28:32]
                   ))
    except:
        print(half_obs)
        raise "Error"

In [17]:
# Define training step
# Batch sample -> run policy -> compute TD error -> optimize
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None]).to(torch.float)
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

In [37]:
# Generate episodes
'''
Env is WarehouseBrawl
- opponent
- reward must be tuned
'''

num_episodes = 50

for i_episode in range(num_episodes):
    player_obs_list = []
    opponent_obs_list = []
    obs, info = env.reset()
    player_obs_list += [torch.cat((process_half_obs(obs[0]), F.one_hot(torch.tensor(0, device=device), num_classes=32)))] * 60
    opponent_obs_list += [torch.cat((process_half_obs(obs[1]), F.one_hot(torch.tensor(0, device=device), num_classes=32)))] * 60

    player_state = torch.cat(player_obs_list).unsqueeze(0).to(torch.float32)

    for t in count():
        player_action_rep = select_action_rep(player_state)
        player_action = process_action_rep(player_action_rep)

        opponent_action_rep = torch.tensor(0)
        opponent_action = process_action_rep(opponent_action_rep)

        full_action = {
            0: player_action,
            1: opponent_action
        }

        observation, reward, terminated, truncated, _ = env.step(full_action)
        if reward[0] > 0 or reward[1] > 0:
            print(reward)
        # process reward... (env returns reward for both agents so player reward is reward[0])

        reward = torch.tensor([reward[0]], device=device)
        done = terminated or truncated

        if terminated:
            next_state = None
        else:
            player_obs_list = player_obs_list[1:] + [
                torch.cat((
                    process_half_obs(observation[0]),
                    F.one_hot(player_action_rep.squeeze(), num_classes=32)
                ))
            ]
            next_state = torch.cat(player_obs_list).unsqueeze(0)

        # Store the transition in memory
        memory.push(player_state, player_action_rep, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            print(f"{i_episode=}, {t=}")
            break

print('Complete')

{0: 0.011944444444444535, 1: 0.0}
{0: 0.011944444444444535, 1: 0.0}
{0: 0.011944444444444535, 1: 0.0}
{0: 0.011944444444444535, 1: 0.0}
{0: 0.011944444444444535, 1: 0.0}
{0: 0.011944444444444535, 1: 0.0}
{0: 0.011944444444444535, 1: 0.0}
{0: 0.12666666666666693, 1: 0.0}
{0: 0.12666666666666693, 1: 0.0}
{0: 0.12666666666666693, 1: 0.0}
{0: 0.12666666666666693, 1: 0.0}
{0: 0.12666666666666693, 1: 0.0}
{0: 0.12666666666666693, 1: 0.0}
{0: 0.12666666666666693, 1: 0.0}
{0: 0.12666666666666648, 1: 0.0}
{0: 0.12666666666666648, 1: 0.0}
{0: 1.0052447763087002e-10, 1: 0.0}
{0: 1.0482770207431713e-10, 1: 0.0}
{0: 1.0913137060697409e-10, 1: 0.0}
{0: 1.1343503913963104e-10, 1: 0.0}
{0: 1.17738707672288e-10, 1: 0.0}
{0: 1.2204237620494496e-10, 1: 0.0}
{0: 1.2634604473760191e-10, 1: 0.0}
{0: 1.3495293771370598e-10, 1: 0.0}
{0: 1.3925660624636294e-10, 1: 0.0}
{0: 0.22499999999999964, 1: 0.0}
{0: 0.22499999999999964, 1: 0.0}
{0: 0.22499999999999964, 1: 0.0}
{0: 0.22499999999999964, 1: 0.0}
{0: 0.22499

In [38]:
torch.save(policy_net.state_dict(), 'DQN_MLP_PFA2.pth')

In [39]:
class SubmittedAgent(Agent):
    def __init__(self, file_path: Optional[str] = None):
        super().__init__(file_path)
        self.past_obs = []
        self.last_action_rep = 0


    def _initialize(self) -> None:
        self.model = DQN_MLP_PFA().to(device)
        if self.file_path is not None:
            self.model.load_state_dict(torch.load(self.file_path))


    def predict(self, observation):
        if len(self.past_obs) == 0:
            self.past_obs = [torch.cat((process_half_obs(observation), F.one_hot(torch.tensor(0, device=device), num_classes=32)))] * 60
        else:
            self.past_obs = self.past_obs[1:] + [torch.cat((process_half_obs(observation), F.one_hot(torch.tensor(self.last_action_rep, device=device), num_classes=32)))]

        state = torch.cat(self.past_obs).unsqueeze(0).to(torch.float32)
        action_rep = self.model(state).max(1).indices.view(1, 1)
        return process_action_rep(action_rep)



In [40]:
agent_4 = SubmittedAgent(file_path="DQN_MLP_PFA2.pth")
agent_1 = UserInputAgent()
max_timesteps = 30*90
run_real_time_match(agent_1, agent_4)

Failed to initialize pygame mixer
Obs space [-1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, -1, -1, 0, -1, -1, 0, -1, -1, 0, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, -1, -1, 0, -1, -1, 0, -1, -1, 0, -1, -1, -1, -1] [1, 1, 1, 1, 1, 1, 1, 2, 12, 1, 1, 1, 1, 3, 11, 2, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 12, 1, 1, 1, 1, 3, 11, 2, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 3, 1, 1, 1, 1]
Action space [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


  self.model.load_state_dict(torch.load(self.file_path))


Ground is rendered
Ground is rendered
Stage is rendered


100%|██████████| 16/16 [00:00<00:00, 34.79it/s]
100%|██████████| 3/3 [00:00<00:00, 18.05it/s]
100%|██████████| 16/16 [00:00<00:00, 31.90it/s]
100%|██████████| 3/3 [00:00<00:00, 15.05it/s]


MatchStats(match_time=11.7, player1=PlayerStats(damage_taken=0, damage_done=0, lives_left=3), player2=PlayerStats(damage_taken=0, damage_done=0, lives_left=0), player1_result=<Result.WIN: 1>)