In [1]:
!pip install svgpath2mpl
!pip install gym

[0m

In [2]:
from collections import namedtuple
import torch
from networks.ppo_net import PPONet
from torch.distributions import Categorical
import torch.nn.functional as F
import numpy as np
from environments.shared_wildfire_gym import SharedWildFireGym

In [3]:
n_actions = 4
height = width = 100
channels = 2
EPISODES_PER_BATCH = 1
TRAIN_FREQ  = 10
SAVE_FREQ = 20
GAMMA = 0.9

BATCH_SIZE = 64


LAMDA = 0.9
UPDATES = 10000
EPSILON = 0.1

EPOCHS = 10
NUM_PROCESSES = 1
BETA = 0.01
TAU = 1.0

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PPONet(device,  channels, height, width, n_actions).to(device)
optimizer= torch.optim.Adam(params=model.parameters(), lr=0.0001)


In [5]:
Transition = namedtuple('Transition', ('belief_map', 'state_vector', 'action', 'reward', 'value', 'log_policy'))

class ReplayMemory:

  def __init__(self):
    self._memory = []
    
  def push(self, *args):
    """Save a transition"""
    self._memory.append(Transition(*args))

  @property
  def get_batch(self):
    return self._memory

  def __len__(self):
    return len(self._memory)

  def clear(self):
    self._memory.clear()

In [6]:
loss = None
i_episode = 1
SAVE_MODEL = 10
N_DRONES = 2
steps = 0
done = False

In [7]:
loss = None
i_episode = 1
done = False
wildFireGym = SharedWildFireGym()
observation = wildFireGym.reset()

memory = ReplayMemory()
belief_map = None
state_vector = None

while True:

    belief_map = torch.tensor(observation['belief_map'], device=device, dtype=torch.float)

    state_vector = torch.tensor(observation['state_vector'], device=device, dtype=torch.float)

    logits, value = model(belief_map, state_vector)
    
    policy = F.softmax(logits, dim=1)

    old_m = Categorical(policy)
    
    action = old_m.sample()

    old_log_policy = old_m.log_prob(action)
    action_item = action.item()
    print(action_item)
    action_vector = None
    if action_item  == 0:
        action_vector = [0, 0]
    elif action_item == 1:
        action_vector = [0, 1]
    elif action_item == 2:
        action_vector = [1, 0]
    else:
        action_vector = [1, 1]

    next_observation, rewards, done, _ = wildFireGym.step(action_vector)

    memory.push(belief_map, state_vector, action, sum(rewards), value.squeeze(), old_log_policy)

    observation = next_observation

    if done:



        _, next_value = model(
            torch.tensor(next_observation['belief_map'], device=device, dtype=torch.float), 
            torch.tensor(next_observation['state_vector'], device=device, dtype=torch.float)
        )

        next_value = next_value.squeeze()
        
        batch  = Transition(*zip(*memory.get_batch))


        old_log_policies_batch = torch.cat(batch.log_policy).detach()
    

        actions_batch = torch.cat(batch.action)


        value_batch = torch.stack(batch.value).detach()
    

        belief_map_batch = torch.cat(batch.belief_map)


        state_vector_batch = torch.cat(batch.state_vector)


        reward_batch = batch.reward 
        
        gae = 0
        R = []

        for value, reward in list(zip(value_batch, reward_batch))[::-1]:
            gae = gae * GAMMA * TAU
            gae = gae + reward + GAMMA * next_value.detach() - value.detach()
            next_value = value
            R.append(gae + value)

        R = R[::-1]
        R = torch.stack(R).detach()
            
        advantages = R - value_batch

        for e_i in range(EPOCHS):
            indice = torch.randperm(len(memory))
            for b_i in range(len(memory)//BATCH_SIZE):
                batch_indices = indice[b_i*BATCH_SIZE:(b_i+1)*BATCH_SIZE]
                logits, values = model(belief_map_batch[batch_indices], state_vector_batch[batch_indices])
                new_policy = F.softmax(logits, dim=1)
                new_m = Categorical(new_policy)
                new_log_policy = new_m.log_prob(actions_batch[batch_indices])
                ratio = torch.exp(new_log_policy - old_log_policies_batch[batch_indices])
                
                actor_loss = -torch.mean(
                    torch.min(ratio * advantages[batch_indices],
                        torch.clamp(ratio, 1.0 - EPSILON, 1.0 + EPSILON) *
                        advantages[batch_indices]
                    )
                )

                critic_loss = F.smooth_l1_loss(R[batch_indices], values.squeeze())
        
                entropy_loss = torch.mean(new_m.entropy())
            
                total_loss = actor_loss + critic_loss - BETA * entropy_loss
                optimizer.zero_grad()
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()

        print("Episode: {}. Total loss: {}".format(i_episode, total_loss))
        memory.clear()

        
        observation = wildFireGym.reset()
        
        if i_episode % SAVE_MODEL == 0:
            file_path = f'./mappo_weights.pt'
            torch.save(model.state_dict(), file_path)
            print('saved')

        i_episode += 1

Episode: 1. Total loss: -4.450505256652832
Episode: 2. Total loss: 27.435470581054688
Episode: 3. Total loss: -5.590559959411621
Episode: 4. Total loss: -6.030116558074951
Episode: 5. Total loss: -3.8009424209594727
Episode: 6. Total loss: -2.2183988094329834
Episode: 7. Total loss: 2.401672124862671
Episode: 8. Total loss: -1.7253459692001343
Episode: 9. Total loss: -7.204023838043213
Episode: 10. Total loss: 18.406944274902344
saved
Episode: 11. Total loss: -11.15910530090332
Episode: 12. Total loss: -2.7526776790618896
Episode: 13. Total loss: -0.20381072163581848
Episode: 14. Total loss: -7.052924633026123
Episode: 15. Total loss: 7.898496627807617
Episode: 16. Total loss: 3.3340587615966797
Episode: 17. Total loss: -2.245131254196167
Episode: 18. Total loss: 12.95170783996582
Episode: 19. Total loss: 2.1903254985809326
Episode: 20. Total loss: -6.578614711761475
saved
Episode: 21. Total loss: 14.027261734008789
Episode: 22. Total loss: -4.65286922454834
Episode: 23. Total loss: -2

: 