In [22]:
!pip install svgpath2mpl
!pip install gym

[0m

In [23]:
from collections import namedtuple
import torch
from networks.lstm_ppo_net import LSTMPPONet
from torch.distributions import Categorical
import torch.nn.functional as F
import numpy as np
import random
from environments.wildfire_gym import WildFireGym

In [24]:

n_actions = 2
height = width = 100
channels = 2
EPISODES_PER_BATCH = 1
TRAIN_FREQ  = 10
SAVE_FREQ = 20
GAMMA = 0.95

BATCH_SIZE = 64


LAMDA = 0.95
UPDATES = 10000
EPSILON = 0.1

EPOCHS = 5
NUM_PROCESSES = 1
BETA = 0.01
TAU = 1.0

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LSTMPPONet(device,  channels, height, width, n_actions).to(device)
optimizer= torch.optim.Adam(params=model.parameters(), lr=0.0001)


In [26]:
Transition = namedtuple('Transition', ('belief_map', 'state_vector', 'action', 'reward', 'value', 'log_policy'))

class ReplayMemory:

  def __init__(self):
    self._memory = []
    
  def push(self, *args):
    """Save a transition"""
    self._memory.append(Transition(*args))

  @property
  def get_batch(self):
    return self._memory

  def __len__(self):
    return len(self._memory)

  def clear(self):
    self._memory.clear()

In [27]:
loss = None
i_episode = 1
SAVE_MODEL = 10
N_DRONES = 2
steps = 0
done = False

In [28]:
loss = None
i_episode = 1
done = False
wildFireGym = WildFireGym()
observation = wildFireGym.reset()

memory = [ReplayMemory(), ReplayMemory()]
maps = [None, None]
state_vectors = [None, None]
hidden = [None, None]
reward_total = 0
while True:

    maps[0] = torch.tensor(observation[0]['belief_map'], device=device, dtype=torch.float)
    maps[1] = torch.tensor(observation[1]['belief_map'], device=device, dtype=torch.float)

    state_vectors[0] = torch.tensor(observation[0]['state_vector'], device=device, dtype=torch.float)
    state_vectors[1] = torch.tensor(observation[1]['state_vector'], device=device, dtype=torch.float)

    logits_1, value_1, hidden[0] = model(maps[0], state_vectors[0], hidden[0])
    policy_1 = F.softmax(logits_1, dim=1)

    old_m_1 = Categorical(policy_1)
    
    action_1 = old_m_1.sample()

    old_log_policy_1 = old_m_1.log_prob(action_1)

    logits_2, value_2, hidden[1] = model(maps[1], state_vectors[1], hidden[1])
    
    policy_2 = F.softmax(logits_2, dim=1)

    old_m_2 = Categorical(policy_2)
    
    action_2 = old_m_2.sample()

    old_log_policy_2 = old_m_2.log_prob(action_2)

    next_observation, rewards, done, _ = wildFireGym.step([action_1.item(), action_2.item()])
    memory[0].push(maps[0], state_vectors[0], action_1, torch.tensor([rewards[0]], device=device, dtype=torch.float), value_1.squeeze(), old_log_policy_1)
    memory[1].push(maps[1], state_vectors[1], action_2, torch.tensor([rewards[1]], device=device, dtype=torch.float), value_2.squeeze(), old_log_policy_2)

    observation = next_observation
    reward_total += sum(rewards)
    if done:
        print("Episode: {}. Total reward: {}".format(i_episode, reward_total))
        if i_episode % SAVE_MODEL == 0:
            file_path = f'./lstm_ppo_weights.pt'
            torch.save(model.state_dict(), file_path)
            print('saved')

        d = random.randint(0,1)
        
        _, next_value, _ = model(maps[d], state_vectors[d], hidden[d])

        next_value = next_value.squeeze()
        
        batch  = Transition(*zip(*memory[d].get_batch))


        old_log_policies_batch = torch.cat(batch.log_policy).detach()
    

        actions_batch = torch.cat(batch.action)


        value_batch = torch.stack(batch.value).detach()
    

        belief_map_batch = torch.cat(batch.belief_map)


        state_vector_batch = torch.cat(batch.state_vector)


        reward_batch = torch.cat(batch.reward) 
        
        
        gae = 0
        R = []

        for value, reward in list(zip(value_batch, reward_batch))[::-1]:
            gae = gae * GAMMA * TAU
            gae = gae + reward.detach() + GAMMA * next_value.detach() - value.detach()
            next_value = value
            R.append(gae + value)

        R = R[::-1]
        R = torch.stack(R).detach()
        advantages = R - value_batch


        batch_indices = torch.split(torch.arange(len(memory[d])), BATCH_SIZE)
        batch_indices = list(batch_indices)

        for e_i in range(EPOCHS):
            random.shuffle(batch_indices)
            for mini_batch in batch_indices:
                logits, values, _ = model(belief_map_batch[mini_batch], state_vector_batch[mini_batch])
                new_policy = F.softmax(logits, dim=1)
                new_m = Categorical(new_policy)
                new_log_policy = new_m.log_prob(actions_batch[mini_batch])
                ratio = torch.exp(new_log_policy - old_log_policies_batch[mini_batch])
                
                actor_loss = -torch.mean(
                    torch.min(ratio * advantages[mini_batch],
                        torch.clamp(ratio, 1.0 - EPSILON, 1.0 + EPSILON) *
                        advantages[mini_batch]
                    )
                )

                critic_loss = F.smooth_l1_loss(R[mini_batch], values.squeeze())
        
                entropy_loss = torch.mean(new_m.entropy())
            
                total_loss = actor_loss + critic_loss - BETA * entropy_loss
                optimizer.zero_grad()
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()

        print("Episode: {}. Total loss: {}".format(i_episode, total_loss))
            
        hidden = [None, None]
        memory[0].clear()
        memory[1].clear()
        reward_total = 0
        i_episode += 1
        observation = wildFireGym.reset()

Episode: 1. Total reward: 328
Episode: 1. Total loss: -0.2834034562110901
Episode: 2. Total reward: 337
Episode: 2. Total loss: 0.0022644558921456337
Episode: 3. Total reward: 516


  critic_loss = F.smooth_l1_loss(R[mini_batch], values.squeeze())


Episode: 3. Total loss: -1.1762022972106934
Episode: 4. Total reward: 115
Episode: 4. Total loss: 0.5873063802719116
Episode: 5. Total reward: 1180
Episode: 5. Total loss: -0.8767033219337463
Episode: 6. Total reward: 586
Episode: 6. Total loss: -4.551400184631348
Episode: 7. Total reward: 647
Episode: 7. Total loss: -6.446238994598389
Episode: 8. Total reward: 1051
Episode: 8. Total loss: -4.248243808746338
Episode: 9. Total reward: 398
Episode: 9. Total loss: 16.09334373474121
Episode: 10. Total reward: 0
saved
Episode: 10. Total loss: 5.948836803436279
Episode: 11. Total reward: 647
Episode: 11. Total loss: -5.102609157562256
Episode: 12. Total reward: 931
Episode: 12. Total loss: -3.046053886413574
Episode: 13. Total reward: 0
Episode: 13. Total loss: 5.770016670227051
Episode: 14. Total reward: 476
Episode: 14. Total loss: 0.4853808283805847
Episode: 15. Total reward: 960
Episode: 15. Total loss: -3.7595412731170654
Episode: 16. Total reward: 184
Episode: 16. Total loss: 2.6871366

KeyboardInterrupt: 