In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import random
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from pettingzoo.mpe import simple_tag_v2
from tqdm import tqdm

In [2]:
env = simple_tag_v2.env(
            num_good=1,
            num_adversaries=3,
            num_obstacles=2,
            max_cycles=10000,
            continuous_actions=False,
            # render_mode='human'
        )

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Let's define an experience replay memory that can be used to store new transitions and sample mini-batches of previous transitions. 

In [4]:
class ReplayMemory(object):
    """Experience Replay Memory"""
    
    def __init__(self, capacity):
        #self.size = size
        self.memory = deque(maxlen=capacity)
    
    def add(self, *args):
        """Add experience to memory."""
        self.memory.append([*args])
    
    def sample(self, batch_size):
        """Sample batch of experiences from memory with replacement."""
        return random.sample(self.memory, batch_size)
    
    def count(self):
        return len(self.memory)

The Q-network is very similar to the one we have seen previously, but we add the possibility to update the parameters, so the same class can also be used as a target network.  

In [5]:
class DQN_prey(nn.Module):
    """Deep Q-network with target network"""
    
    def __init__(self, n_inputs, n_outputs, learning_rate):
        super(DQN_prey, self).__init__()
        # network
        self.out = nn.Linear(n_inputs, n_outputs)
        # training
        self.optimizer = optim.SGD(self.parameters(), lr=learning_rate)
    
    def forward(self, x):
        # x = x.to(device)
        return self.out(x)
    
    def loss(self, q_outputs, q_targets):
        return torch.sum(torch.pow(q_targets - q_outputs, 2))
    
    def update_params(self, new_params, tau):
        params = self.state_dict()
        for k in params.keys():
            params[k] = (1-tau) * params[k] + tau * new_params[k]
        self.load_state_dict(params)

Before training, we create a policy network and copy its weight parameters to a target network, so they are initially the same. 
We also set up a replay memory and prefill it with random transitions sampled from the environment. 

In [48]:
num_episodes = 300 # training loops
episode_limit = 200 # 
batch_size = 64
learning_rate = 0.01
gamma = 0.99 # discount rate
tau = 0.01 # target network update rate
replay_memory_capacity = 5000
prefill_memory = True
val_freq = 100 # validation frequency

# n_inputs = env.observation_space.n
# n_outputs = env.action_space.n
n_inputs = 14
n_outputs = 5

# initialize DQN and replay memory
policy_dqn = DQN_prey(n_inputs, n_outputs, learning_rate)
target_dqn = DQN_prey(n_inputs, n_outputs, learning_rate)
# .to(device)
target_dqn.load_state_dict(policy_dqn.state_dict())

replay_memory = ReplayMemory(replay_memory_capacity)

# prefill replay memory with random actionss
env.reset()
if prefill_memory:
    print('prefill replay memory ...')
    
    s = None
    s1 = None
    r = None

    count = 1
    env.reset()
    for agent in env.agent_iter():
        if count == 4:
            s = env.last()[0]
        
        if replay_memory.count() >= replay_memory_capacity:
            break

        s1, reward, termination, truncation, _ = env.last()
        action = env.action_space(agent).sample() 
        env.step(action)

        if count % 4 == 0:
            replay_memory.add(s, action, reward, s1, termination)
            # print(reward)

        if not (termination or truncation):
            if count % 4 == 0:
                s = s1
        else:
            env.reset()
            count = 1

        count = count + 1

    env.close()

    print('prefill replay memory done')

# test 1
# print(one_hot([0,1,2], 14))
# policy_dqn(torch.from_numpy(one_hot([0, 1, 2] ,14)).float())

prefill replay memory ...
prefill replay memory done


In [49]:
# training loop
env.reset()

