In [1]:
from magent2.environments import battle_v4, adversarial_pursuit_v4, tiger_deer_v4
import sys, os
from pettingzoo.utils import random_demo
import pygame
import torch
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from dqn_basic import DQN_Basic
from dqn_lstm import DQN_Basic_LSTM
from collections import namedtuple, deque
import random
import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
ent_coef = 0.1
vf_coef = 0.1
clip_coef = 0.1
gamma = 0.99
batch_size = 32
stack_size = 4
frame_size = (128, 128)
max_cycles = 100
total_episodes = 100
map_size = 60

env = tiger_deer_v4.env(map_size=map_size, minimap_mode=False, render_mode='human', tiger_step_recover=-0.1, deer_attacked=-0.1, max_cycles=max_cycles, extra_features=False)

# random_demo(env, render=False , episodes=1)
# 'rgb_array'
# pygame.quit()

cuda


In [3]:
class Tiger_DQN_Basic(nn.Module):
    def __init__(self, n_observations, n_actions):
        super(Tiger_DQN_Basic, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [4]:
class Deer_DQN_Basic(nn.Module):
    def __init__(self, n_observations, n_actions):
        super(Deer_DQN_Basic, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [5]:
def deer_policy(model, observation, prev_observation, no_previous=True):
    
    if no_previous:
        return torch.randint(0,model.n_actions, (1,))
    stacked_observations = torch.cat((observation, prev_observation), 1)
    # Make a prediction based on the stacked observations
    q_values = model.forward(stacked_observations)
    
    # Choose the action with the highest Q-value
    action = torch.argmax(q_values)
    # print(type(observation))
    # print(observation.shape)
    # 0 - up
    # 1 - left
    # 2 - stay still/  nothing
    # 3 - right
    # 4 - down
    # return 4
    return action

In [6]:
def tiger_policy(model, observation, prev_observation, no_previous=True):
    if no_previous:
        return torch.randint(0,model.n_actions, (1,))
    
    stacked_observations = torch.cat((observation, prev_observation), 1)
    # Make a prediction based on the stacked observations
    q_values = model.forward(stacked_observations)
    
    # Choose the action with the highest Q-value
    action = torch.argmax(q_values)
    # print(observation)
    # 0 - up
    # 1 - left
    # 2 - stay still/  nothing
    # 3 - right
    # 4 - down
    # 5 - attack up
    # 6 - attack left
    # 7 - attack right
    # 8 - attack down
    # return 5
    return action

In [7]:
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

# Get number of actions from gym action space
# n_actions = env.action_space.n
n_actions = 5
# Get the number of state observations
# state, info = env.reset()
n_observations = (3,3,5)


DQN_Basic((3,3,5), 5)


policy_net = DQN_Basic(n_observations, n_actions).to(device) # What is this line for?
target_net = DQN_Basic(n_observations, n_actions).to(device) # Tiger or Deer?
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
# memory = ReplayMemory(10000) # is this line needed?
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))



In [8]:
def optimize_model(episode_data, model):
    if len(episode_data) < batch_size:
        return
    optimizer.zero_grad()
    batch = random.sample(episode_data, batch_size)
    # [agent, observation, reward, done, info, action]
    # [agent, prev_observation, observation, reward, done, info, action]
    print(batch)
    agent_batch, prev_obs_batch, obs_batch, reward_batch, done_batch, info_batch, action_batch = map(np.array, zip(*batch))

    # Find a way to add the masked observations to the state_batch
    # state_batch = torch.cat((state_batch, masked_observations), 1)
    # next_state_batch = torch.cat((next_state_batch, masked_observations), 1)
    
    # # Compute a mask of non-final states and concatenate the batch elements
    # # (a final state would've been the one after which simulation ended)
    # non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
    #                                       batch.next_state)), device=device, dtype=torch.bool)
    # non_final_next_states = torch.cat([s for s in batch.next_state
    #                                             if s is not None])
    # state_batch = torch.cat(batch.state)
    # action_batch = torch.cat(batch.action)
    # reward_batch = torch.cat(batch.reward)

    # agent_batch = torch.tensor(agent_batch, dtype=torch.float32, device=device)
    prev_obs_batch = torch.tensor(prev_obs_batch, dtype=torch.float32, device=device)
    obs_batch = torch.tensor(obs_batch, dtype=torch.int64, device=device)
    reward_batch = torch.tensor(reward_batch, dtype=torch.float32, device=device)
    done_batch = torch.tensor(done_batch, dtype=torch.float32, device=device)
    info_batch = torch.tensor(info_batch, dtype=torch.float32, device=device)
    action_batch = torch.tensor(action_batch, dtype=torch.float32, device=device)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = model(obs_batch).gather(1, action_batch.unsqueeze(1))

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(batch_size, device=device)
    next_state_values[done_batch == 0] = model(prev_obs_batch[done_batch == 0]).max(1)[0].detach()

    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    
    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    loss.backward()
    for param in model.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [9]:
def generate_plots(episode_data):
    # number of deer
    # number of tigers
    # average life?
    
    pass

