In [1]:
!pip install svgpath2mpl

Collecting svgpath2mpl
  Downloading svgpath2mpl-1.0.0-py2.py3-none-any.whl (7.8 kB)
Installing collected packages: svgpath2mpl
Successfully installed svgpath2mpl-1.0.0
[0m

In [2]:
from collections import namedtuple
import torch
from env.wildfire_gym import WildFireGym
from networks.ppo_net import PPONet
from torch.distributions import Categorical
import torch.nn.functional as F
import numpy as np

In [3]:

n_actions = 2
height = width = 100
channels = 2
EPISODES_PER_BATCH = 1
TRAIN_FREQ  = 10
SAVE_FREQ = 20
GAMMA = 0.9

BATCH_SIZE = 64


LAMDA = 0.9
UPDATES = 10000
EPSILON = 0.1

EPOCHS = 10
NUM_PROCESSES = 1
BETA = 0.01
TAU = 1.0

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PPONet(device,  channels, height, width, n_actions)
optimizer= torch.optim.Adam(params=model.parameters(), lr=0.0001)


In [5]:
Transition = namedtuple('Transition', ('belief_map', 'state_vector', 'action', 'reward', 'value', 'log_policy'))

class ReplayMemory:

  def __init__(self):
    self._memory = []
    
  def push(self, *args):
    """Save a transition"""
    self._memory.append(Transition(*args))

  @property
  def get_batch(self):
    return self._memory

  def __len__(self):
    return len(self._memory)

  def clear(self):
    self._memory.clear()

In [6]:

loss = None
i_episode = 1
SAVE_MODEL = 10
N_DRONES = 2
steps = 0
done = False

In [7]:
loss = None
i_episode = 1

wildFireGym = WildFireGym()
observation = wildFireGym.reset()

memory = [ReplayMemory(), ReplayMemory()]
maps = [None, None]
state_vectors = [None, None]
while True:

    maps[0] = torch.tensor(observation[0]['belief_map'], device=device, dtype=torch.float)
    maps[1] = torch.tensor(observation[1]['belief_map'], device=device, dtype=torch.float)

    state_vectors[0] = torch.tensor(observation[0]['state_vector'], device=device, dtype=torch.float)
    state_vectors[1] = torch.tensor(observation[1]['state_vector'], device=device, dtype=torch.float)

    logits_1, value_1 = model(maps[0], state_vectors[0])
    
    policy_1 = F.softmax(logits_1, dim=1)

    old_m_1 = Categorical(policy_1)
    
    action_1 = old_m_1.sample()

    old_log_policy_1 = old_m_1.log_prob(action_1)

    logits_2, value_2 = model(maps[1], state_vectors[1])
    
    policy_2 = F.softmax(logits_2, dim=1)

    old_m_2 = Categorical(policy_2)
    
    action_2 = old_m_2.sample()

    old_log_policy_2 = old_m_2.log_prob(action_2)

    next_observation, rewards, done, _ = wildFireGym.step([action_1, action_2])

    memory[0].push(maps[0], state_vectors[0], action_1, rewards[0], value_1.squeeze(), old_log_policy_1)
    memory[1].push(maps[1], state_vectors[1], action_2, rewards[1], value_2.squeeze(), old_log_policy_2)




    if done:

        if i_episode % SAVE_MODEL == 0:
            file_path = f'./ppo_weights.pt'
            torch.save(model.state_dict(), file_path)
            print('saved')

        for d in range(N_DRONES):
            _, next_value = model(maps[d], state_vectors[d])

            next_value = next_value.squeeze()
            
            batch  = Transition(*zip(*memory[d].get_batch))


            old_log_policies_batch = torch.cat(batch.log_policy).detach()
        

            actions_batch = torch.cat(batch.action)


            value_batch = torch.stack(batch.value).detach()
        

            belief_map_batch = torch.cat(batch.belief_map)


            state_vector_batch = torch.cat(batch.state_vector)


            reward_batch = batch.reward 
            
            gae = 0
            R = []

            for value, reward in list(zip(value_batch, reward_batch))[::-1]:
                gae = gae * GAMMA * TAU
                gae = gae + reward + GAMMA * next_value.detach() - value.detach()
                next_value = value
                R.append(gae + value)

            R = R[::-1]
            R = torch.stack(R).detach()
                
            advantages = R - value_batch

            for e_i in range(EPOCHS):
                indice = torch.randperm(len(memory[d]))
            
                for b_i in range(len(memory[d])//BATCH_SIZE):
                    batch_indices = indice[b_i*(len(memory[d])//BATCH_SIZE): (b_i+1) *(len(memory[d])//BATCH_SIZE)]
                    logits, values = model(belief_map_batch[batch_indices], state_vector_batch[batch_indices])
                    new_policy = F.softmax(logits, dim=1)
                    new_m = Categorical(new_policy)
                    new_log_policy = new_m.log_prob(actions_batch[batch_indices])
                    ratio = torch.exp(new_log_policy - old_log_policies_batch[batch_indices])
                    
                    actor_loss = -torch.mean(
                        torch.min(ratio * advantages[batch_indices],
                            torch.clamp(ratio, 1.0 - EPSILON, 1.0 + EPSILON) *
                            advantages[batch_indices]
                        )
                    )

                    critic_loss = F.smooth_l1_loss(R[batch_indices], value.squeeze())
                    entropy_loss = torch.mean(new_m.entropy())
                    total_loss = actor_loss + critic_loss - BETA * entropy_loss
                    optimizer.zero_grad()
                    total_loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                    optimizer.step()

            print("Episode: {}. Total loss: {}".format(i_episode, total_loss))
            memory[d].clear()
        i_episode += 1
        wildFireGym = WildFireGym()
        observation = wildFireGym.reset()
        memory = [ReplayMemory(), ReplayMemory()]
        maps = [None, None]
        state_vectors = [None, None]
        break

Episode: 1. Total loss: -0.10826052725315094
Episode: 1. Total loss: -5.918087005615234
Episode: 2. Total loss: 0.6450406908988953
Episode: 2. Total loss: 2.426481246948242
Episode: 3. Total loss: 1.1261237859725952
Episode: 3. Total loss: 0.03809733688831329
Episode: 4. Total loss: -2.664047956466675
Episode: 4. Total loss: -0.04057173430919647
Episode: 5. Total loss: -9.497632026672363
Episode: 5. Total loss: 0.018901020288467407
Episode: 6. Total loss: -0.008171562105417252
Episode: 6. Total loss: -0.016402889043092728
Episode: 7. Total loss: -4.856660842895508
Episode: 7. Total loss: -0.026954282075166702
Episode: 8. Total loss: 0.4451802670955658
Episode: 8. Total loss: 0.5058846473693848
Episode: 9. Total loss: 5.345011234283447
Episode: 9. Total loss: -0.012998729944229126
saved
Episode: 10. Total loss: 1.8571988344192505
Episode: 10. Total loss: -0.0074686650186777115
Episode: 11. Total loss: -1.3182380199432373
Episode: 11. Total loss: 3.8942677974700928
Episode: 12. Total los

KeyboardInterrupt: 