# Actor Critic Agent(s) for Resource Allocation
- Tests with Gym first
- https://pytorch.org/docs/stable/tensorboard.html

In [69]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import gym

from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import gym.spaces as spaces
from torch.distributions import Categorical
import torch.optim as optim

## Actor-Critic
- https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html

In [70]:
NUM_HIDDEN_UNITS = 32

In [71]:
class CategoricalWrapper(torch.distributions.Categorical):
    def log_probs(self, actions):
        return super.log_prob(actions.squeeze(-1)).view(actions.size(0), -1).sum(-1).unsqueeze(-1)
    
    def sample(self):
        return super().sample().unsqueeze(-1)


In [82]:
class ActorCritic(nn.Module):
    def __init__(self, obs_shape, action_space, discrete=True):
        super(ActorCritic, self).__init__()
        
        self.obs_shape = obs_shape
        self.action_space = action_space
        
        self._is_discrete = discrete
        self._num_inputs = obs_shape[0]
        self._n_hidden = NUM_HIDDEN_UNITS
        if discrete:
            self._num_outputs = action_space.n
        else:
            self._num_outputs = action_space.shape[0]
        
        self.actor = nn.Sequential(
            nn.Linear(self._num_inputs, self._n_hidden),
            nn.Tanh(),
            nn.Linear(self._n_hidden, self._n_hidden),
            nn.Tanh(),
            nn.Linear(self._n_hidden, self._num_outputs)
        )
        
        self.critic = nn.Sequential(
            nn.Linear(self._num_inputs, self._n_hidden),
            nn.Tanh(),
            nn.Linear(self._n_hidden, self._n_hidden),
            nn.Tanh(),
            nn.Linear(self._n_hidden, 1)
        )
        
    
    def forward(self, x):
        # forward through networks
        actions = self.actor(x)
        critique = self.critic(x)
        
        return actions, critique
    
    def act(self, x):
        actions, value = self(x)
        
        # sample
        dist = CategoricalWrapper(logits=actions)
        action = self.dist.sample()
        
        # generate log probabilities
        log_prob = dist.log_probs(action)
        return value, action, log_prob
        
    def evaluate_action(self, x, action):
        actions, value = self(x)
        
        dist = CategoricalWrapper(logits=actions)
        
        # generate log probabilities
        log_prob = dist.log_probs(action)
        return value, log_prob
    
    def get_value(self, x):
        _, value = self(x)
        return value


In [83]:
# Adapted from Ikostrikov's A2C Rollout Storage
class RolloutStorage(object):
    def __init__(self, num_steps, obs_shape, action_space, num_processes=1):
        self.obs = torch.zeros(num_steps + 1, num_processes, *obs_shape)
        
        self.rewards = torch.zeros(num_steps, num_processes, 1)
        self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
        
        self.returns = torch.zeros(num_steps + 1, num_processes, 1)
        self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
        
        if action_space.__class__.__name__ == 'Discrete':
            action_shape = 1
        else:
            action_shape = action_space.shape[0]
        
        self.actions = torch.zeros(num_steps, num_processes, action_shape)
        
        if action_space.__class__.__name__ == 'Discrete':
            self.actions = self.actions.long()
        
        self.masks = torch.ones(num_steps + 1, num_processes, 1)

        # Masks that indicate whether it's a true terminal state
        # or time limit end state
        # self.bad_masks = torch.ones(num_steps + 1, num_processes, 1)

        self.num_steps = num_steps
        self.step = 0

    def to(self, device):
        self.obs = self.obs.to(device)
        
        self.rewards = self.rewards.to(device)
        self.value_preds = self.value_preds.to(device)
        self.returns = self.returns.to(device)
        
        self.action_log_probs = self.action_log_probs.to(device)
        self.actions = self.actions.to(device)
        self.masks = self.masks.to(device)

    def insert(self, obs, actions, action_log_probs, value_preds, rewards, masks, bad_masks):
        self.obs[self.step + 1].copy_(obs)
        
        self.actions[self.step].copy_(actions)
        self.action_log_probs[self.step].copy_(action_log_probs)
        self.value_preds[self.step].copy_(value_preds)
        
        self.rewards[self.step].copy_(rewards)
        self.masks[self.step + 1].copy_(masks)

        self.step = (self.step + 1) % self.num_steps

    def after_update(self):
        self.obs[0].copy_(self.obs[-1])
        self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1])
        self.masks[0].copy_(self.masks[-1])

    def compute_returns(self, next_value, gamma):
        self.returns[-1] = next_value
        
        for step in reversed(range(self.rewards.size(0))):
            self.returns[step] = (self.returns[step + 1] * \
                gamma * self.masks[step + 1] + self.rewards[step])


In [84]:
class Optimizer:
    def __init__(self, model, lr=7e-4, eps=1e-5, alpha=0.99):
        self.optimizer = optim.RMSprop(model.parameters(), lr, eps=eps, alpha=alpha)
    
    # adapted from Ikostrikov's A2C
    def update(self, rollouts):
        obs_shape = rollouts.obs.size()[2:]
        action_shape = rollouts.actions.size()[-1]
        num_steps, num_processes, _ = rollouts.rewards.size()
        
        values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions(
            rollouts.obs[:-1].view(-1, *obs_shape),
            rollouts.recurrent_hidden_states[0].view(
                -1, self.actor_critic.recurrent_hidden_state_size),
            rollouts.masks[:-1].view(-1, 1),
            rollouts.actions.view(-1, action_shape)
        )

        values = values.view(num_steps, num_processes, 1)
        action_log_probs = action_log_probs.view(num_steps, num_processes, 1)

        advantages = rollouts.returns[:-1] - values
        value_loss = advantages.pow(2).mean()

        action_loss = -(advantages.detach() * action_log_probs).mean()
        self.optimizer.zero_grad()
        (value_loss + action_loss).backward()
        
        nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 0.5)

        self.optimizer.step()

        return value_loss.item(), action_loss.item()


In [85]:
a = ActorCritic((3, 1), spaces.Discrete(5))
print(a)
o = Optimizer(a, 0.01)

ActorCritic(
  (actor): Sequential(
    (0): Linear(in_features=3, out_features=32, bias=True)
    (1): Tanh()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): Tanh()
    (4): Linear(in_features=32, out_features=5, bias=True)
  )
  (critic): Sequential(
    (0): Linear(in_features=3, out_features=32, bias=True)
    (1): Tanh()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): Tanh()
    (4): Linear(in_features=32, out_features=1, bias=True)
  )
)


## Gym Environment Test

## Environment

In [78]:
class SustainableEnvironment(gym.Env):
    def __init__(self, max_land_resources = 100):
        self.action_space = spaces.Discrete(11)
        self.max_land_resources = max_land_resources
        self.high = np.array(
            [
                2*self.max_land_resources,
                self.max_land_resources
            ],
            dtype=np.int32,
        )
        self.low = np.array(
            [
                0,
                0
            ],
            dtype=np.int32
        )
        self.observation_space = spaces.Box(self.low, self.high, dtype=np.int32)
    
    def reset(self):
        self.state = [random.randint(2, 2*self.max_land_resources), random.randint(0, self.max_land_resources)]
        return np.array(self.state, dtype=np.int32)
    
    def step(self, action):
        percent_land_use = action/10
        land_utilized = int(self.state[1] * percent_land_use)
        self.state[1] = self.state[1] - land_utilized
        population_left_to_feed = max(0, self.state[0] - land_utilized)
        pass
    
    def render(self):
        pass

In [80]:
env = SustainableEnvironment()
print(env.reset())

[112  90]
