In [8]:
!pip install svgpath2mpl
!pip install gym

[0m

In [9]:
from collections import namedtuple
import torch
from networks.lstm_ppo_net import LSTMPPONet
from torch.distributions import Categorical
import torch.nn.functional as F
import numpy as np
from environments.wildfire_gym import WildFireGym

In [10]:

n_actions = 2
height = width = 100
channels = 2
EPISODES_PER_BATCH = 1
TRAIN_FREQ  = 10
SAVE_FREQ = 20
GAMMA = 0.95

BATCH_SIZE = 128


LAMDA = 0.95
UPDATES = 10000
EPSILON = 0.1

EPOCHS = 10
NUM_PROCESSES = 1
BETA = 0.01
TAU = 1.0

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LSTMPPONet(device,  channels, height, width, n_actions).to(device)
optimizer= torch.optim.Adam(params=model.parameters(), lr=0.0001)


In [12]:
Transition = namedtuple('Transition', ('belief_map', 'state_vector', 'action', 'reward', 'value', 'log_policy'))

class ReplayMemory:

  def __init__(self):
    self._memory = []
    
  def push(self, *args):
    """Save a transition"""
    self._memory.append(Transition(*args))

  @property
  def get_batch(self):
    return self._memory

  def __len__(self):
    return len(self._memory)

  def clear(self):
    self._memory.clear()

In [13]:
loss = None
i_episode = 1
SAVE_MODEL = 10
N_DRONES = 2
steps = 0
done = False

In [14]:
loss = None
i_episode = 1
done = False
wildFireGym = WildFireGym()
observation = wildFireGym.reset()

memory = [ReplayMemory(), ReplayMemory()]
maps = [None, None]
state_vectors = [None, None]
reward_total = 0
while True:
    model.eval()
    maps[0] = torch.tensor(observation[0]['belief_map'], device=device, dtype=torch.float)
    maps[1] = torch.tensor(observation[1]['belief_map'], device=device, dtype=torch.float)

    state_vectors[0] = torch.tensor(observation[0]['state_vector'], device=device, dtype=torch.float)
    state_vectors[1] = torch.tensor(observation[1]['state_vector'], device=device, dtype=torch.float)

    logits_1, value_1 = model(maps[0], state_vectors[0])
    policy_1 = F.softmax(logits_1, dim=1)

    old_m_1 = Categorical(policy_1)
    
    action_1 = old_m_1.sample()

    old_log_policy_1 = old_m_1.log_prob(action_1)

    logits_2, value_2 = model(maps[1], state_vectors[1])
    
    policy_2 = F.softmax(logits_2, dim=1)

    old_m_2 = Categorical(policy_2)
    
    action_2 = old_m_2.sample()

    old_log_policy_2 = old_m_2.log_prob(action_2)

    next_observation, rewards, done, _ = wildFireGym.step([action_1.item(), action_2.item()])
    memory[0].push(maps[0], state_vectors[0], action_1, rewards[0], value_1.squeeze(), old_log_policy_1)
    memory[1].push(maps[1], state_vectors[1], action_2, rewards[1], value_2.squeeze(), old_log_policy_2)
    reward_total += sum(rewards)
    observation = next_observation
    
    if done:
        print("Episode: {}. Total reward: {}".format(i_episode, reward_total))
        reward_total = 0
        if i_episode % SAVE_MODEL == 0:
            file_path = f'./ppo_weights.pt'
            torch.save(model.state_dict(), file_path)
            print('saved')

        for d in range(N_DRONES):
            model.eval()
            _, next_value = model(maps[d], state_vectors[d])

            next_value = next_value.squeeze()
            
            batch  = Transition(*zip(*memory[d].get_batch))


            old_log_policies_batch = torch.cat(batch.log_policy).detach()
        

            actions_batch = torch.cat(batch.action)


            value_batch = torch.stack(batch.value).detach()
        

            belief_map_batch = torch.cat(batch.belief_map)


            state_vector_batch = torch.cat(batch.state_vector)


            reward_batch = batch.reward 
            
            gae = 0
            R = []

            for value, reward in list(zip(value_batch, reward_batch))[::-1]:
                gae = gae * GAMMA * TAU
                gae = gae + reward + GAMMA * next_value.detach() - value.detach()
                next_value = value
                R.append(gae + value)

            R = R[::-1]
            R = torch.stack(R).detach()
                
            advantages = R - value_batch
            model.train()
            indices = torch.randperm(len(memory[d]))
            batch_indices = torch.split(indices, BATCH_SIZE)
            for e_i in range(EPOCHS):
                for batch_indice in  batch_indices:
                    logits, values = model(belief_map_batch[batch_indice], state_vector_batch[batch_indice])
                    new_policy = F.softmax(logits, dim=1)
                    new_m = Categorical(new_policy)
                    new_log_policy = new_m.log_prob(actions_batch[batch_indice])
                    ratio = torch.exp(new_log_policy - old_log_policies_batch[batch_indice])
                    
                    actor_loss = -torch.mean(
                        torch.min(ratio * advantages[batch_indice],
                            torch.clamp(ratio, 1.0 - EPSILON, 1.0 + EPSILON) *
                            advantages[batch_indice]
                        )
                    )

                    critic_loss = F.smooth_l1_loss(R[batch_indice], values.squeeze())
          
                    entropy_loss = torch.mean(new_m.entropy())
               
                    total_loss = actor_loss + critic_loss - BETA * entropy_loss
                    optimizer.zero_grad()
                    total_loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                    optimizer.step()

            print("Episode: {}. Total loss: {}".format(i_episode, total_loss))
            memory[d].clear()

        i_episode += 1
        observation = wildFireGym.reset()

Episode: 1. Total reward: 394
Episode: 1. Total loss: -37.33098220825195
Episode: 1. Total loss: -35.094757080078125
Episode: 2. Total reward: 42
Episode: 2. Total loss: 30.808868408203125
Episode: 2. Total loss: 35.017024993896484
Episode: 3. Total reward: 724
Episode: 3. Total loss: 5.725156784057617
Episode: 3. Total loss: -1.4305964708328247
Episode: 4. Total reward: 100
Episode: 4. Total loss: 11.833915710449219
Episode: 4. Total loss: 5.066467761993408
Episode: 5. Total reward: 341
Episode: 5. Total loss: -30.605926513671875
Episode: 5. Total loss: 15.683409690856934
Episode: 6. Total reward: 676
Episode: 6. Total loss: -9.913215637207031
Episode: 6. Total loss: -8.28534984588623
Episode: 7. Total reward: 347
Episode: 7. Total loss: 28.67243766784668
Episode: 7. Total loss: 4.674426078796387
Episode: 8. Total reward: 710
Episode: 8. Total loss: -42.534549713134766
Episode: 8. Total loss: -54.27208709716797
Episode: 9. Total reward: 546
Episode: 9. Total loss: 31.752147674560547
E

  critic_loss = F.smooth_l1_loss(R[batch_indice], values.squeeze())


Episode: 14. Total loss: -42.81599426269531
Episode: 14. Total loss: -29.397750854492188
Episode: 15. Total reward: 153
Episode: 15. Total loss: 29.112613677978516
Episode: 15. Total loss: -1.8739032745361328
Episode: 16. Total reward: 59
Episode: 16. Total loss: 26.7103214263916
Episode: 16. Total loss: 71.90812683105469
Episode: 17. Total reward: 786
Episode: 17. Total loss: 131.94651794433594
Episode: 17. Total loss: 35.96464920043945
Episode: 18. Total reward: 378
Episode: 18. Total loss: 28.030977249145508
Episode: 18. Total loss: -28.807403564453125
Episode: 19. Total reward: 543
Episode: 19. Total loss: 123.87869262695312
Episode: 19. Total loss: 104.0799331665039
Episode: 20. Total reward: 887
saved
Episode: 20. Total loss: 10.036417007446289
Episode: 20. Total loss: 24.323640823364258
Episode: 21. Total reward: 589
Episode: 21. Total loss: 47.26457214355469
Episode: 21. Total loss: 7.630592346191406
Episode: 22. Total reward: 151
Episode: 22. Total loss: 32.807472229003906
Epi

KeyboardInterrupt: 