In [1]:
!pip install svgpath2mpl

Collecting svgpath2mpl
  Downloading svgpath2mpl-1.0.0-py2.py3-none-any.whl (7.8 kB)
Installing collected packages: svgpath2mpl
Successfully installed svgpath2mpl-1.0.0
[0m

In [2]:
from collections import namedtuple, deque
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adamax
import random
import math 
from svgpath2mpl import parse_path
import matplotlib.pyplot as plt
import matplotlib
from scipy.ndimage import rotate, shift
from matplotlib.animation import FuncAnimation
from probabilistic_fire_env import ProbabilisticFireEnv
from drone_env import DronesEnv
from replay_memory import ReplayMemory, Transition
from models.dqn import DQN

In [3]:
height = width = 100
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.1
EPS_DECAY = 200000
INIT_SIZE = 20000
TARGET_UPDATE = 1000
SAVE_POLICY = 100
EPISODE_LENGTH = 250
TRAIN_FREQ  = 10   # Number of samples to generate between trainings (Should be multiple of 10)
PRINT_FREQ  = 100  # Frequency of printing (Should be a multiple of 10)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_actions = 2
screen_height = screen_width = 100
channels = 2
policy_net = DQN(channels, screen_height, screen_width, n_actions).to(device)
target_net = DQN(channels, screen_height, screen_width, n_actions).to(device)
steps = 0
policy_file_path = f'./policy_weights.pt'
target_file_path = f'./target_weights.pt'

#policy_net.load_state_dict(torch.load(policy_file_path))
#target_net.load_state_dict(torch.load('target_weights.pt'))
target_net.load_state_dict(policy_net.state_dict())
memory = ReplayMemory(70000)
#memory.load()
policy_net.train()
target_net.eval()
update_counter = 0
optimizer = Adamax(policy_net.parameters(), lr=0.0001)


In [5]:
def select_action(belief_map, state_vector, steps):
  sample = random.random()
  eps_threshold = EPS_END + (EPS_START - EPS_END) * \
    math.exp(-1. * steps / EPS_DECAY)

  if sample > eps_threshold:
    with torch.no_grad():
      output = policy_net(belief_map, state_vector).max(1)[1].view(1, 1)
      return output
  else:
    return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long)

In [6]:
def optimize_model():
    
    global update_counter
    update_counter += 1
    transitions = memory.sample(BATCH_SIZE)

    batch = Transition(*zip(*transitions))

    next_states = torch.cat(batch.next_state_vector)
    next_belief_map = torch.cat(batch.next_belief_map)

    belief_map_batch = torch.cat(batch.belief_map)
    state_vector_batch = torch.cat(batch.state_vector)
    
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(belief_map_batch, state_vector_batch).gather(1, action_batch)
    next_state_values = target_net(next_belief_map, next_states).max(1)[0].detach()

    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss().to(device)
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

    if update_counter % TARGET_UPDATE == 0:
        policy_file_path = f'./policy_weights2.pt'
        target_file_path = f'./target_weights2.pt'
        torch.save(policy_net.state_dict(), policy_file_path)
        torch.save(target_net.state_dict(), target_file_path)
        print('update target')
        target_net.load_state_dict(policy_net.state_dict())

    return loss

In [7]:
DT          = 0.5  # Time between wildfire updates            
DTI         = 0.1  # Time between aircraft decisions
fireEnv = ProbabilisticFireEnv(height, width)
dronesEnv = DronesEnv(height, width, DT, DTI) 
loss = None
i_episode = 1

observation = fireEnv.reset()
dronesEnv.reset(observation)

while True:
  # Initialize the environment and state
  #env.reset()
  for j in range(TRAIN_FREQ//int(2*DT/DTI)):

    observation = fireEnv.step()

    state_vector_1 = dronesEnv.drones[0].state
    map_1 = dronesEnv.drones[0].observation
    state_vector_1 = torch.tensor(state_vector_1, device=device, dtype=torch.float)
    map_1 = torch.tensor(map_1, device=device, dtype=torch.float)

    state_vector_2 = dronesEnv.drones[1].state
    map_2 = dronesEnv.drones[1].observation
    state_vector_2 = torch.tensor(state_vector_2, device=device, dtype=torch.float)
    map_2 = torch.tensor(map_2, device=device, dtype=torch.float)

    for i in range(int(DT/DTI)):
      action1 = select_action(map_1, state_vector_1, steps)
      action2 = select_action(map_2, state_vector_2, steps)
      steps += 2
      reward_1, reward_2 = dronesEnv.step([action1.item(), action2.item()], observation)

      next_state_vector_1 = dronesEnv.drones[0].state
      next_map_1 = dronesEnv.drones[0].observation

      next_state_vector_1 = torch.tensor(next_state_vector_1, device=device, dtype=torch.float)
      next_map_1 = torch.tensor(next_map_1, device=device, dtype=torch.float)

      next_state_vector_2 = dronesEnv.drones[1].state
      next_map_2 = dronesEnv.drones[1].observation

      next_state_vector_2 = torch.tensor(next_state_vector_2, device=device, dtype=torch.float)
      next_map_2 = torch.tensor(next_map_2, device=device, dtype=torch.float)

      reward_1 = torch.tensor([reward_1], device=device)
      reward_2 = torch.tensor([reward_2], device=device)  

      memory.push(map_1, state_vector_1, action1, next_map_1, next_state_vector_1, reward_1)
      memory.push(map_2, state_vector_2, action2, next_map_2, next_state_vector_2, reward_2)

      state_vector_1 = next_state_vector_1
      state_vector_2 = next_state_vector_2

      map_1 = next_map_1
      map_2 = next_map_2

    if not fireEnv.fire_in_range(6):
      observation = fireEnv.reset()
      dronesEnv.reset(observation)

      i_episode +=1
      if (i_episode+1) % 5 == 0:
        print(f'{i_episode+1} episodes completed')
        print(f'loss {loss}')
        print(f'steps done {steps}')
      

  if steps>=INIT_SIZE:
    loss = optimize_model()

5 episodes completed
loss None
steps done 1710
10 episodes completed
loss None
steps done 5170
15 episodes completed
loss None
steps done 9670
20 episodes completed
loss None
steps done 14510
25 episodes completed
loss None
steps done 17980
30 episodes completed
loss 0.6751432418823242
steps done 21230
35 episodes completed
loss 0.8153688311576843
steps done 25140
40 episodes completed
loss 0.6013771891593933
steps done 28610
update target
45 episodes completed
loss 0.4485732614994049
steps done 32520
50 episodes completed
loss 0.47188690304756165
steps done 36010
55 episodes completed
loss 0.48938125371932983
steps done 39190
update target
60 episodes completed
loss 0.4722789227962494
steps done 43100
65 episodes completed
loss 0.33767175674438477
steps done 47770
update target
70 episodes completed
loss 1.1815670728683472
steps done 51600
75 episodes completed
loss 0.6511623859405518
steps done 55650
80 episodes completed
loss 0.8221272826194763
steps done 59090
update target
85 epis

: 