In [1]:
!pip install svgpath2mpl

Collecting svgpath2mpl
  Downloading svgpath2mpl-1.0.0-py2.py3-none-any.whl (7.8 kB)
Installing collected packages: svgpath2mpl
Successfully installed svgpath2mpl-1.0.0
[0m

In [2]:
from collections import namedtuple, deque
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adamax
import random
from svgpath2mpl import parse_path
import matplotlib.pyplot as plt
import matplotlib
from scipy.ndimage import rotate, shift
from matplotlib.animation import FuncAnimation
from probabilistic_fire_env import ProbabilisticFireEnv
from drone_env import DronesEnv
from replay_memory import Transition
from networks.drqn import DRQN
from episode_buffer import EpisodeBuffer

In [3]:
height = width = 100
BATCH_SIZE = 5
GAMMA = 0.99
INIT_SIZE = 5
TARGET_UPDATE = 1000
SAVE_POLICY = 100
EPISODE_LENGTH = 250
TRAIN_FREQ  = 10   # Number of samples to generate between trainings (Should be multiple of 10)
PRINT_FREQ  = 100  # Frequency of printing (Should be a multiple of 10)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_actions = 2
screen_height = screen_width = 100
channels = 2
policy_net = DRQN(device, channels, screen_height, screen_width, n_actions).to(device)
target_net = DRQN(device, channels, screen_height, screen_width, n_actions).to(device)
steps = 0
target_net.load_state_dict(policy_net.state_dict())
episode_buffer = EpisodeBuffer()
policy_net.train()
target_net.eval()
EPISODE_LENGTH = 128
update_counter = 0
optimizer = Adamax(policy_net.parameters(), lr=0.0001)

In [5]:
def optimize_model():
    
    loss = None
    
    global update_counter

    episode_batch, epiosde_length = episode_buffer.sample(EPISODE_LENGTH)

    update_counter += 1
    batch = Transition(*zip(*episode_batch))

    next_states = torch.cat(batch.next_state_vector)
    next_belief_map = torch.cat(batch.next_belief_map)

    belief_map_batch = torch.cat(batch.belief_map)
    state_vector_batch = torch.cat(batch.state_vector)

    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    hidden_policy = policy_net.init_hidden_state()
    hidden_target = target_net.init_hidden_state()

    policy_output, _ = policy_net(belief_map_batch, state_vector_batch, hidden_policy)
    target_output, _ = target_net(next_belief_map, next_states, hidden_target)

    state_action_values = policy_output.gather(1, action_batch)
    next_state_values = target_output.max(1)[0].detach()

    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss().to(device)
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

    if update_counter % TARGET_UPDATE == 0:
        policy_file_path = f'./policy_weights_drqn.pt'
        target_file_path = f'./target_weights_drqn.pt'
        torch.save(policy_net.state_dict(), policy_file_path)
        torch.save(target_net.state_dict(), target_file_path)
        print('update target')
        target_net.load_state_dict(policy_net.state_dict())

    return loss

In [6]:
DT          = 0.5  # Time between wildfire updates            
DTI         = 0.1  # Time between aircraft decisions
fireEnv = ProbabilisticFireEnv(height, width)
dronesEnv = DronesEnv(height, width, DT, DTI) 
loss = None
i_episode = 1

observation = fireEnv.reset()
dronesEnv.reset(observation)

episode_memory_1 = []
episode_memory_2 = []

hidden_1 = policy_net.init_hidden_state()
hidden_2 = policy_net.init_hidden_state()

while True:
  # Initialize the environment and state
  #env.reset()
  for j in range(TRAIN_FREQ//int(2*DT/DTI)):

    observation = fireEnv.step()

    state_vector_1 = dronesEnv.drones[0].state
    map_1 = dronesEnv.drones[0].observation
    state_vector_1 = torch.tensor(state_vector_1, device=device, dtype=torch.float)
    map_1 = torch.tensor(map_1, device=device, dtype=torch.float)

    state_vector_2 = dronesEnv.drones[1].state
    map_2 = dronesEnv.drones[1].observation
    state_vector_2 = torch.tensor(state_vector_2, device=device, dtype=torch.float)
    map_2 = torch.tensor(map_2, device=device, dtype=torch.float)

    for i in range(int(DT/DTI)):

      action1, hidden_1 = policy_net.select_action(map_1, state_vector_1, steps, hidden=hidden_1)
      action2, hidden_2 = policy_net.select_action(map_2, state_vector_2, steps, hidden=hidden_2)
      steps += 2
      reward_1, reward_2 = dronesEnv.step([action1.item(), action2.item()], observation)

      next_state_vector_1 = dronesEnv.drones[0].state
      next_map_1 = dronesEnv.drones[0].observation

      next_state_vector_1 = torch.tensor(next_state_vector_1, device=device, dtype=torch.float)
      next_map_1 = torch.tensor(next_map_1, device=device, dtype=torch.float)

      next_state_vector_2 = dronesEnv.drones[1].state
      next_map_2 = dronesEnv.drones[1].observation

      next_state_vector_2 = torch.tensor(next_state_vector_2, device=device, dtype=torch.float)
      next_map_2 = torch.tensor(next_map_2, device=device, dtype=torch.float)

      reward_1 = torch.tensor([reward_1], device=device)
      reward_2 = torch.tensor([reward_2], device=device)  

      episode_memory_1.append(Transition(map_1, state_vector_1, action1, next_map_1, next_state_vector_1, reward_1))
      episode_memory_2.append(Transition(map_2, state_vector_2, action2, next_map_2, next_state_vector_2, reward_2))

      state_vector_1 = next_state_vector_1
      state_vector_2 = next_state_vector_2

      map_1 = next_map_1
      map_2 = next_map_2

    if i_episode>=INIT_SIZE:
      loss = optimize_model()

    if not fireEnv.fire_in_range(6):
      observation = fireEnv.reset()
      dronesEnv.reset(observation)
      episode_buffer.push(episode_memory_1.copy())
      episode_buffer.push(episode_memory_2.copy())
      episode_memory_1 = []
      episode_memory_2 = []
      hidden_1 = policy_net.init_hidden_state()
      hidden_2 = policy_net.init_hidden_state()
      i_episode +=1
      
      if (i_episode+1) % 5 == 0:
        print(f'{i_episode+1} episodes completed')
        print(f'loss {loss}')
        print(f'steps done {steps}')



5 episodes completed
loss None
steps done 2070
10 episodes completed
loss 0.0020916499197483063
steps done 6010
15 episodes completed
loss 0.17738160490989685
steps done 10290
update target
20 episodes completed
loss 1.1363341808319092
steps done 14060
25 episodes completed
loss 0.00010166229185415432
steps done 17820
30 episodes completed
loss 0.18423466384410858
steps done 21220
update target
35 episodes completed
loss 0.0004083088133484125
steps done 25030
40 episodes completed
loss 1.2048654556274414
steps done 28910
update target
45 episodes completed
loss 2.273698091506958
steps done 33010
50 episodes completed
loss 1.174837350845337
steps done 35880
55 episodes completed
loss 0.00740866269916296
steps done 39610
update target
60 episodes completed
loss 0.00013605505228042603
steps done 43600
65 episodes completed
loss 0.002227463060989976
steps done 47050
70 episodes completed
loss 0.7172749042510986
steps done 51410
update target
75 episodes completed
loss 0.017545830458402634


: 