In [11]:
!pip install svgpath2mpl

[0m

In [12]:
from collections import namedtuple
import numpy as np
import torch
import torch.nn as nn
import imageio
from svgpath2mpl import parse_path
import matplotlib.pyplot as plt
import matplotlib
from scipy.ndimage import rotate, shift
from matplotlib.animation import FuncAnimation
from probabilistic_fire_env import ProbabilisticFireEnv
from drone_env import DronesEnv
from networks.ppo_net import PPONet
from torch.distributions import MultivariateNormal, Categorical

In [13]:
DT          = 0.5  # Time between wildfire updates            
DTI         = 0.1  # Time between aircraft decisions
n_actions = 2
height = width = 100
channels = 2
EPISODES_PER_BATCH = 1
TRAIN_FREQ  = 10
SAVE_FREQ = 10
GAMMA = 0.95
CLIP  = 0.2
BATCH_SIZE = 64

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PPONet(device,  channels, height, width, n_actions).to(device)
model.load_state_dict(torch.load('ppo_weights.pt'))

<All keys matched successfully>

In [15]:
fireEnv = ProbabilisticFireEnv(height, width)
dronesEnv = DronesEnv(height, width, DT, DTI) 
loss = None
i_episode = 1
images = []
observation = fireEnv.reset()
dronesEnv.reset(observation)

while True:

  observation = fireEnv.step()

  state_vector_1 = dronesEnv.drones[0].state
  map_1 = dronesEnv.drones[0].observation
  state_vector_1 = torch.tensor(state_vector_1, device=device, dtype=torch.float)
  map_1 = torch.tensor(map_1, device=device, dtype=torch.float)

  state_vector_2 = dronesEnv.drones[1].state
  map_2 = dronesEnv.drones[1].observation
  state_vector_2 = torch.tensor(state_vector_2, device=device, dtype=torch.float)
  map_2 = torch.tensor(map_2, device=device, dtype=torch.float)

  for i in range(int(DT/DTI)):

    policy1, _ = model(map_1, state_vector_1)
    policy2, _ = model(map_2, state_vector_2)
       

    dronesEnv.step([ torch.argmax(policy1).item(), torch.argmax(policy2).item()], observation)

    next_state_vector_1 = dronesEnv.drones[0].state
    next_map_1 = dronesEnv.drones[0].observation

    next_state_vector_1 = torch.tensor(next_state_vector_1, device=device, dtype=torch.float)
    next_map_1 = torch.tensor(next_map_1, device=device, dtype=torch.float)

    next_state_vector_2 = dronesEnv.drones[1].state
    next_map_2 = dronesEnv.drones[1].observation

    next_state_vector_2 = torch.tensor(next_state_vector_2, device=device, dtype=torch.float)
    next_map_2 = torch.tensor(next_map_2, device=device, dtype=torch.float)

    state_vector_1 = next_state_vector_1
    state_vector_2 = next_state_vector_2

    map_1 = next_map_1
    map_2 = next_map_2

  if not fireEnv.fire_in_range(6):

    break

  fig, ax = plt.subplots(1, 5, figsize=(28, 4))
  dronesEnv.plot_drones(fig, ax[0])
  dronesEnv.plot_belief_map(fig, ax[1])
  dronesEnv.plot_time_elapsed(fig, ax[2])
  dronesEnv.plot_trajectory(fig, ax[3])
  fireEnv.plot_heat_map(fig, ax[4])
  fig.canvas.draw() 
  image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
  image  = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
  images.append(image)

kwargs_write = {'fps':5.0, 'quantizer':'nq'}
imageio.mimsave('./ppo_example.gif', images, fps=5)


  fig, ax = plt.subplots(1, 5, figsize=(28, 4))
