In [8]:
!pip install svgpath2mpl



In [9]:
from collections import namedtuple, deque
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adamax
import random
import math 
from svgpath2mpl import parse_path
import matplotlib.pyplot as plt
import matplotlib
from scipy.ndimage import rotate, shift
from matplotlib.animation import FuncAnimation
#from probabilistic_fire_env import ProbabilisticFireEnv
#from drone_env import DronesEnv
from env.wildfire_gym import WildFireGym
from replay_memory import ReplayMemory, Transition
from networks.dqn import DQN

In [10]:
height = width = 100
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.1
EPS_DECAY = 200000
INIT_SIZE = 20000
TARGET_UPDATE = 1000
SAVE_POLICY = 100
EPISODE_LENGTH = 250
TRAIN_FREQ  = 10   # Number of samples to generate between trainings (Should be multiple of 10)
PRINT_FREQ  = 100  # Frequency of printing (Should be a multiple of 10)

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_actions = 2
screen_height = screen_width = 100
channels = 2
policy_net = DQN(device, channels, screen_height, screen_width, n_actions).to(device)
target_net = DQN(device, channels, screen_height, screen_width, n_actions).to(device)
steps = 0
policy_file_path = f'./policy_weights.pt'
target_file_path = f'./target_weights.pt'

#policy_net.load_state_dict(torch.load(policy_file_path))
#target_net.load_state_dict(torch.load('target_weights.pt'))
target_net.load_state_dict(policy_net.state_dict())
memory = ReplayMemory(70000)
#memory.load()
policy_net.train()
target_net.eval()
update_counter = 0
optimizer = Adamax(policy_net.parameters(), lr=0.0001)


In [12]:
def select_action(belief_map, state_vector, steps):
  sample = random.random()
  eps_threshold = EPS_END + (EPS_START - EPS_END) * \
    math.exp(-1. * steps / EPS_DECAY)

  if sample > eps_threshold:
    with torch.no_grad():
      output = policy_net(belief_map, state_vector).max(1)[1].view(1, 1)
      return output
  else:
    return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long)

In [13]:
def optimize_model():
    
    global update_counter
    update_counter += 1
    transitions = memory.sample(BATCH_SIZE)

    batch = Transition(*zip(*transitions))

    next_states = torch.cat(batch.next_state_vector)
    next_belief_map = torch.cat(batch.next_belief_map)

    belief_map_batch = torch.cat(batch.belief_map)
    state_vector_batch = torch.cat(batch.state_vector)
    
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(belief_map_batch, state_vector_batch).gather(1, action_batch)
    next_state_values = target_net(next_belief_map, next_states).max(1)[0].detach()

    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss().to(device)
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

    if update_counter % TARGET_UPDATE == 0:
        policy_file_path = f'./policy_weights2.pt'
        target_file_path = f'./target_weights2.pt'
        torch.save(policy_net.state_dict(), policy_file_path)
        torch.save(target_net.state_dict(), target_file_path)
        print('update target')
        target_net.load_state_dict(policy_net.state_dict())

    return loss

In [14]:
loss = None
i_episode = 1

wildFireGym = WildFireGym()
observation = wildFireGym.reset()
steps = 0 

map_1 = torch.tensor(observation[0]['belief_map'], device=device, dtype=torch.float)
map_2 = torch.tensor(observation[1]['belief_map'], device=device, dtype=torch.float)

state_vector_1 = torch.tensor(observation[0]['state_vector'], device=device, dtype=torch.float)
state_vector_2 = torch.tensor(observation[1]['state_vector'], device=device, dtype=torch.float)

while True:

  action1 = select_action(map_1, state_vector_1, steps)
  action2 = select_action(map_2, state_vector_2, steps)

  next_observation, rewards, done, _ = wildFireGym.step([action1, action2])

  reward_1 = torch.tensor([rewards[0]], device=device)
  reward_2 = torch.tensor([rewards[1]], device=device)  

  next_map_1 = torch.tensor(next_observation[0]['belief_map'], device=device, dtype=torch.float)
  next_map_2 = torch.tensor(next_observation[1]['belief_map'], device=device, dtype=torch.float)

  next_state_vector_1 = torch.tensor(next_observation[0]['state_vector'], device=device, dtype=torch.float)
  next_state_vector_2 = torch.tensor(next_observation[1]['state_vector'], device=device, dtype=torch.float)

  memory.push(map_1, state_vector_1, action1, next_map_1, next_state_vector_1, reward_1)
  memory.push(map_2, state_vector_2, action2, next_map_2, next_state_vector_2, reward_2)

  observation = next_observation

  map_1 = next_map_1
  map_2 = next_map_2

  state_vector_1 = next_state_vector_1
  state_vector_2 = next_state_vector_2

  if done:
    print(f'episode {i_episode} completed')
    i_episode += 1
    observation = wildFireGym.reset()
    steps = 0 

    map_1 = torch.tensor(observation[0]['belief_map'], device=device, dtype=torch.float)
    map_2 = torch.tensor(observation[1]['belief_map'], device=device, dtype=torch.float)

    state_vector_1 = torch.tensor(observation[0]['state_vector'], device=device, dtype=torch.float)
    state_vector_2 = torch.tensor(observation[1]['state_vector'], device=device, dtype=torch.float)

  if steps>=INIT_SIZE:
    loss = optimize_model()

TypeError: step() missing 1 required positional argument: 'action_n'