In [8]:
!pip install svgpath2mpl
!pip install gym

[0m

In [9]:
from collections import namedtuple
import torch
from torch.distributions import Categorical
import torch.nn.functional as F
import numpy as np
from environments.wildfire_gym import WildFireGym

In [10]:

n_actions = 2
height = width = 100
channels = 2
EPISODES_PER_BATCH = 1
TRAIN_FREQ  = 10
SAVE_FREQ = 20
GAMMA = 0.95

BATCH_SIZE = 128


LAMDA = 0.95
UPDATES = 10000
EPSILON = 0.1

EPOCHS = 10
NUM_PROCESSES = 1
BETA = 0.01
TAU = 1.0

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PPONet(device,  channels, height, width, n_actions).to(device)
optimizer= torch.optim.Adam(params=model.parameters(), lr=0.0001)

In [12]:
Transition = namedtuple('Transition', ('belief_map', 'state_vector', 'action', 'reward', 'value', 'log_policy'))

class ReplayMemory:

  def __init__(self):
    self._memory = []
    
  def push(self, *args):
    """Save a transition"""
    self._memory.append(Transition(*args))

  @property
  def get_batch(self):
    return self._memory

  def __len__(self):
    return len(self._memory)

  def clear(self):
    self._memory.clear()

In [13]:
import torch
import torch.nn as nn
from networks.basedqn import BaseDQN
from torch.distributions import Categorical
import torch.nn.functional as F

class PPONet(BaseDQN):

  def __init__(self, _device, _channels, _height, _width, _outputs):

    super().__init__(_device, _channels, _height, _width, _outputs)
    self.to(device=_device)
    self.fc1  = nn.Sequential(
      nn.Linear(5, 50),
      nn.ReLU(),
      nn.Linear(50, 50),
      nn.ReLU(),
      nn.Linear(50, 50),
      nn.ReLU(),
    )

    self.conv = nn.Sequential(
      nn.Conv2d(2, 32, kernel_size=3),
      nn.ReLU(),
      nn.MaxPool2d(2, stride=2),
      nn.Conv2d(32, 32, kernel_size=3),
      nn.ReLU(),
      nn.Conv2d(32, 32, kernel_size=3),
      nn.ReLU(),
      nn.MaxPool2d(2, stride=2)
    )
  
    conv_out_size = self._get_conv_out()

    self.fc2 = nn.Sequential(
      nn.Linear(conv_out_size+50, 256),
      nn.ReLU(),
    )

    self.actor = nn.Linear(256, _outputs)

    self.critic = nn.Linear(256, 1)
    self._initialize_weights()


  def forward(self, belief_map, state_vector):

    fc1_out = self.fc1(state_vector)
    conv_out = torch.flatten(self.conv(belief_map),1)
    fc2_out = self.fc2(conv_out)
    
    fc3_out = self.fc3(torch.cat((fc1_out, fc2_out), dim=1))
    return self.actor(fc3_out), self.critic(fc3_out)

  def _initialize_weights(self):
    for module in self.modules():
      if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        nn.init.orthogonal_(module.weight, nn.init.calculate_gain('relu'))
        nn.init.constant_(module.bias, 0)

In [14]:
loss = None
i_episode = 1
SAVE_MODEL = 10
N_DRONES = 2
steps = 0
done = False

In [15]:
loss = None
i_episode = 1
done = False
wildFireGym = WildFireGym()
observation = wildFireGym.reset()

memory = [ReplayMemory(), ReplayMemory()]
maps = [None, None]
state_vectors = [None, None]
reward_total = 0
while True:
    model.eval()
    maps[0] = torch.tensor(observation[0]['belief_map'], device=device, dtype=torch.float)
    maps[1] = torch.tensor(observation[1]['belief_map'], device=device, dtype=torch.float)

    state_vectors[0] = torch.tensor(observation[0]['state_vector'], device=device, dtype=torch.float)
    state_vectors[1] = torch.tensor(observation[1]['state_vector'], device=device, dtype=torch.float)

    logits_1, value_1 = model(maps[0], state_vectors[0])
    policy_1 = F.softmax(logits_1, dim=1)

    old_m_1 = Categorical(policy_1)
    
    action_1 = old_m_1.sample()

    old_log_policy_1 = old_m_1.log_prob(action_1)

    logits_2, value_2 = model(maps[1], state_vectors[1])
    
    policy_2 = F.softmax(logits_2, dim=1)

    old_m_2 = Categorical(policy_2)
    
    action_2 = old_m_2.sample()

    old_log_policy_2 = old_m_2.log_prob(action_2)

    next_observation, rewards, done, _ = wildFireGym.step([action_1.item(), action_2.item()])
    memory[0].push(maps[0], state_vectors[0], action_1, rewards[0], value_1.squeeze(), old_log_policy_1)
    memory[1].push(maps[1], state_vectors[1], action_2, rewards[1], value_2.squeeze(), old_log_policy_2)
    reward_total += sum(rewards)
    observation = next_observation
    
    if done:
        print("Episode: {}. Total reward: {}".format(i_episode, reward_total))
        reward_total = 0
        if i_episode % SAVE_MODEL == 0:
            file_path = f'./ppo_weights.pt'
            torch.save(model.state_dict(), file_path)
            print('saved')

        for d in range(N_DRONES):
            model.eval()
            _, next_value = model(maps[d], state_vectors[d])

            next_value = next_value.squeeze()
            
            batch  = Transition(*zip(*memory[d].get_batch))


            old_log_policies_batch = torch.cat(batch.log_policy).detach()
        

            actions_batch = torch.cat(batch.action)


            value_batch = torch.stack(batch.value).detach()
        

            belief_map_batch = torch.cat(batch.belief_map)


            state_vector_batch = torch.cat(batch.state_vector)


            reward_batch = batch.reward 
            
            gae = 0
            R = []

            for value, reward in list(zip(value_batch, reward_batch))[::-1]:
                gae = gae * GAMMA * TAU
                gae = gae + reward + GAMMA * next_value.detach() - value.detach()
                next_value = value
                R.append(gae + value)

            R = R[::-1]
            R = torch.stack(R).detach()
            
            advantages = R - value_batch
            advantages = (advantages - advantages.mean()) / advantages.std()
            model.train()
            indices = torch.randperm(len(memory[d]))
            batch_indices = torch.split(indices, BATCH_SIZE)
            for e_i in range(EPOCHS):
                for batch_indice in  batch_indices:
                    logits, values = model(belief_map_batch[batch_indice], state_vector_batch[batch_indice])
                    new_policy = F.softmax(logits, dim=1)
                    new_m = Categorical(new_policy)
                    new_log_policy = new_m.log_prob(actions_batch[batch_indice])
                    ratio = torch.exp(new_log_policy - old_log_policies_batch[batch_indice])
                    
                    actor_loss = -torch.mean(
                        torch.min(ratio * advantages[batch_indice],
                            torch.clamp(ratio, 1.0 - EPSILON, 1.0 + EPSILON) *
                            advantages[batch_indice]
                        )
                    )

                    critic_loss = F.smooth_l1_loss(R[batch_indice], values.squeeze())
          
                    entropy_loss = torch.mean(new_m.entropy())
               
                    total_loss = actor_loss + critic_loss - BETA * entropy_loss
                    optimizer.zero_grad()
                    total_loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                    optimizer.step()

            print("Episode: {}. Total loss: {}".format(i_episode, total_loss))
            memory[d].clear()

        i_episode += 1
        observation = wildFireGym.reset()

AttributeError: 'DronesEnv' object has no attribute 'step'