In [55]:
import argparse
import os
import random
import time
from distutils.util import strtobool

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
import cv2 as cv
import matplotlib.pyplot as plt

In [57]:
def linearInitialize(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Agent(nn.Module):
    def __init__(self, envs):
        super(Agent, self).__init__()
        self.featureExtractor = nn.Sequential(
            linearInitialize(nn.Conv2D(4, 32, 8, stride=4)),
            nn.ReLU(),
            linearInitialize(nn.Conv2D(32, 64, 4, stride=2)),
            nn.ReLU(),
            linearInitialize(nn.Conv2D(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            linearInitialize(nn.Linear(64*7*7), 512),
            nn.ReLU())

        self.actor = linearInitialize(nn.Linear(512, envs.single_action_space.n), std=0.01)
        self.critic = linearInitialize(nn.Linear(512, 1), std=1)

    def get_value(self, x):
        return self.critic(self.featureExtractor(x))

    def get_action_and_value(self, x, action=None):
        x = x / 255 # scale pixels
        logits = self.actor(self.featureExtractor(x))
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(self.featureExtractor(x))