In [1]:
!pip install svgpath2mpl
!pip install gym

Collecting svgpath2mpl
  Downloading svgpath2mpl-1.0.0-py2.py3-none-any.whl (7.8 kB)
Installing collected packages: svgpath2mpl
Successfully installed svgpath2mpl-1.0.0
[0mCollecting gym
  Downloading gym-0.26.2.tar.gz (721 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m721.7/721.7 kB[0m [31m29.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting gym-notices>=0.0.4
  Downloading gym_notices-0.0.8-py3-none-any.whl (3.0 kB)
Building wheels for collected packages: gym
  Building wheel for gym (pyproject.toml) ... [?25ldone
[?25h  Created wheel for gym: filename=gym-0.26.2-py3-none-any.whl size=827629 sha256=898f5f7c0b4d4e960422c23c1f2eec8de1b4e7398cc96f5f7cfa0ef5fd686161
  Stored in directory: /root/.cache/pip/wheels/af/2b/30/5e78b8b9599f2a2286a582b8da80594f654bf0e18d825a4405
Successfully built gym
Installi

In [2]:
from collections import namedtuple
import torch
from networks.ppo_net import PPONet
from torch.distributions import Categorical
import torch.nn.functional as F
import numpy as np
from environments.shared_wildfire_gym import SharedWildFireGym

In [3]:

n_actions = 4
height = width = 100
channels = 2
EPISODES_PER_BATCH = 1
TRAIN_FREQ  = 10
SAVE_FREQ = 20
GAMMA = 0.95

BATCH_SIZE = 32


LAMDA = 0.9
UPDATES = 10000
EPSILON = 0.1

EPOCHS = 5
NUM_PROCESSES = 1
BETA = 0.01
TAU = 1.0

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PPONet(device,  channels, height, width, n_actions).to(device)
optimizer= torch.optim.Adam(params=model.parameters(), lr=0.0001)


In [5]:
Transition = namedtuple('Transition', ('belief_map', 'state_vector', 'action', 'reward', 'value', 'log_policy'))

class ReplayMemory:

  def __init__(self):
    self._memory = []
    
  def push(self, *args):
    """Save a transition"""
    self._memory.append(Transition(*args))

  @property
  def get_batch(self):
    return self._memory

  def __len__(self):
    return len(self._memory)

  def clear(self):
    self._memory.clear()

In [6]:
loss = None
i_episode = 1
SAVE_MODEL = 10
N_DRONES = 2
steps = 0
done = False

In [7]:
loss = None
i_episode = 1
done = False
wildFireGym = SharedWildFireGym()
observation = wildFireGym.reset()

memory = ReplayMemory()
belief_map = None
state_vector = None

while True:

    belief_map = torch.tensor(observation['belief_map'], device=device, dtype=torch.float)

    state_vector = torch.tensor(observation['state_vector'], device=device, dtype=torch.float)

    logits, value = model(belief_map, state_vector)
    
    policy = F.softmax(logits, dim=1)

    old_m = Categorical(policy)
    
    action = old_m.sample()

    old_log_policy = old_m.log_prob(action)

    action_vector = None
    if action == 0:
        action_vector = [0, 0]
    elif action == 1:
        action_vector = [0, 1]
    elif action == 2:
        action_vector = [1, 0]
    else:
        action_vector = [1, 1]

    next_observation, rewards, done, _ = wildFireGym.step(action_vector)

    memory.push(belief_map, state_vector, action, sum(rewards), value.squeeze(), old_log_policy)

    observation = next_observation

    if done:

        if i_episode % SAVE_MODEL == 0:
            file_path = f'./ppo_weights.pt'
            torch.save(model.state_dict(), file_path)
            print('saved')


        _, next_value = model(belief_map, state_vector)

        next_value = next_value.squeeze()
        
        batch  = Transition(*zip(*memory.get_batch))


        old_log_policies_batch = torch.cat(batch.log_policy).detach()
    

        actions_batch = torch.cat(batch.action)


        value_batch = torch.stack(batch.value).detach()
    

        belief_map_batch = torch.cat(batch.belief_map)


        state_vector_batch = torch.cat(batch.state_vector)


        reward_batch = batch.reward 
        
        gae = 0
        R = []

        for value, reward in list(zip(value_batch, reward_batch))[::-1]:
            gae = gae * GAMMA * TAU
            gae = gae + reward + GAMMA * next_value.detach() - value.detach()
            next_value = value
            R.append(gae + value)

        R = R[::-1]
        R = torch.stack(R).detach()
            
        advantages = R - value_batch

        for e_i in range(EPOCHS):
            indice = torch.randperm(len(memory))
            for b_i in range(len(memory)//BATCH_SIZE):
                batch_indices = indice[b_i*BATCH_SIZE:(b_i+1)*BATCH_SIZE]
                logits, values = model(belief_map_batch[batch_indices], state_vector_batch[batch_indices])
                new_policy = F.softmax(logits, dim=1)
                new_m = Categorical(new_policy)
                new_log_policy = new_m.log_prob(actions_batch[batch_indices])
                ratio = torch.exp(new_log_policy - old_log_policies_batch[batch_indices])
                
                actor_loss = -torch.mean(
                    torch.min(ratio * advantages[batch_indices],
                        torch.clamp(ratio, 1.0 - EPSILON, 1.0 + EPSILON) *
                        advantages[batch_indices]
                    )
                )

                critic_loss = F.smooth_l1_loss(R[batch_indices], values.squeeze())
        
                entropy_loss = torch.mean(new_m.entropy())
            
                total_loss = actor_loss + critic_loss - BETA * entropy_loss
                optimizer.zero_grad()
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()

        print("Episode: {}. Total loss: {}".format(i_episode, total_loss))
        memory.clear()

        i_episode += 1
        observation = wildFireGym.reset()

Episode: 1. Total loss: -15.40317153930664
Episode: 1. Total loss: -7.225956916809082
Episode: 2. Total loss: -10.1463623046875
Episode: 2. Total loss: -0.47821322083473206
Episode: 3. Total loss: -0.8540241718292236
Episode: 3. Total loss: -1.7486659288406372
Episode: 4. Total loss: -1.1140602827072144
Episode: 4. Total loss: 2.073686361312866
Episode: 5. Total loss: 0.6448144912719727
Episode: 5. Total loss: -3.082087993621826
Episode: 6. Total loss: 14.189920425415039
Episode: 6. Total loss: 9.565095901489258
Episode: 7. Total loss: 1.7347664833068848
Episode: 7. Total loss: 3.026738166809082
Episode: 8. Total loss: 2.0213546752929688
Episode: 8. Total loss: -0.6277065873146057
Episode: 9. Total loss: 4.976499557495117
Episode: 9. Total loss: -1.101016640663147
saved
Episode: 10. Total loss: 9.048123359680176
Episode: 10. Total loss: 6.9652838706970215
Episode: 11. Total loss: 1.0205944776535034
Episode: 11. Total loss: 0.9046235084533691
Episode: 12. Total loss: 0.3523760735988617


: 