try:
    print('start training')
    epsilon = 0.5
    rewards, lengths, losses, epsilons = [], [], [], []

    for i in tqdm(range(num_episodes)):
        # init new episode
        for agent in env.agent_iter():
            if agent != 'agent_0':
                # try network only for the agent, sample action for adversary
                env.step(env.action_space(agent).sample())
            else:     
                s, ep_reward, ep_loss = env.last()[0], 0, 0

                for j in range(episode_limit):
                    if env.last()[2] or env.last()[3]: 
                        env.reset()
                        break

                    if (j + 1) % 4 == 0:
                        if np.random.rand() < epsilon:
                            a = env.action_space(agent).sample()
                        else:
                            with torch.no_grad():
                                a = policy_dqn(torch.from_numpy(s).float()).argmax().item()
                        # perform action
                        env.step(a)
                    else:
                        # print(j)
                        env.step(0)
                        continue
                    # if env.last()[0].shape[0] == 14:

                    s1, reward, termination, truncation, _ = env.last()

                    # store experience in replay memory
                    replay_memory.add(s, a, reward, s1, termination)
                    
                    # batch update
                    if replay_memory.count() >= batch_size:
                        # sample batch from replay memory, this is used as to predict the values in the q-table
                        # frozen lake do one hot encoding for states, we directly put the 14-sized vector into the network
                        batch = replay_memory.sample(batch_size)
                        # ss, aa, rr, ss1, dd = batch[:,0], batch[:,1], batch[:,2], batch[:,3], batch[:,4]
                        ss = np.array([list(memory[0]) for memory in batch])
                        aa = np.array([memory[1] for memory in batch])
                        rr = np.array([memory[2] for memory in batch])
                        ss1 = np.array([list(memory[3]) for memory in batch])
                        dd = np.array([memory[4] for memory in batch])

                        # do forward pass of batch
                        policy_dqn.optimizer.zero_grad()
                        Q = policy_dqn(torch.from_numpy(ss).float())

                        # use target network to compute target Q-values
                        with torch.no_grad():
                            Q1 = target_dqn(torch.from_numpy(ss1).float())
                        # compute target for each sampled experience
                        q_targets = Q.clone()
                        for k in range(batch_size):
                            q_targets[k, aa[k]] = rr[k] + gamma * Q1[k].max().item() * (not dd[k])
                        
                        # update network weights
                        loss = policy_dqn.loss(Q, q_targets)
                        loss.backward()
                        policy_dqn.optimizer.step()
                        # update target network parameters from policy network parameters
                        target_dqn.update_params(policy_dqn.state_dict(), tau)

                    else:
                        loss = 0
                    
                    # bookkeeping
                    s = s1
                    ep_reward += reward
                    ep_loss += loss.item()

                    if termination or truncation: 
                        env.reset()
                        break
                
                # bookkeeping
                epsilon *= num_episodes/(i/(num_episodes/20)+num_episodes) # decrease epsilon
                epsilons.append(epsilon); rewards.append(ep_reward); lengths.append(j+1); losses.append(ep_loss)

                if (i+1) % val_freq == 0: print('%5d mean training reward: %5.2f' % (i+1, np.mean(rewards[-val_freq:])))
            break
    print('done')
except KeyboardInterrupt:
    print('interrupt')

start training


 34%|███▎      | 101/300 [00:13<00:30,  6.63it/s]

  100 mean training reward: -36.37


 67%|██████▋   | 201/300 [00:29<00:15,  6.21it/s]

  200 mean training reward: -36.37


100%|██████████| 300/300 [00:42<00:00,  6.99it/s]

  300 mean training reward: -52.42
done





In [None]:
# not working yet
# plot results
def moving_average(a, n=10) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret / n

plt.figure(figsize=(16, 9))
plt.subplot(411)
plt.title('training rewards')
plt.plot(range(1, num_episodes+1), rewards)
plt.plot(moving_average(rewards))
plt.xlim([0, num_episodes])
plt.subplot(412)
plt.title('training lengths')
plt.plot(range(1, num_episodes+1), lengths)
plt.plot(range(1, num_episodes+1), moving_average(lengths))
plt.xlim([0, num_episodes])
plt.subplot(413)
plt.title('training loss')
plt.plot(range(1, num_episodes+1), losses)
plt.plot(range(1, num_episodes+1), moving_average(losses))
plt.xlim([0, num_episodes])
plt.subplot(414)
plt.title('epsilon')
plt.plot(range(1, num_episodes+1), epsilons)
plt.xlim([0, num_episodes])
plt.tight_layout(); plt.show()

In [None]:

env = simple_tag_v2.env(
            num_good=1,
            num_adversaries=3,
            num_obstacles=2,
            max_cycles=1000,
            continuous_actions=False,
            render_mode='human'
        )

env.reset()
for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()
    if agent == 'agent_0':
        if termination or truncation:
            env.reset()
            continue

        action = policy_dqn(torch.from_numpy(env.last()[0]).float()).argmax().item()
    else:
        action = None if termination or truncation else env.action_space(agent).sample()  # this is where you would insert your policy
    
    env.step(action)

env.render()
env.close()