# Getting started with PPO and ProcGen

Here's a bit of code that should help you get started on your projects.

The cell below installs `procgen` and downloads a small `utils.py` script that contains some utility functions. You may want to inspect the file for more details.

In [1]:
!pip install procgen
!wget https://raw.githubusercontent.com/nicklashansen/ppo-procgen-utils/main/utils.py

Collecting procgen
[?25l  Downloading https://files.pythonhosted.org/packages/d6/34/0ae32b01ec623cd822752e567962cfa16ae9c6d6ba2208f3445c017a121b/procgen-0.10.4-cp36-cp36m-manylinux2010_x86_64.whl (39.9MB)
[K     |████████████████████████████████| 39.9MB 82kB/s 
Collecting gym3<1.0.0,>=0.3.3
[?25l  Downloading https://files.pythonhosted.org/packages/89/8c/83da801207f50acfd262041e7974f3b42a0e5edd410149d8a70fd4ad2e70/gym3-0.3.3-py3-none-any.whl (50kB)
[K     |████████████████████████████████| 51kB 9.0MB/s 
Collecting moderngl<6.0.0,>=5.5.4
[?25l  Downloading https://files.pythonhosted.org/packages/56/ab/5f72a1b7c5bdbb17160c85e8ba855d48925c74ff93c1e1027d5ad40bf33c/moderngl-5.6.2-cp36-cp36m-manylinux1_x86_64.whl (664kB)
[K     |████████████████████████████████| 665kB 56.1MB/s 
[?25hCollecting imageio<3.0.0,>=2.6.0
[?25l  Downloading https://files.pythonhosted.org/packages/6e/57/5d899fae74c1752f52869b613a8210a2480e1a69688e65df6cb26117d45d/imageio-2.9.0-py3-none-any.whl (3.3MB)
[K   

Hyperparameters. These values should be a good starting point. You can modify them later once you have a working implementation.

In [2]:
# Hyperparameters
total_steps = 10e6
num_envs = 32
num_levels = 100
num_steps = 256
num_epochs = 3
batch_size = 512
eps = .2
grad_eps = .5
value_coef = .5
entropy_coef = .01 #.01

#####################################
n_stack = 3 # =1 for default
#####################################

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *# make_env, Storage, orthogonal_init
from google.colab import files

###############################################################################  
def make_env(
	n_envs=32,
	env_name='starpilot',
	start_level=0,
	num_levels=100,
	use_backgrounds=True,
	normalize_obs=False,
	normalize_reward=True,
	seed=0,

	):
	"""Make environment for procgen experiments"""
	set_global_seeds(seed)
	set_global_log_levels(40)
	env = ProcgenEnv(
		num_envs=n_envs,
		env_name=env_name,
		start_level=start_level,
		num_levels=num_levels,
		distribution_mode='easy',
		use_backgrounds=use_backgrounds,
		restrict_themes=not use_backgrounds,
		render_mode='rgb_array',
		rand_seed=seed
	)
    
  
	env = VecExtractDictObs(env, "rgb")
	if n_stack >=2 :
		env = VecFrameStack(env, n_stack)
	env = VecNormalize(env, ob=normalize_obs, ret=normalize_reward)
	env = TransposeFrame(env)
	env = ScaledFloatFrame(env)
	env = TensorEnv(env)
	

	return env
###############################################################################


def xavier_uniform_init(module, gain=1.0):
    if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
        nn.init.xavier_uniform_(module.weight.data, gain)
        nn.init.constant_(module.bias.data, 0)
    return module



class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)



class ResidualBlock(nn.Module):
    def __init__(self,
                 in_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        out = nn.ReLU()(x)
        out = self.conv1(out)
        out = nn.ReLU()(out)
        out = self.conv2(out)
        return out + x

class ImpalaBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ImpalaBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.res1 = ResidualBlock(out_channels)
        self.res2 = ResidualBlock(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)(x)
        x = self.res1(x)
        x = self.res2(x)
        return x

class ImpalaModel(nn.Module):
    def __init__(self,
                 in_channels,
                 **kwargs):
        super(ImpalaModel, self).__init__()
        self.block1 = ImpalaBlock(in_channels=in_channels, out_channels=16)
        self.block2 = ImpalaBlock(in_channels=16, out_channels=32)
        self.block3 = ImpalaBlock(in_channels=32, out_channels=32)
        self.fc = nn.Linear(in_features=32 * 8 * 8, out_features=256)

        self.output_dim = 256
        self.apply(xavier_uniform_init)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = nn.ReLU()(x)
        x = Flatten()(x)
        x = self.fc(x)
        x = nn.ReLU()(x)
        return x


class Encoder(nn.Module):
  def __init__(self, in_channels, feature_dim):
    super().__init__()
    self.layers = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=8, stride=4), nn.ReLU(),
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), nn.ReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), nn.ReLU(),
        Flatten(),
        nn.Linear(in_features=1024, out_features=feature_dim), nn.ReLU()
    )
    self.apply(orthogonal_init)

  def forward(self, x):
    return self.layers(x)


class Policy(nn.Module):
  def __init__(self, encoder, feature_dim, num_actions):
    super().__init__()
    self.encoder = encoder
    self.policy = orthogonal_init(nn.Linear(feature_dim, num_actions), gain=.01)
    self.value = orthogonal_init(nn.Linear(feature_dim, 1), gain=1.)

  def act(self, x):
    with torch.no_grad():
      x = x.cuda().contiguous()
      dist, value = self.forward(x)
      action = dist.sample()
      log_prob = dist.log_prob(action)
    
    return action.cpu(), log_prob.cpu(), value.cpu()

  def forward(self, x):
    x = self.encoder(x)
    logits = self.policy(x)
    value = self.value(x).squeeze(1)
    dist = torch.distributions.Categorical(logits=logits)

    return dist, value

########################################################################################################################################

print("starting to make env")

# Define environmentbossfight
# check the utils.py file for info on arguments
env = make_env(num_envs, num_levels=num_levels, env_name='starpilot')
print('Observation space:', env.observation_space)
print('Action space:', env.action_space.n)

# Define network
#encoder = Encoder(3,512)
observation_shape = env.observation_space.shape
in_channels = observation_shape[0]
encoder = ImpalaModel(in_channels=in_channels)
policy = Policy(encoder, 256, 15)
policy.load_state_dict(torch.load("checkpoint_Final.pt"))
policy.cuda()
# Define optimizer
# these are reasonable values but probably not optimal
optimizer = torch.optim.Adam(policy.parameters(), lr=5e-4, eps=1e-5) #lr=5e-4 , eps=1e-5

# Define temporary storage
# we use this to collect transitions during each iteration
storage = Storage(
    env.observation_space.shape,
    num_steps,
    num_envs
)


# Run training
obs = env.reset()
step = 5005312
i = 10
print("NN setup, Training Starts")
while step < total_steps:

  # Use policy to collect data for num_steps steps
  policy.eval()
  for _ in range(num_steps):
    # Use policy
    action, log_prob, value = policy.act(obs)
    
    # Take step in environment
    next_obs, reward, done, info = env.step(action)

    # Store data
    storage.store(obs, action, reward, done, info, log_prob, value)
    
    # Update current observation
    obs = next_obs

  # Add the last observation to collected data
  _, _, value = policy.act(obs)
  storage.store_last(obs, value)

  # Compute return and advantage
  storage.compute_return_advantage()

  # Optimize policy
  policy.train()
  for epoch in range(num_epochs):

    # Iterate over batches of transitions
    generator = storage.get_generator(batch_size)
    for batch in generator:
      b_obs, b_action, b_log_prob, b_value, b_returns, b_advantage = batch

      # Get current policy outputs
      new_dist, new_value = policy(b_obs)
      new_log_prob = new_dist.log_prob(b_action)
      # log_prob
      # Clipped policy objective
      #print(str(log_prob.shape) + " " + str(b_log_prob.shape) + " " + str(new_log_prob.shape))
      ratio = torch.exp(new_log_prob - b_log_prob)
      
      clipped_ratio = ratio.clamp(min=1.0 - eps, max=1.0 + eps) 
      policy_reward = torch.min(ratio * b_advantage, clipped_ratio * b_advantage)
      #clip_fraction = (abs((ratio - 1.0)) > clip).to(torch.float).mean()
      pi_loss = -policy_reward.mean()

      # Clipped value function objective
      # clipped_value = new_value + (b_value - new_value).clamp(min=-eps,max=eps)
      # vf_loss=torch.max((b_value-b_returns)**2, (clipped_value-b_returns)**2)
      # value_loss = 0.5 * vf_loss.mean()

      clipped_value = b_value + (new_value - b_value).clamp(min=-eps,max=eps) 

      clipped_value = (new_value - b_value).clamp(min=-eps,max=eps)
      value_loss = 0.5 * torch.max(torch.pow(new_value - b_returns,2), torch.pow(clipped_value - b_returns, 2)).mean()

      # Entropy loss
      entropy_loss = new_dist.entropy().mean()

      # Backpropagate losses
      # loss = torch.mean(pi_loss+value_coef*value_loss+entropy_coef*entropy_loss) #
      loss = pi_loss + value_coef * value_loss - entropy_coef * entropy_loss
      loss.backward()

      # Clip gradients
      torch.nn.utils.clip_grad_norm_(policy.parameters(), grad_eps)

      # Update policy
      optimizer.step()
      optimizer.zero_grad()

  # Update stats
  step += num_envs * num_steps
  print(f'Step: {step}\tMean reward: {storage.get_reward()}')

  if step//503808>=i:
    torch.save(policy.state_dict(), 'checkpoint'+str(i)+'.pt')
    files.download('checkpoint'+str(i)+'.pt')
    i=i+1

print('Completed training!')
torch.save(policy.state_dict(), 'checkpoint_Final_Final.pt')


files.download('checkpoint_Final_Final.pt')

starting to make env
Observation space: Box(0.0, 1.0, (9, 64, 64), float32)
Action space: 15
NN setup, Training Starts
Step: 5013504	Mean reward: 9.5
Step: 5021696	Mean reward: 10.125
Step: 5029888	Mean reward: 10.5
Step: 5038080	Mean reward: 10.84375


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Step: 5046272	Mean reward: 11.25
Step: 5054464	Mean reward: 10.40625
Step: 5062656	Mean reward: 8.8125
Step: 5070848	Mean reward: 8.6875
Step: 5079040	Mean reward: 10.25
Step: 5087232	Mean reward: 9.5
Step: 5095424	Mean reward: 9.5625
Step: 5103616	Mean reward: 10.0
Step: 5111808	Mean reward: 10.53125
Step: 5120000	Mean reward: 9.59375
Step: 5128192	Mean reward: 9.46875
Step: 5136384	Mean reward: 9.90625
Step: 5144576	Mean reward: 9.03125
Step: 5152768	Mean reward: 10.09375
Step: 5160960	Mean reward: 9.78125
Step: 5169152	Mean reward: 10.46875
Step: 5177344	Mean reward: 10.03125
Step: 5185536	Mean reward: 9.21875
Step: 5193728	Mean reward: 10.0
Step: 5201920	Mean reward: 11.03125
Step: 5210112	Mean reward: 9.46875
Step: 5218304	Mean reward: 9.625
Step: 5226496	Mean reward: 9.0
Step: 5234688	Mean reward: 9.46875
Step: 5242880	Mean reward: 10.40625
Step: 5251072	Mean reward: 11.8125
Step: 5259264	Mean reward: 9.8125
Step: 5267456	Mean reward: 9.6875
Step: 5275648	Mean reward: 9.125
Step:

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Step: 5554176	Mean reward: 9.40625
Step: 5562368	Mean reward: 10.21875
Step: 5570560	Mean reward: 10.40625
Step: 5578752	Mean reward: 9.875
Step: 5586944	Mean reward: 9.46875
Step: 5595136	Mean reward: 8.125
Step: 5603328	Mean reward: 8.125
Step: 5611520	Mean reward: 9.34375
Step: 5619712	Mean reward: 9.25
Step: 5627904	Mean reward: 9.3125
Step: 5636096	Mean reward: 8.28125
Step: 5644288	Mean reward: 10.03125
Step: 5652480	Mean reward: 9.375
Step: 5660672	Mean reward: 9.625
Step: 5668864	Mean reward: 8.9375
Step: 5677056	Mean reward: 9.25
Step: 5685248	Mean reward: 10.34375
Step: 5693440	Mean reward: 9.25
Step: 5701632	Mean reward: 9.6875
Step: 5709824	Mean reward: 10.90625
Step: 5718016	Mean reward: 10.15625
Step: 5726208	Mean reward: 9.40625
Step: 5734400	Mean reward: 9.6875
Step: 5742592	Mean reward: 9.46875
Step: 5750784	Mean reward: 10.90625
Step: 5758976	Mean reward: 10.09375
Step: 5767168	Mean reward: 9.71875
Step: 5775360	Mean reward: 9.46875
Step: 5783552	Mean reward: 9.59375


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Step: 6053888	Mean reward: 8.96875
Step: 6062080	Mean reward: 10.40625
Step: 6070272	Mean reward: 9.8125
Step: 6078464	Mean reward: 9.03125
Step: 6086656	Mean reward: 8.5625
Step: 6094848	Mean reward: 8.8125
Step: 6103040	Mean reward: 10.09375
Step: 6111232	Mean reward: 9.59375
Step: 6119424	Mean reward: 10.375
Step: 6127616	Mean reward: 9.75
Step: 6135808	Mean reward: 8.28125
Step: 6144000	Mean reward: 8.9375
Step: 6152192	Mean reward: 8.34375
Step: 6160384	Mean reward: 9.875
Step: 6168576	Mean reward: 9.65625
Step: 6176768	Mean reward: 8.59375
Step: 6184960	Mean reward: 10.5625
Step: 6193152	Mean reward: 9.78125
Step: 6201344	Mean reward: 9.5
Step: 6209536	Mean reward: 10.03125
Step: 6217728	Mean reward: 9.1875
Step: 6225920	Mean reward: 10.0
Step: 6234112	Mean reward: 9.75
Step: 6242304	Mean reward: 9.5
Step: 6250496	Mean reward: 8.78125
Step: 6258688	Mean reward: 11.09375
Step: 6266880	Mean reward: 10.0
Step: 6275072	Mean reward: 10.46875
Step: 6283264	Mean reward: 8.1875
Step: 629

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Step: 6561792	Mean reward: 9.90625
Step: 6569984	Mean reward: 9.9375
Step: 6578176	Mean reward: 9.71875
Step: 6586368	Mean reward: 9.75
Step: 6594560	Mean reward: 10.25
Step: 6602752	Mean reward: 11.5625
Step: 6610944	Mean reward: 10.40625
Step: 6619136	Mean reward: 9.75
Step: 6627328	Mean reward: 11.09375
Step: 6635520	Mean reward: 7.96875
Step: 6643712	Mean reward: 8.90625
Step: 6651904	Mean reward: 10.53125
Step: 6660096	Mean reward: 8.8125
Step: 6668288	Mean reward: 9.9375
Step: 6676480	Mean reward: 9.84375
Step: 6684672	Mean reward: 10.15625
Step: 6692864	Mean reward: 10.71875
Step: 6701056	Mean reward: 10.75
Step: 6709248	Mean reward: 9.6875
Step: 6717440	Mean reward: 8.09375
Step: 6725632	Mean reward: 9.9375
Step: 6733824	Mean reward: 9.5
Step: 6742016	Mean reward: 9.15625
Step: 6750208	Mean reward: 11.09375
Step: 6758400	Mean reward: 10.40625
Step: 6766592	Mean reward: 10.59375
Step: 6774784	Mean reward: 10.03125
Step: 6782976	Mean reward: 11.28125
Step: 6791168	Mean reward: 10

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Step: 7061504	Mean reward: 11.5
Step: 7069696	Mean reward: 8.71875
Step: 7077888	Mean reward: 10.65625
Step: 7086080	Mean reward: 10.0625
Step: 7094272	Mean reward: 10.125
Step: 7102464	Mean reward: 10.21875
Step: 7110656	Mean reward: 9.9375
Step: 7118848	Mean reward: 8.65625
Step: 7127040	Mean reward: 9.09375
Step: 7135232	Mean reward: 10.625
Step: 7143424	Mean reward: 10.375
Step: 7151616	Mean reward: 9.3125
Step: 7159808	Mean reward: 9.84375
Step: 7168000	Mean reward: 9.71875
Step: 7176192	Mean reward: 9.0
Step: 7184384	Mean reward: 10.8125
Step: 7192576	Mean reward: 9.9375
Step: 7200768	Mean reward: 9.90625
Step: 7208960	Mean reward: 9.6875
Step: 7217152	Mean reward: 8.96875
Step: 7225344	Mean reward: 11.375
Step: 7233536	Mean reward: 10.75
Step: 7241728	Mean reward: 10.0
Step: 7249920	Mean reward: 10.90625
Step: 7258112	Mean reward: 9.40625
Step: 7266304	Mean reward: 9.9375
Step: 7274496	Mean reward: 9.5
Step: 7282688	Mean reward: 10.15625
Step: 7290880	Mean reward: 10.46875
Step:

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Step: 7569408	Mean reward: 9.46875
Step: 7577600	Mean reward: 10.0
Step: 7585792	Mean reward: 9.75
Step: 7593984	Mean reward: 8.625
Step: 7602176	Mean reward: 9.875
Step: 7610368	Mean reward: 9.03125
Step: 7618560	Mean reward: 9.375
Step: 7626752	Mean reward: 9.90625
Step: 7634944	Mean reward: 10.1875
Step: 7643136	Mean reward: 9.53125
Step: 7651328	Mean reward: 9.25
Step: 7659520	Mean reward: 9.375
Step: 7667712	Mean reward: 8.96875
Step: 7675904	Mean reward: 9.46875
Step: 7684096	Mean reward: 10.0
Step: 7692288	Mean reward: 9.6875
Step: 7700480	Mean reward: 9.5625
Step: 7708672	Mean reward: 9.28125
Step: 7716864	Mean reward: 10.25
Step: 7725056	Mean reward: 10.84375
Step: 7733248	Mean reward: 9.75
Step: 7741440	Mean reward: 9.96875
Step: 7749632	Mean reward: 10.46875
Step: 7757824	Mean reward: 10.53125
Step: 7766016	Mean reward: 9.34375
Step: 7774208	Mean reward: 10.28125
Step: 7782400	Mean reward: 11.3125
Step: 7790592	Mean reward: 10.90625
Step: 7798784	Mean reward: 9.53125
Step: 7

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Step: 8069120	Mean reward: 8.75
Step: 8077312	Mean reward: 8.40625
Step: 8085504	Mean reward: 10.125
Step: 8093696	Mean reward: 9.5
Step: 8101888	Mean reward: 9.5
Step: 8110080	Mean reward: 9.0
Step: 8118272	Mean reward: 10.25
Step: 8126464	Mean reward: 7.3125
Step: 8134656	Mean reward: 10.0625
Step: 8142848	Mean reward: 10.40625
Step: 8151040	Mean reward: 9.5
Step: 8159232	Mean reward: 8.6875
Step: 8167424	Mean reward: 8.96875
Step: 8175616	Mean reward: 8.375
Step: 8183808	Mean reward: 9.625
Step: 8192000	Mean reward: 8.40625
Step: 8200192	Mean reward: 10.09375
Step: 8208384	Mean reward: 9.9375
Step: 8216576	Mean reward: 9.625
Step: 8224768	Mean reward: 10.6875
Step: 8232960	Mean reward: 9.09375
Step: 8241152	Mean reward: 10.90625
Step: 8249344	Mean reward: 10.4375
Step: 8257536	Mean reward: 8.9375
Step: 8265728	Mean reward: 10.75
Step: 8273920	Mean reward: 11.28125
Step: 8282112	Mean reward: 11.96875
Step: 8290304	Mean reward: 9.75
Step: 8298496	Mean reward: 9.40625
Step: 8306688	Mea

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Step: 8577024	Mean reward: 8.6875
Step: 8585216	Mean reward: 9.03125
Step: 8593408	Mean reward: 11.0625
Step: 8601600	Mean reward: 8.53125
Step: 8609792	Mean reward: 10.09375
Step: 8617984	Mean reward: 10.28125
Step: 8626176	Mean reward: 9.53125
Step: 8634368	Mean reward: 8.96875
Step: 8642560	Mean reward: 9.34375
Step: 8650752	Mean reward: 8.84375
Step: 8658944	Mean reward: 8.71875
Step: 8667136	Mean reward: 9.8125
Step: 8675328	Mean reward: 9.84375
Step: 8683520	Mean reward: 10.15625
Step: 8691712	Mean reward: 9.90625
Step: 8699904	Mean reward: 8.25
Step: 8708096	Mean reward: 8.71875
Step: 8716288	Mean reward: 9.75
Step: 8724480	Mean reward: 9.6875
Step: 8732672	Mean reward: 9.53125
Step: 8740864	Mean reward: 9.09375
Step: 8749056	Mean reward: 8.5625
Step: 8757248	Mean reward: 9.1875
Step: 8765440	Mean reward: 9.03125
Step: 8773632	Mean reward: 8.625
Step: 8781824	Mean reward: 9.03125
Step: 8790016	Mean reward: 10.09375
Step: 8798208	Mean reward: 8.5625
Step: 8806400	Mean reward: 9.7

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Step: 9076736	Mean reward: 9.15625
Step: 9084928	Mean reward: 8.25
Step: 9093120	Mean reward: 8.6875
Step: 9101312	Mean reward: 9.6875
Step: 9109504	Mean reward: 9.25
Step: 9117696	Mean reward: 10.0625
Step: 9125888	Mean reward: 10.25
Step: 9134080	Mean reward: 9.09375
Step: 9142272	Mean reward: 9.34375
Step: 9150464	Mean reward: 8.9375
Step: 9158656	Mean reward: 9.46875
Step: 9166848	Mean reward: 8.84375
Step: 9175040	Mean reward: 9.625
Step: 9183232	Mean reward: 9.0
Step: 9191424	Mean reward: 9.9375
Step: 9199616	Mean reward: 8.75
Step: 9207808	Mean reward: 8.71875
Step: 9216000	Mean reward: 8.8125
Step: 9224192	Mean reward: 9.03125
Step: 9232384	Mean reward: 8.5625
Step: 9240576	Mean reward: 8.5
Step: 9248768	Mean reward: 9.1875
Step: 9256960	Mean reward: 9.09375
Step: 9265152	Mean reward: 9.4375
Step: 9273344	Mean reward: 8.78125
Step: 9281536	Mean reward: 10.46875
Step: 9289728	Mean reward: 10.125
Step: 9297920	Mean reward: 8.84375
Step: 9306112	Mean reward: 9.59375
Step: 9314304	

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Step: 9584640	Mean reward: 10.125
Step: 9592832	Mean reward: 9.65625
Step: 9601024	Mean reward: 9.15625


Network definitions. We have defined a policy network for you in advance. It uses the popular `NatureDQN` encoder architecture (see below), while policy and value functions are linear projections from the encodings. There is plenty of opportunity to experiment with architectures, so feel free to do that! Perhaps implement the `Impala` encoder from [this paper](https://arxiv.org/pdf/1802.01561.pdf) (perhaps minus the LSTM).

Below cell can be used for policy evaluation and saves an episode to mp4 for you to view.

In [None]:
import imageio

# Make evaluation environment
eval_env = make_env(num_envs, env_name = 'starpilot',use_backgrounds=True ,start_level=num_levels, num_levels=num_levels)
obs = eval_env.reset()

frames = []
total_reward = []

# Evaluate policy
policy.eval()
for _ in range(512):

  # Use policy
  action, log_prob, value = policy.act(obs)

  # Take step in environment
  obs, reward, done, info = eval_env.step(action)
  total_reward.append(torch.Tensor(reward))

  # Render environment and store
  frame = (torch.Tensor(eval_env.render(mode='rgb_array'))*255.).byte()
  frames.append(frame)

# Calculate average return
total_reward = torch.stack(total_reward).sum(0).mean(0)
print('Average return:', total_reward)

# Save frames as video
frames = torch.stack(frames).numpy()
imageio.mimsave('vid_stack100.mp4', frames, fps=25)

from google.colab import files
files.download('vid_stack100.mp4')


#
#Step: 10006528	Mean reward: 12.428571701049805
#Completed training!

#Average return: tensor(16.3039)