In [None]:
!pip install svgpath2mpl

In [None]:
from collections import namedtuple, deque
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adamax
import random
import math 
from svgpath2mpl import parse_path
import matplotlib.pyplot as plt
import matplotlib
from scipy.ndimage import rotate, shift
from matplotlib.animation import FuncAnimation
from probabilistic_fire_env import ProbabilisticFireEnv
from drone_env import DronesEnv
from networks.ppo_net import PPONet
from torch.distributions import MultivariateNormal

In [None]:
DT          = 0.5  # Time between wildfire updates            
DTI         = 0.1  # Time between aircraft decisions
n_actions = 2
height = width = 100
channels = 2
EPISODES_PER_BATCH = 5
TRAIN_FREQ  = 10
SAVE_FREQ = 10
GAMMA = 0.95

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

actor = PPONet(device,  channels, height, width, n_actions).to(device)
critic = PPONet(device, channels, height, width, 1).to(device)

optimizer_actor = torch.optim.Adam(params=actor.parameters(), lr=0.001)
optimizer_critic = torch.optim.Adam(params=critic.parameters(), lr=0.001)

cov_var = torch.full(size=(n_actions,), fill_value=0.5)
cov_mat = torch.diag(cov_var)

In [None]:
def get_action(belief_map, state_vector, hidden):
    mean, new_hidden = actor(belief_map, state_vector, hidden)
    dist = MultivariateNormal(mean, cov_mat)
    action = dist.sample()
    log_prob = dist.log_prob(action)
    return action.detach().numpy(), log_prob.detach(), hidden

In [None]:
Transition = namedtuple('Transition', ('belief_map', 'state_vector', 'action', 'log_probability' ,'reward'))

class EpisodeTransitions:

    def __init__(self):
        self._transitions = []

    def push(self, *args):
        self._transitions.append(Transition(*args))

    def __getitem__(self, index):
        self._transitions[index]

    @property 
    def transitions(self):
        return self._transitions

    def __len__(self):
        return len(self._transitions)+1

class EpisodeMemoryBuffer:

    def __init__(self, capacity=10):
        self.capicity = capacity
        self.memory = []
        
    def push(self, episode):
        self.memory.append(episode)
    
    @property
    def get_batch(self):
        transitions = []
        for episode in self.memory:
            transitions.extend(episode.transistions)
        
        return transitions

    def __len__(self):
        return sum([len(episode) for episode in self.memory])

    @property
    def reward_to_go(self):
        _reward_to_go = []
        for episode in reversed(self.memory):

            discounted_reward = 0

            for transition in episode.transitions:
                discounted_reward = transition.reward.item() + discounted_reward * GAMMA
                _reward_to_go = [discounted_reward] + _reward_to_go
        
        return torch.Tensor(_reward_to_go, dtype=torch.float, device=device)





In [None]:
fireEnv = ProbabilisticFireEnv(height, width)
dronesEnv = DronesEnv(height, width, DT, DTI) 

