In [1]:
!pip install svgpath2mpl
!pip install gym



In [None]:
from collections import namedtuple
import torch
from networks.ppo_net import PPONet
from torch.distributions import Categorical
import torch.nn.functional as F
import numpy as np
from environments.shared_wildfire_gym import SharedWildFireGym

In [None]:
n_actions = 4
height = width = 100
channels = 2
EPISODES_PER_BATCH = 1
TRAIN_FREQ  = 10
SAVE_FREQ = 20
GAMMA = 0.95
BATCH_SIZE = 64

LAMDA = 0.95
UPDATES = 10000
EPSILON = 0.1

EPOCHS = 10
NUM_PROCESSES = 1
BETA = 0.01
TAU = 1.0

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = PPONet(device,  channels, height, width, n_actions).to(device)
optimizer= torch.optim.Adam(params=model.parameters(), lr=0.0001)


In [None]:
Transition = namedtuple('Transition', ('belief_map', 'state_vector', 'action', 'reward', 'value', 'log_policy'))

class ReplayMemory:

  def __init__(self):
    self._memory = []
    
  def push(self, *args):
    """Save a transition"""
    self._memory.append(Transition(*args))

  @property
  def get_batch(self):
    return self._memory

  def __len__(self):
    return len(self._memory)

  def clear(self):
    self._memory.clear()

In [None]:
loss = None
i_episode = 1
SAVE_MODEL = 10
N_DRONES = 2
steps = 0
done = False

In [None]:
loss = None
i_episode = 1
done = False
wildFireGym = SharedWildFireGym()
observation = wildFireGym.reset()

memory = ReplayMemory()
belief_map = None
state_vector = None
total_reward = 0 
while True:

    belief_map = torch.tensor(observation['belief_map'], device=device, dtype=torch.float)

    state_vector = torch.tensor(observation['state_vector'], device=device, dtype=torch.float)

    logits, value = model(belief_map, state_vector)
    
    policy = F.softmax(logits, dim=1)

    old_m = Categorical(policy)
    
    action = old_m.sample()

    old_log_policy = old_m.log_prob(action)
    action_item = action.item()
    action_vector = None
    if action_item  == 0:
        action_vector = [0, 0]
    elif action_item == 1:
        action_vector = [0, 1]
    elif action_item == 2:
        action_vector = [1, 0]
    else:
        action_vector = [1, 1]

    next_observation, rewards, done, _ = wildFireGym.step(action_vector)
    reward = sum(rewards)
    memory.push(belief_map, state_vector, action, reward, value.squeeze(), old_log_policy)
    total_reward += reward
    observation = next_observation

    if done:
        print("Episode: {}. Total reward: {}".format(i_episode, total_reward))
        total_reward = 0 
        _, next_value = model(
            torch.tensor(next_observation['belief_map'], device=device, dtype=torch.float), 
            torch.tensor(next_observation['state_vector'], device=device, dtype=torch.float)
        )

        next_value = next_value.squeeze()
        
        batch  = Transition(*zip(*memory.get_batch))


        old_log_policies_batch = torch.cat(batch.log_policy).detach()
    

        actions_batch = torch.cat(batch.action)


        value_batch = torch.stack(batch.value).detach()
    

        belief_map_batch = torch.cat(batch.belief_map)


        state_vector_batch = torch.cat(batch.state_vector)


        reward_batch = batch.reward 
        
        gae = 0
        R = []

        for value, reward in list(zip(value_batch, reward_batch))[::-1]:
            gae = gae * GAMMA * TAU
            gae = gae + reward + GAMMA * next_value.detach() - value.detach()
            next_value = value
            R.append(gae + value)

        R = R[::-1]
        R = torch.stack(R).detach()
            
        advantages = R - value_batch
        advantages = (advantages - advantages.mean()) / advantages.std()
        indices = torch.randperm(len(memory))
        batch_indices = torch.split(indices, BATCH_SIZE)
        for e_i in range(EPOCHS):
     
            for batch_indice in  batch_indices:

                logits, values = model(belief_map_batch[batch_indice], state_vector_batch[batch_indice ])
                new_policy = F.softmax(logits, dim=1)
                new_m = Categorical(new_policy)
                new_log_policy = new_m.log_prob(actions_batch[batch_indice])
                ratio = torch.exp(new_log_policy - old_log_policies_batch[batch_indice])
                
                actor_loss = -torch.mean(
                    torch.min(ratio * advantages[batch_indice ],
                        torch.clamp(ratio, 1.0 - EPSILON, 1.0 + EPSILON) *
                        advantages[batch_indice ]
                    )
                )

                critic_loss = F.smooth_l1_loss(R[batch_indice], values.squeeze())
        
                entropy_loss = torch.mean(new_m.entropy())
            
                total_loss = actor_loss + critic_loss - BETA * entropy_loss
                optimizer.zero_grad()
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()

        print("Episode: {}. Total loss: {}".format(i_episode, total_loss))
        memory.clear()

        
        observation = wildFireGym.reset()
        
        if i_episode % SAVE_MODEL == 0:
            file_path = f'./mappo_weights.pt'
            torch.save(model.state_dict(), file_path)
            print('saved')

        i_episode += 1

RuntimeError: mat1 dim 1 must match mat2 dim 0