In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import gym
import numpy
from gym import wrappers
import os
from typing import Any, Dict, List, Optional
import glob
# from base.rl.ppo import PPO

ModuleNotFoundError: No module named 'base.rl'

In [3]:
class GlobalPolicy(nn.Module):
    def __init__(self, G=240, use_data_parallel=False, gpu_ids=[]):
        super().__init__()

        self.G = G

        self.actor = nn.Sequential(  # (8, G, G)
            nn.Conv2d(8, 8, 3, padding=1),  # (8, G, G)
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Conv2d(8, 4, 3, padding=1),  # (4, G, G)
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Conv2d(4, 4, 5, padding=2),  # (4, G, G)
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Conv2d(4, 2, 5, padding=2),  # (2, G, G)
            nn.BatchNorm2d(2),
            nn.ReLU(),
            nn.Conv2d(2, 1, 5, padding=2),  # (1, G, G)
            Flatten(),  # (G*G, )
            nn.Sigmoid(), # frontier_mask
        )

        self.critic = nn.Sequential(  # (8, G, G)
            nn.Conv2d(8, 8, 3, padding=1),  # (8, G, G)
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Conv2d(8, 4, 3, padding=1),  # (4, G, G)
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Conv2d(4, 4, 5, padding=2),  # (4, G, G)
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Conv2d(4, 2, 5, padding=2),  # (2, G, G)
            nn.BatchNorm2d(2),
            nn.ReLU(),
            nn.Conv2d(2, 1, 5, padding=2),  # (1, G, G)
            Flatten(),
            nn.Linear(self.G * self.G, 1),
        )

        if use_data_parallel:
            self.actor = nn.DataParallel(
                self.actor, device_ids=gpu_ids, output_device=gpu_ids[0],
            )
            self.critic = nn.DataParallel(
                self.critic, device_ids=gpu_ids, output_device=gpu_ids[0],
            )

    def forward(self, inputs):
        raise NotImplementedError

    def _get_h12(self, inputs):
        x = inputs["pose_in_map_at_t"]
        h = inputs["map_at_t"]

        h_1 = crop_map(h, x[:, :2], self.G)
        h_2 = F.adaptive_max_pool2d(h, (self.G, self.G))

        h_12 = torch.cat([h_1, h_2], dim=1)

        return h_12

    def act(self, inputs, rnn_hxs, prev_actions, masks, deterministic=False):
        """
        Note: inputs['pose_in_map_at_t'] must obey the following conventions:
              origin at top-left, downward Y and rightward X in the map coordinate system.
        """
        M = inputs["map_at_t"].shape[2]
        h_12 = self._get_h12(inputs)
        '''
        action_logits = self.actor(h_12)
        dist = FixedCategorical(logits=action_logits)
        value = self.critic(h_12)

        if deterministic:
            action = dist.mode()
        else:
            action = dist.sample()

        action_log_probs = dist.log_probs(action)
        '''
        #'''  # frontier_mask
        action_logits = torch.clamp(self.actor(h_12), min=1e-4, max=1 - 1e-4)
        value = self.critic(h_12)
        action_mask = inputs["frontier_mask"]
        action_probs_mask = torch.where(action_mask == 1, action_logits, torch.ones_like(action_logits)*1e-7)
        dist_2 = FixedCategorical(probs=action_probs_mask, validate_args=False)
        if deterministic:
            action = dist_2.mode()
        else:
            action = dist_2.sample()
        action_log_probs = dist_2.log_probs(action)
        #''' # frontier_mask
        return value, action, action_log_probs, rnn_hxs

    def get_value(self, inputs, rnn_hxs, prev_actions, masks):
        h_12 = self._get_h12(inputs)
        value = self.critic(h_12)
        return value

    def evaluate_actions(self, inputs, rnn_hxs, prev_actions, masks, action):
        h_12 = self._get_h12(inputs)
        '''
        action_logits = self.actor(h_12)
        dist = FixedCategorical(logits=action_logits)
        value = self.critic(h_12)

        action_log_probs = dist.log_probs(action)

        dist_entropy = dist.entropy().mean()

        return value, action_log_probs, dist_entropy, rnn_hxs
        '''

        #'''  # frontier_mask
        action_logits = torch.clamp(self.actor(h_12), min=1e-4, max=1-1e-4)
        value = self.critic(h_12)
        action_mask = inputs["frontier_mask"]
        action_probs_mask = torch.where(action_mask == 1, action_logits, torch.ones_like(action_logits)*1e-7)
        dist_2 = FixedCategorical(probs=action_probs_mask, validate_args=False)
        action_log_probs = dist_2.log_probs(action)

        dist_entropy = dist_2.entropy().mean()
        return value, action_log_probs, dist_entropy, rnn_hxs
        #'''  # frontier_mask