In [10]:
def mask_observations(agent,agents_prev_data,observation):
    maskDeer=torch.tensor([ [1,1,1],
                            [1,1,1],
                            [0,0,0],
                            ])
    deerOpt=[0,1,0,3,2]
    maskTiger=torch.tensor([[1,1,1,1,1,1,1,1,1],
                            [1,1,1,1,1,1,1,1,1],
                            [1,1,1,1,1,1,1,1,1],
                            [1,1,1,1,1,1,1,1,1],
                            [1,1,1,1,1,1,1,1,1],
                            [1,1,1,1,1,1,1,1,1],
                            [0,0,0,0,0,0,0,0,0],
                            [0,0,0,0,0,0,0,0,0],
                            [0,0,0,0,0,0,0,0,0],
                            ])
    tigerOpt=[0,1,0,3,2,0,1,3,2]
    rotCount=0
    if 'tiger' in agent:
        if agents_prev_data[agent][5] is None:
            rotCount=0
        else:
            rotCount=tigerOpt[agents_prev_data[agent][5]]
        finalMask=maskTiger
        maskedObs=observation
        for i in range(rotCount):
            finalMask=torch.rot90(finalMask, 1, [0, 1])
        for i in range(observation.shape[3]):
            maskedObs[:,:,:,i]=maskedObs[:,:,:,i]*finalMask
        return maskedObs
    else:
        if agents_prev_data[agent][5] is None:
            rotCount=0
        else:
            rotCount=deerOpt[agents_prev_data[agent][5]]
        finalMask=maskDeer
        maskedObs=observation
        for i in range(rotCount):
            finalMask=torch.rot90(finalMask, 1, [0, 1])
        for i in range(observation.shape[3]):
            maskedObs[:,:,:,i]=maskedObs[:,:,:,i]*finalMask
        return maskedObs

In [11]:
def get_agent_type_from_agent_name(agent):
    if 'tiger' in agent:
        return 'tiger'
    return 'deer'

In [12]:
tiger_model = DQN_Basic_LSTM((9,9,5), 9)
deer_model = DQN_Basic_LSTM((3,3,5), 5)
# tiger_model.to(device=device)
# deer_model.to(device=device)
agents_prev_data = {}
for episodes in range(total_episodes):
    env.reset(seed=None)
    episode_data = {'deer': [], 'tiger':[]}
    for agent in env.agent_iter():

        observation, reward, termination, truncation, info = env.last()
        observation = torch.unsqueeze(torch.from_numpy(observation),0)
        
        # instantiate previous data
        no_previous = False
        if agent not in agents_prev_data.keys():
            agents_prev_data[agent] = [None]*6
            no_previous = True

        prev_observation = agents_prev_data[agent][0]

        # set agent type to tiger or deer based on agentName
        agentType = get_agent_type_from_agent_name(agent)
        # print(observation[:,:,:,0])
        # print(mask_observations(agent,agents_prev_data,observation)[:,:,:,0])
        # input()
        # catch if agent is dead
        done = termination or truncation

        if not done:
            # if the agent is not dead
            if 'tiger' in agent:
                action =  tiger_policy(tiger_model, observation, prev_observation, no_previous=no_previous)
            else:
                action = deer_policy(deer_model, observation, prev_observation, no_previous=no_previous)
            if not no_previous:
                episode_data[agentType].append([agent, prev_observation, observation, reward, done, info, action])
        else:
            # if agent is dead
            action = None
        if isinstance(action, torch.Tensor):
            action = action.item()
        # previous recorded data
        agents_prev_data[agent] = [observation, reward, termination, truncation, info, action]
        env.step(action)

    #########################
    # plots for data
    # generate_plots(episode_data)
    #########################
    # optimize_model(episode_data['deer'], deer_model)
    # optimize_model(episode_data['tiger'], tiger_model)
env.close()

(9, 9, 5)
(3, 3, 5)
tensor([2])
tensor([6])
tensor([4])
tensor([7])
tensor([7])
tensor([4])
tensor([8])
tensor([5])
tensor([7])
tensor([1])
tensor([6])
tensor([8])
tensor([5])
tensor([8])
tensor([7])
tensor([4])
tensor([0])
tensor([8])
tensor([7])
tensor([0])
tensor([2])
tensor([3])
tensor([2])
tensor([0])
tensor([5])
tensor([0])
tensor([0])
tensor([8])
tensor([6])
tensor([3])
tensor([4])
tensor([0])
tensor([4])
tensor([7])
tensor([1])
tensor([5])
tensor(0)
tensor(0)
tensor(1)
tensor(3)
tensor(6)
tensor(1)
tensor(1)
tensor(8)
tensor(1)
tensor(8)
tensor(0)
tensor(8)
tensor(8)
tensor(8)
tensor(8)
tensor(1)
tensor(8)
tensor(0)
tensor(5)
tensor(8)
tensor(3)
tensor(8)
tensor(1)
tensor(1)
tensor(8)
tensor(0)
tensor(1)
tensor(3)
tensor(7)
tensor(8)
tensor(3)
tensor(0)
tensor(6)
tensor(7)
tensor(8)
tensor(8)
tensor(8)
tensor(0)
tensor(1)
tensor(3)
tensor(8)
tensor(8)
tensor(1)
tensor(1)
tensor(5)
tensor(0)
tensor(1)
tensor(8)
tensor(8)
tensor(5)
tensor(8)
tensor(1)
tensor(8)
tensor(5)
tensor(5

KeyboardInterrupt: 

In [None]:
print(env.action_space)
print(type(env))
print(env.num_agents)
print(env.possible_agents)
print(env.action_spaces)