In [None]:
import torch
import torch.nn as nn
from torch.distributions import Categorical

import numpy as np

import gym
import pandemic_simulator as ps
import pickle

use_cuda = torch.cuda.is_available()
device   = torch.device("cuda" if use_cuda else "cpu")

In [None]:
class RolloutBuffer:
    def __init__(self):
        self.actions = []
        self.states = []
        self.log_prob = []
        self.rewards = []
        self.terminals = []
    
    def clear(self):
        del self.actions[:]
        del self.states[:]
        del self.log_p[:]
        del self.rewards[:]
        del self.terminals[:]

In [None]:
class ActorCritic(nn.Module):
    def __init__(self, num_states, num_actions):
        super(ActorCritic, self).__init__()
        self.num_actions = num_actions
        self.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device)

        # actor
        self.actor = nn.Sequential(nn.Linear(num_states, 64), 
                        nn.Tanh(),
                        nn.Linear(64, 32),
                        nn.Tanh(),
                        nn.Linear(32, num_actions),
                        nn.Softmax(dim=-1)
                        )
        
        # critic
        self.critic = nn.Sequential(
                        nn.Linear(num_states, 64),
                        nn.Tanh(),
                        nn.Linear(64, 32),
                        nn.Tanh(),
                        nn.Linear(32, 1)
                    ) 

    def act(self, state, buffer):

        action_prob = self.actor(state)
        dist = Categorical(action_prob)

        action = dist.sample()
        action_logprob = dist.log_prob(action)
        
        buffer.states.append(state)
        buffer.actions.append(action)
        buffer.log_probs.append(action_logprob)
        
        return action.item()
    

    def evaluate(self, state, action):

        action_prob = self.actor(state)
        dist = Categorial(action_prob)
        dist_entropy = dist.entropy()
        action_logprob = dist.log_prob(action)
        state_value = self.critic(state)
        
        return action_logprob, torch.squeeze(state_value), dist_entropy

In [None]:
class PPO:
    def __init__(self, num_states, num_actions, lr, betas, gamma, K_epochs, eps_clip):
        self.lr = lr
        self.betas = betas
        self.gamma = gamma
        self.K_epochs = K_epochs
        self.eps_clip = eps_clip
        
        # current policy
        self.policy = ActorCritic(num_states, num_actions).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)
        
        # old policy
        self.policy_old = ActorCritic(num_states, num_actions).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        
        # loss function
        self.loss_fcn = nn.MSELoss()

    def select_action(self, state, buffer):

        state = torch.FloatTensor(state.reshape(1, -1)).to(device)

        return self.policy_old.act(state, buffer)


    def update(self):

        # Monte Carlo estimate
        rewards = []
        discounted_reward = 0
        for reward, terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.terminals)):
            if terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)
            
        # Normalizing the rewards
        rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

        # convert list to tensor
        old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)
        old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)
        old_logprobs = torch.squeeze(torch.stack(self.buffer.log_p, dim=0)).detach().to(device)

        # Train policy
        for _ in range(self.K_epochs):

            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
            
            # importance ratio p/q
            ratios = torch.exp(logprobs - old_logprobs.detach())

            advantages = rewards - state_values.detach()
            # actor loss using surrogate loss
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
            actor_loss = - torch.min(surr1, surr2)
            # critic loss
            critic_loss = 0.5 * self.loss_fcn(rewards, state_values) - 0.01 * dist_entropy

            # total loss of clipped objective PPO
            loss = actor_loss + critic_loss
            
            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
            
        # Copy new weights into old policy
        self.policy_old.load_state_dict(self.policy.state_dict())

In [None]:
def train(env, num_states, num_actions, max_episodes, max_timesteps, update_timestep, K_epochs, eps_clip, 
          gamma, lr, betas, print_interval=10,):

    buffer = RolloutBuffer()

    ppo = PPO(num_states, num_actions, lr, betas, gamma, K_epochs, eps_clip)

    running_reward, avg_length, time_step = 0, 0, 0

    # training loop
    for episode in range(1, max_episodes+1):
        state = env.reset()
        for t in range(max_timesteps):
            time_step += 1

            # Run old policy
            action = ppo.select_action(state, memory)

            state, reward, terminal, _ = env.step(action)

            buffer.rewards.append(reward)
            buffer.terminals.append(done)

            if time_step % update_timestep == 0:
                ppo.update(buffer)
                buffer.clear()
                time_step = 0

            running_reward += reward

            if terminal:
                break

        avg_length += t
        
        viz.record((state, reward))
        if episode % print_interval == 0:
            avg_length = int(avg_length / print_interval)
            running_reward = int((running_reward / print_interval))

            print('Episode {} \t Avg length: {} \t Avg reward: {}'.format(episode, avg_length, running_reward))

            running_reward, avg_length = 0, 0
        if episode < max_episodes+1:
            viz = ps.viz.GymViz.from_config(sim_config=sim_config)
            

    viz.plot()
    
    with open('results.pickle', 'wb') as handle:
        pickle.dump(viz, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# torch.manual_seed(10)
# np.random.seed(10)
ps.init_globals(0)

# select a simulator config
sim_config = ps.sh.small_town_config
# make environment
env = ps.env.PandemicGymEnv.from_config(sim_config, pandemic_regulations=ps.sh.austin_regulations)

# setup viz
viz = ps.viz.GymViz.from_config(sim_config=sim_config)

# num_states = env.observation_space.shape[0]
# num_actions = env.action_space.n
num_states = 12
num_actions = 5

# HYPER PARAMS
max_episodes = 100000
max_timesteps = 1500
update_timestep = 4000
K_epochs = 80
eps_clip = 0.2
gamma = 0.99
lr = 1e-3

train(env, num_states, num_actions, max_episodes, max_timesteps, update_timestep, K_epochs, eps_clip, gamma, lr, betas=[0.9, 0.990], print_interval=10)