In [1]:
!pip install svgpath2mpl

[0m

In [2]:
from collections import namedtuple
import numpy as np
import torch
from svgpath2mpl import parse_path
import matplotlib.pyplot as plt
from scipy.ndimage import rotate, shift
from matplotlib.animation import FuncAnimation
from probabilistic_fire_env import ProbabilisticFireEnv
from drone_env import DronesEnv
from collections import  deque
from networks.ppo_net import PPONet
from torch.distributions import Categorical
import torch.nn.functional as F
import random
from itertools import islice


In [3]:
DT          = 0.5  # Time between wildfire updates            
DTI         = 0.1  # Time between aircraft decisions
n_actions = 2
height = width = 100
channels = 2
EPISODES_PER_BATCH = 1
TRAIN_FREQ  = 10
SAVE_FREQ = 20
GAMMA = 0.95
CLIP  = 0.2

BATCH_SIZE = 64


LAMDA = 0.99
UPDATES = 10000
EPSILON = 0.2

EPOCHS = 10
NUM_PROCESSES = 1
BETA = 0.1
TAU = 1.0



In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PPONet(device,  channels, height, width, n_actions).to(device)

optimizer = torch.optim.Adamax(params=model.parameters(), lr=0.00025)


In [5]:
Transition = namedtuple('Transition', ('belief_map', 'state_vector', 'action', 'reward', 'value', 'log_policy', 'done'))

class ReplayMemory:

  def __init__(self):
    self._memory = []
  def push(self, *args):
    """Save a transition"""
    self._memory.append(Transition(*args))

  @property
  def get_batch(self):
    return self._memory

  def __len__(self):
    return len(self._memory)

  def clear(self):
    self._memory.clear()

In [6]:
DT          = 0.5  # Time between wildfire updates            
DTI         = 0.1  # Time between aircraft decisions
fireEnv = ProbabilisticFireEnv(height, width)
dronesEnv = DronesEnv(height, width, DT, DTI) 
loss = None
i_episode = 1
SAVE_MODEL = 10
N_DRONES = 2
steps = 0
done = False


In [7]:
seed = fireEnv.reset()
dronesEnv.reset(seed)

memory = [ReplayMemory(),ReplayMemory()]
state_vectors = [None]*N_DRONES
maps = [None]*N_DRONES
steps_for_episode = 0 

while True:
    
    for j in range(TRAIN_FREQ//int(2*DT/DTI)):

        observation = fireEnv.step()
        
        for d in range(N_DRONES):

            state_vectors[d] = dronesEnv.drones[d].state
            maps[d] = dronesEnv.drones[d].observation

            state_vectors[d] = torch.tensor(state_vectors[d], device=device, dtype=torch.float)
            maps[d] = torch.tensor(maps[d], device=device, dtype=torch.float)



        for i in range(int(DT/DTI)):
            
            steps += 1

            logits = [None]*N_DRONES
            values = [None]*N_DRONES
            actions = [None]*N_DRONES
            old_m = [None]*N_DRONES
            old_log_policy = [None]*N_DRONES
            rewards = [None]*N_DRONES

            for d in range(N_DRONES):

                logits[d], values[d] = model(maps[d], state_vectors[d])
                
                policy = F.softmax(logits, dim=1)
            
                old_m[d] = Categorical(policy)
    

                actions[d] = old_m[d].sample()

            
                old_log_policy[d] = old_m[d].log_prob(actions[d])



            rewards = dronesEnv.step(actions, observation)
            done = not fireEnv.fire_in_range(6)
            for d in range(N_DRONES):

                state_vectors[d] = dronesEnv.drones[d].state
                maps[d] = dronesEnv.drones[d].observation

                state_vectors[d] = torch.tensor(state_vectors[d], device=device, dtype=torch.float)
                maps[d] = torch.tensor(maps[d], device=device, dtype=torch.float)

                memory[d].push(maps[d], state_vectors[d], actions[d], rewards[d], values[d], old_log_policy[d], done)

            if done:

                if i_episode % SAVE_MODEL == 0:
                    file_path = f'./ppo_weights.pt'
                    torch.save(model.state_dict(), file_path)
                    print('saved')

           

                for d in range(N_DRONES):
                    _, next_value, = model(maps[d], state_vectors[d])

                    next_value = next_value.squeeze()
                    
                    batch  = Transition(*zip(*memory[d].get_batch))


                    old_log_policies_batch = torch.cat(batch.log_policy).detach()
                

                    actions_batch = torch.cat(batch.action)


                    value_batch = torch.cat(batch.value).detach()
                

                    belief_map_batch = torch.cat(batch.belief_map)


                    state_vector_batch = torch.cat(batch.state_vector)


                    reward_batch = batch.reward 


                    done_batch = batch.done

                    
                    gae = 0
                    R = []


                    for value, reward, done in list(zip(value_batch, reward_batch, done_batch))[::-1]:
                        gae = gae * GAMMA * TAU
                        gae = gae + reward + GAMMA * next_value.detach() * (1 - done) - value.detach()
                        next_value = value
                        R.append(gae + value)

                    R = R[::-1]
                    R = torch.cat(R).detach()
                        
                    advantages = R - value_batch.detach().squeeze()
                    
                    for e_i in range(EPOCHS):
                        indice = torch.randperm(len(memory[d]))
                        for b_i in range(len(memory[d])//BATCH_SIZE):
                            batch_indices = indice[b_i*(len(memory[d])//BATCH_SIZE): (b_i+1) *(len(memory[d])//BATCH_SIZE)]
                            logits, value = model(belief_map_batch[batch_indices], state_vector_batch[batch_indices])
                            new_policy = F.softmax(logits, dim=1)
                            new_m = Categorical(new_policy)
                            new_log_policy = new_m.log_prob(actions_batch[batch_indices])
                            ratio = torch.exp(new_log_policy - old_log_policies_batch[batch_indices])
                            
                            actor_loss = -torch.mean(
                                torch.min(ratio * advantages[batch_indices],
                                    torch.clamp(ratio, 1.0 - EPSILON, 1.0 + EPSILON) *
                                    advantages[batch_indices]
                                )
                            )

                            critic_loss = F.smooth_l1_loss(R[batch_indices], value.squeeze())
                            entropy_loss = torch.mean(new_m.entropy())
                            total_loss = actor_loss + critic_loss - BETA * entropy_loss
                            optimizer.zero_grad()
                            total_loss.backward()
                            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                            optimizer.step()

                    print("Episode: {}. Total loss: {}".format(i_episode, total_loss))
                    memory[d].clear()
                i_episode += 1
                seed = fireEnv.reset()
                dronesEnv.reset(seed)


Episode: 1. Total loss: -1.4205999374389648
Episode: 1. Total loss: -0.18661947548389435
Episode: 2. Total loss: -0.08708422631025314
Episode: 2. Total loss: -0.07153958082199097
Episode: 3. Total loss: 0.3813660144805908
Episode: 3. Total loss: -0.06833426654338837
Episode: 4. Total loss: -3.571653366088867
Episode: 4. Total loss: -0.08647318184375763
Episode: 5. Total loss: -1.200814962387085
Episode: 5. Total loss: -0.08002986013889313
Episode: 6. Total loss: -0.5445626378059387
Episode: 6. Total loss: -0.0426526814699173
Episode: 7. Total loss: -1.1778169870376587
Episode: 7. Total loss: -3.5961239337921143
Episode: 8. Total loss: 0.5251494646072388
Episode: 8. Total loss: -2.2240498065948486
Episode: 9. Total loss: 0.040774472057819366
Episode: 9. Total loss: 1.0079681873321533
saved
Episode: 10. Total loss: -4.983312129974365
Episode: 10. Total loss: -0.07639080286026001
Episode: 11. Total loss: -12.141057014465332
Episode: 11. Total loss: -0.062991201877594
Episode: 12. Total lo

KeyboardInterrupt: 