def rollout():

    memory_buffer = EpisodeMemoryBuffer

    episode_transitions_1 = EpisodeTransitions()
    episode_transitions_2 = EpisodeTransitions()

    episode_i = 0

    observation = fireEnv.reset()
    dronesEnv.reset(observation)

    episode_length = 0 

    while episode_i < EPISODES_PER_BATCH:

        for j in range(TRAIN_FREQ//int(2*DT/DTI)):

            observation = fireEnv.step()

            state_vector_1 = dronesEnv.drones[0].state
            map_1 = dronesEnv.drones[0].observation
            state_vector_1 = torch.tensor(state_vector_1, device=device, dtype=torch.float)
            map_1 = torch.tensor(map_1, device=device, dtype=torch.float)

            state_vector_2 = dronesEnv.drones[1].state
            map_2 = dronesEnv.drones[1].observation
            state_vector_2 = torch.tensor(state_vector_2, device=device, dtype=torch.float)
            map_2 = torch.tensor(map_2, device=device, dtype=torch.float)


            for i in range(int(DT/DTI)):

                episode_length += 1
                action1, log_probability_1, hidden_1 = get_action(map_1, state_vector_1, hidden_1)
                action2, log_probability_2, hidden_2 = get_action(map_2, state_vector_2, hidden_2)
                reward_1, reward_2 = dronesEnv.step([action1.item(), action2.item()], observation)

                next_state_vector_1 = dronesEnv.drones[0].state
                next_map_1 = dronesEnv.drones[0].observation

                next_state_vector_1 = torch.tensor(next_state_vector_1, device=device, dtype=torch.float)
                next_map_1 = torch.tensor(next_map_1, device=device, dtype=torch.float)

                next_state_vector_2 = dronesEnv.drones[1].state
                next_map_2 = dronesEnv.drones[1].observation

                next_state_vector_2 = torch.tensor(next_state_vector_2, device=device, dtype=torch.float)
                next_map_2 = torch.tensor(next_map_2, device=device, dtype=torch.float)

                reward_1 = torch.tensor([reward_1], device=device)
                reward_2 = torch.tensor([reward_2], device=device)  

                episode_transitions_1.push(map_1, state_vector_1, action1, log_probability_1, reward_1)
                episode_transitions_2.push(map_2, state_vector_2, action2, log_probability_2, reward_2)

                state_vector_1 = next_state_vector_1
                state_vector_2 = next_state_vector_2

                map_1 = next_map_1
                map_2 = next_map_2

            if not fireEnv.fire_in_range(6):
                memory_buffer.push(episode_transitions_1)
                memory_buffer.push(episode_transitions_2)
                episode_i += 1
                observation = fireEnv.reset()
                dronesEnv.reset(observation)
            
    return memory_buffer



In [None]:
def evaluate(belief_maps, state_vectors, actions):
    
    values, _ = critic(belief_maps, state_vectors).squeeze()


    mean = actor(belief_maps, state_vectors)
    dist = MultivariateNormal(mean, cov_mat)
    log_probs = dist.log_prob(actions)


    return values, log_probs

In [None]:
def learn(self, total_timesteps):
    t_so_far = 0 # Timesteps simulated so far
    i_so_far = 0 # Iterations ran so far

    while t_so_far < total_timesteps:  
        memory_buffer = rollout()

        t_so_far += len(memory_buffer)
        i_so_far += 1

        batch = Transition(*zip(*memory_buffer.get_batch))

        belief_map_batch = torch.cat(batch.belief_map)
        state_vector_batch = torch.cat(batch.state_vector)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        log_probs_batch = torch.cat(batch.log_probability)
        
        reward_to_go = memory_buffer.reward_to_go
        values, _ = evaluate(belief_map_batch , state_vector_batch, action_batch)
        A_k =  reward_to_go - values.detach()  
        A_k = (A_k - A_k.mean()) / (A_k.std() + 1e-10)
        for _ in range(5):   
            values, current_log_probs  = evaluate(belief_map_batch , state_vector_batch, action_batch)
            ratios = torch.exp(current_log_probs - log_probs_batch)
            surr1 = ratios * A_k
            surr2 = torch.clamp(ratios, 1 - self.clip, 1 + self.clip) * A_k

            actor_loss = (-torch.min(surr1, surr2)).mean()
            critic_loss = nn.MSELoss()(values, reward_to_go)

            optimizer_actor.zero_grad()
            actor_loss.backward(retain_graph=True)
            optimizer_actor.step()

            optimizer_critic.zero_grad()
            critic_loss.backward(retain_graph=True)
            optimizer_critic.step()

        if i_so_far % SAVE_FREQ == 0:
            torch.save(self.actor.state_dict(), './ppo_actor.pth')
            torch.save(self.critic.state_dict(), './ppo_critic.pth')