In [1]:
!pip install svgpath2mpl

Collecting svgpath2mpl
  Downloading svgpath2mpl-1.0.0-py2.py3-none-any.whl (7.8 kB)
Installing collected packages: svgpath2mpl
Successfully installed svgpath2mpl-1.0.0
[0m

In [2]:
from collections import namedtuple
import torch
from probabilistic_fire_env import ProbabilisticFireEnv
from drone_env import DronesEnv
from networks.ppo_net import PPONet
from torch.distributions import Categorical
import torch.nn.functional as F
import numpy as np

In [3]:
DT          = 0.5  # Time between wildfire updates            
DTI         = 0.1  # Time between aircraft decisions
n_actions = 2
height = width = 100
channels = 2
EPISODES_PER_BATCH = 1
TRAIN_FREQ  = 10
SAVE_FREQ = 20
GAMMA = 0.9

BATCH_SIZE = 64


LAMDA = 0.9
UPDATES = 10000
EPSILON = 0.1

EPOCHS = 10
NUM_PROCESSES = 1
BETA = 0.01
TAU = 1.0

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PPONet(device,  channels, height, width, n_actions).to(device)

optimizer = torch.optim.Adam(params=model.parameters(), lr=0.00001)

In [5]:
Transition = namedtuple('Transition', ('belief_map', 'state_vector', 'action', 'reward', 'value', 'log_policy'))

class ReplayMemory:

  def __init__(self):
    self._memory = []
    
  def push(self, *args):
    """Save a transition"""
    self._memory.append(Transition(*args))

  @property
  def get_batch(self):
    return self._memory

  def __len__(self):
    return len(self._memory)

  def clear(self):
    self._memory.clear()

In [6]:
DT          = 0.5  # Time between wildfire updates            
DTI         = 0.1  # Time between aircraft decisions
fireEnv = ProbabilisticFireEnv(height, width)
dronesEnv = DronesEnv(height, width, DT, DTI) 
loss = None
i_episode = 1
SAVE_MODEL = 10
N_DRONES = 2
steps = 0
done = False

In [7]:
seed = fireEnv.reset()
dronesEnv.reset(seed)

memory = ReplayMemory()
state_vectors = [None]*N_DRONES
maps = [None]*N_DRONES
steps_for_episode = 0 

while True:
    
    for j in range(TRAIN_FREQ//int(2*DT/DTI)):

        observation = fireEnv.step()
        
        for d in range(N_DRONES):

            state_vectors[d] = dronesEnv.drones[d].state
            maps[d] = dronesEnv.drones[d].observation

            state_vectors[d] = torch.tensor(state_vectors[d], device=device, dtype=torch.float)
            maps[d] = torch.tensor(maps[d], device=device, dtype=torch.float)



        for i in range(int(DT/DTI)):
            
            steps += 1

            logits = [None]*N_DRONES
            values = [None]*N_DRONES
            actions = [None]*N_DRONES
            old_m = [None]*N_DRONES
            old_log_policy = [None]*N_DRONES
            rewards = [None]*N_DRONES

            for d in range(N_DRONES):

                logits[d], values[d] = model(maps[d], state_vectors[d])
                
                policy = F.softmax(logits[d], dim=1)
            
                old_m[d] = Categorical(policy)
    
                actions[d] = old_m[d].sample()
            
                old_log_policy[d] = old_m[d].log_prob(actions[d])


            rewards = dronesEnv.step(actions, observation)
            memory.push(maps[0], state_vectors[0], actions[0], rewards[0], values[0], old_log_policy[0])
            
            for d in range(N_DRONES):

                state_vectors[d] = dronesEnv.drones[d].state
                maps[d] = dronesEnv.drones[d].observation

                state_vectors[d] = torch.tensor(state_vectors[d], device=device, dtype=torch.float)
                maps[d] = torch.tensor(maps[d], device=device, dtype=torch.float)

            if done:

                if i_episode % SAVE_MODEL == 0:
                    file_path = f'./ppo_weights.pt'
                    torch.save(model.state_dict(), file_path)
                    print('saved')

           
                _, next_value, = model(maps[0], state_vectors[0])

                next_value = next_value.squeeze()
                
                batch  = Transition(*zip(*memory.get_batch))


                old_log_policies_batch = torch.cat(batch.log_policy).detach()
            

                actions_batch = torch.cat(batch.action)


                value_batch = torch.cat(batch.value).detach()
            

                belief_map_batch = torch.cat(batch.belief_map)


                state_vector_batch = torch.cat(batch.state_vector)


                reward_batch = batch.reward 
                
                gae = 0
                R = []

                for value, reward in list(zip(value_batch, reward_batch))[::-1]:
                    gae = gae * GAMMA * TAU
                    gae = gae + reward + GAMMA * next_value.detach() - value.detach()
                    next_value = value
                    R.append(gae + value)

                R = R[::-1]
                R = torch.cat(R).detach()
                    
                advantages = R - value_batch.squeeze()
                indices = np.arange(len(memory))
                indices_split = np.array_split(indices, BATCH_SIZE)
                
                for e_i in range(EPOCHS):
                    indice = torch.randperm(len(memory))
                 
                    for b_i in range(len(memory)//BATCH_SIZE):
                        batch_indices = indice[b_i*(len(memory)//BATCH_SIZE): (b_i+1) *(len(memory)//BATCH_SIZE)]
                        logits, value = model(belief_map_batch[batch_indices], state_vector_batch[batch_indices])
                        new_policy = F.softmax(logits, dim=1)
                        new_m = Categorical(new_policy)
                        new_log_policy = new_m.log_prob(actions_batch[batch_indices])
                        ratio = torch.exp(new_log_policy - old_log_policies_batch[batch_indices])
                        
                        actor_loss = -torch.mean(
                            torch.min(ratio * advantages[batch_indices],
                                torch.clamp(ratio, 1.0 - EPSILON, 1.0 + EPSILON) *
                                advantages[batch_indices]
                            )
                        )

                        critic_loss = F.smooth_l1_loss(R[batch_indices], value.squeeze())
                        entropy_loss = torch.mean(new_m.entropy())
                        total_loss = actor_loss + critic_loss - BETA * entropy_loss
                        optimizer.zero_grad()
                        total_loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                        optimizer.step()

                print("Episode: {}. Total loss: {}".format(i_episode, total_loss))
                memory.clear()
                i_episode += 1
                seed = fireEnv.reset()
                dronesEnv.reset(seed)
                break

Episode: 1. Total loss: 0.02799837291240692
Episode: 1. Total loss: 0.026823706924915314
Episode: 2. Total loss: -3.6505913734436035
Episode: 2. Total loss: 0.3167976140975952
Episode: 3. Total loss: -1.209546446800232
Episode: 3. Total loss: -1.0565685033798218
Episode: 4. Total loss: -0.0018777013756334782
Episode: 4. Total loss: 0.0006790105253458023
Episode: 5. Total loss: 0.027159474790096283
Episode: 5. Total loss: 0.07939529418945312
Episode: 6. Total loss: 5.440525054931641
Episode: 6. Total loss: 5.041927337646484
Episode: 7. Total loss: 0.029298098757863045
Episode: 7. Total loss: 2.3409292697906494
Episode: 8. Total loss: -2.6780829429626465
Episode: 8. Total loss: 0.23820964992046356
Episode: 9. Total loss: 1.2588509321212769
Episode: 9. Total loss: 0.38812899589538574
saved
Episode: 10. Total loss: 0.2658970057964325
Episode: 10. Total loss: 0.24153049290180206
Episode: 11. Total loss: 2.1258182525634766
Episode: 11. Total loss: -2.454558849334717
Episode: 12. Total loss: 

KeyboardInterrupt: 