# Getting started with PPO and ProcGen

Here's a bit of code that should help you get started on your projects.

The cell below installs `procgen` and downloads a small `utils.py` script that contains some utility functions. You may want to inspect the file for more details.

In [None]:
!pip install procgen
!wget https://raw.githubusercontent.com/nicklashansen/ppo-procgen-utils/main/utils.py
!wget https://raw.githubusercontent.com/MishaLaskin/rad/1246bfd6e716669126e12c1f02f393801e1692c1/TransformLayer.py

Collecting procgen
[?25l  Downloading https://files.pythonhosted.org/packages/d6/34/0ae32b01ec623cd822752e567962cfa16ae9c6d6ba2208f3445c017a121b/procgen-0.10.4-cp36-cp36m-manylinux2010_x86_64.whl (39.9MB)
[K     |████████████████████████████████| 39.9MB 83kB/s 
Collecting gym3<1.0.0,>=0.3.3
[?25l  Downloading https://files.pythonhosted.org/packages/89/8c/83da801207f50acfd262041e7974f3b42a0e5edd410149d8a70fd4ad2e70/gym3-0.3.3-py3-none-any.whl (50kB)
[K     |████████████████████████████████| 51kB 7.3MB/s 
Collecting moderngl<6.0.0,>=5.5.4
[?25l  Downloading https://files.pythonhosted.org/packages/56/ab/5f72a1b7c5bdbb17160c85e8ba855d48925c74ff93c1e1027d5ad40bf33c/moderngl-5.6.2-cp36-cp36m-manylinux1_x86_64.whl (664kB)
[K     |████████████████████████████████| 665kB 37.8MB/s 
Collecting imageio-ffmpeg<0.4.0,>=0.3.0
[?25l  Downloading https://files.pythonhosted.org/packages/1b/12/01126a2fb737b23461d7dadad3b8abd51ad6210f979ff05c6fa9812dfbbe/imageio_ffmpeg-0.3.0-py3-none-manylinux2010_

Data aug code from
https://github.com/MishaLaskin/rad/blob/1246bfd6e716669126e12c1f02f393801e1692c1/data_augs.py#L296


In [None]:
'''
dataaugs:
https://github.com/MishaLaskin/rad/blob/1246bfd6e716669126e12c1f02f393801e1692c1/data_augs.py#L296
'''
'''
paper:
https://arxiv.org/pdf/2010.10814.pdf
'''
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from TransformLayer import ColorJitterLayer


def random_crop(imgs, out=84):
    """
        args:
        imgs: np.array shape (B,C,H,W)
        out: output size (e.g. 84)
        returns np.array
    """
    n, c, h, w = imgs.shape
    crop_max = h - out + 1
    w1 = np.random.randint(0, crop_max, n)
    h1 = np.random.randint(0, crop_max, n)
    cropped = np.empty((n, c, out, out), dtype=imgs.dtype)
    for i, (img, w11, h11) in enumerate(zip(imgs, w1, h1)):
        
        cropped[i] = img[:, h11:h11 + out, w11:w11 + out]
    return cropped


def grayscale(imgs):
    # imgs: b x c x h x w
    device = imgs.device
    b, c, h, w = imgs.shape
    frames = c // 3
    
    imgs = imgs.view([b,frames,3,h,w])
    imgs = imgs[:, :, 0, ...] * 0.2989 + imgs[:, :, 1, ...] * 0.587 + imgs[:, :, 2, ...] * 0.114 
    
    imgs = imgs.type(torch.uint8).float()
    # assert len(imgs.shape) == 3, imgs.shape
    imgs = imgs[:, :, None, :, :]
    imgs = imgs * torch.ones([1, 1, 3, 1, 1], dtype=imgs.dtype).float().to(device) # broadcast tiling
    return imgs

def random_grayscale(images,p=.3):
    """
        args:
        imgs: torch.tensor shape (B,C,H,W)
        device: cpu or cuda
        returns torch.tensor
    """
    device = images.device
    in_type = images.type()
    images = images * 255.
    images = images.type(torch.uint8)
    # images: [B, C, H, W]
    bs, channels, h, w = images.shape
    images = images.to(device)
    gray_images = grayscale(images)
    rnd = np.random.uniform(0., 1., size=(images.shape[0],))
    mask = rnd <= p
    mask = torch.from_numpy(mask)
    frames = images.shape[1] // 3
    images = images.view(*gray_images.shape)
    mask = mask[:, None] * torch.ones([1, frames]).type(mask.dtype)
    mask = mask.type(images.dtype).to(device)
    mask = mask[:, :, None, None, None]
    out = mask * gray_images + (1 - mask) * images
    out = out.view([bs, -1, h, w]).type(in_type) / 255.
    return out

# random cutout
# TODO: should mask this 

def random_cutout(imgs, min_cut=10,max_cut=30):
    """
        args:
        imgs: np.array shape (B,C,H,W)
        min / max cut: int, min / max size of cutout 
        returns np.array
    """

    n, c, h, w = imgs.shape
    w1 = np.random.randint(min_cut, max_cut, n)
    h1 = np.random.randint(min_cut, max_cut, n)
    
    cutouts = np.empty((n, c, h, w), dtype=imgs.dtype)
    for i, (img, w11, h11) in enumerate(zip(imgs, w1, h1)):
        cut_img = img.copy()
        cut_img[:, h11:h11 + h11, w11:w11 + w11] = 0
        #print(img[:, h11:h11 + h11, w11:w11 + w11].shape)
        cutouts[i] = cut_img
    return cutouts

def random_cutout_color(imgs, min_cut=7,max_cut=22):
    """
        args:
        imgs: shape (B,C,H,W)
        out: output size (e.g. 84)
    """
    
    n, c, h, w = imgs.shape
    w1 = np.random.randint(min_cut, max_cut, n)
    h1 = np.random.randint(min_cut, max_cut, n)
    
    cutouts = np.empty((n, c, h, w), dtype=imgs.dtype)
    rand_box = np.random.randint(0, 255, size=(n, c)) / 255.
    for i, (img, w11, h11) in enumerate(zip(imgs, w1, h1)):
        cut_img = img.copy()
        
        # add random box
        cut_img[:, h11:h11 + h11, w11:w11 + w11] = np.tile(
            rand_box[i].reshape(-1,1,1),                                                
            (1,) + cut_img[:, h11:h11 + h11, w11:w11 + w11].shape[1:])
        
        cutouts[i] = cut_img
    return cutouts

# random flip

def random_flip(images,p=.2):
    """
        args:
        imgs: torch.tensor shape (B,C,H,W)
        device: cpu or gpu, 
        p: prob of applying aug,
        returns torch.tensor
    """
    # images: [B, C, H, W]
    device = images.device
    bs, channels, h, w = images.shape
    
    images = images.to(device)

    flipped_images = images.flip([3])
    
    rnd = np.random.uniform(0., 1., size=(images.shape[0],))
    mask = rnd <= p
    mask = torch.from_numpy(mask)
    frames = images.shape[1] #// 3
    images = images.view(*flipped_images.shape)
    mask = mask[:, None] * torch.ones([1, frames]).type(mask.dtype)
    
    mask = mask.type(images.dtype).to(device)
    mask = mask[:, :, None, None]
    
    out = mask * flipped_images + (1 - mask) * images

    out = out.view([bs, -1, h, w])
    return out

# random rotation

def random_rotation(images,p=.3):
    """
        args:
        imgs: torch.tensor shape (B,C,H,W)
        device: str, cpu or gpu, 
        p: float, prob of applying aug,
        returns torch.tensor
    """
    device = images.device
    # images: [B, C, H, W]
    bs, channels, h, w = images.shape
    
    images = images.to(device)

    rot90_images = images.rot90(1,[2,3])
    rot180_images = images.rot90(2,[2,3])
    rot270_images = images.rot90(3,[2,3])    
    
    rnd = np.random.uniform(0., 1., size=(images.shape[0],))
    rnd_rot = np.random.randint(1, 4, size=(images.shape[0],))
    mask = rnd <= p
    mask = rnd_rot * mask
    mask = torch.from_numpy(mask).to(device)
    
    frames = images.shape[1]
    masks = [torch.zeros_like(mask) for _ in range(4)]
    for i,m in enumerate(masks):
        m[torch.where(mask==i)] = 1
        m = m[:, None] * torch.ones([1, frames]).type(mask.dtype).type(images.dtype).to(device)
        m = m[:,:,None,None]
        masks[i] = m
    
    
    out = masks[0] * images + masks[1] * rot90_images + masks[2] * rot180_images + masks[3] * rot270_images

    out = out.view([bs, -1, h, w])
    return out


# random color

    

def random_convolution(imgs):
    '''
    random covolution in "network randomization"
    
    (imbs): B x (C x stack) x H x W, note: imgs should be normalized and torch tensor
    '''
    _device = imgs.device
    
    img_h, img_w = imgs.shape[2], imgs.shape[3]
    num_stack_channel = imgs.shape[1]
    num_batch = imgs.shape[0]
    num_trans = num_batch
    batch_size = int(num_batch / num_trans)
    
    # initialize random covolution
    rand_conv = nn.Conv2d(3, 3, kernel_size=3, bias=False, padding=1).to(_device)
    
    for trans_index in range(num_trans):
        torch.nn.init.xavier_normal_(rand_conv.weight.data)
        temp_imgs = imgs[trans_index*batch_size:(trans_index+1)*batch_size]
        temp_imgs = temp_imgs.reshape(-1, 3, img_h, img_w) # (batch x stack, channel, h, w)
        rand_out = rand_conv(temp_imgs)
        if trans_index == 0:
            total_out = rand_out
        else:
            total_out = torch.cat((total_out, rand_out), 0)
    total_out = total_out.reshape(-1, num_stack_channel, img_h, img_w)
    return total_out


def random_color_jitter(imgs):
    """
        inputs np array outputs tensor
    """
    b,c,h,w = imgs.shape
    imgs = imgs.view(-1,3,h,w)
    transform_module = nn.Sequential(ColorJitterLayer(brightness=0.4, 
                                                contrast=0.4,
                                                saturation=0.4, 
                                                hue=0.5, 
                                                p=1.0, 
                                                batch_size=b,
                                                stack_size=1))

    imgs = transform_module(imgs).view(b,c,h,w)
    return imgs


def random_translate(imgs, size, return_random_idxs=False, h1s=None, w1s=None):
    n, c, h, w = imgs.shape
    assert size >= h and size >= w
    outs = np.zeros((n, c, size, size), dtype=imgs.dtype)
    h1s = np.random.randint(0, size - h + 1, n) if h1s is None else h1s
    w1s = np.random.randint(0, size - w + 1, n) if w1s is None else w1s
    for out, img, h1, w1 in zip(outs, imgs, h1s, w1s):
        out[:, h1:h1 + h, w1:w1 + w] = img
    if return_random_idxs:  # So can do the same to another set of imgs.
        return outs, dict(h1s=h1s, w1s=w1s)
    return outs


def no_aug(x):
    return x


# if __name__ == '__main__':
#     import time 
#     from tabulate import tabulate
#     def now():
#         return time.time()
#     def secs(t):
#         s = now() - t
#         tot = round((1e5 * s)/60,1)
#         return round(s,3),tot

#     x = np.load('data_sample.npy',allow_pickle=True)
#     x = np.concatenate([x,x,x],1)
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#     x = torch.from_numpy(x).to(device)
#     x = x.float() / 255.

#     # crop
#     t = now()
#     random_crop(x.cpu().numpy(),64)
#     s1,tot1 = secs(t)
#     # grayscale 
#     t = now()
#     random_grayscale(x,p=.5)
#     s2,tot2 = secs(t)
#     # normal cutout 
#     t = now()
#     random_cutout(x.cpu().numpy(),10,30)
#     s3,tot3 = secs(t)
#     # color cutout 
#     t = now()
#     random_cutout_color(x.cpu().numpy(),10,30)
#     s4,tot4 = secs(t)
#     # flip 
#     t = now()
#     random_flip(x,p=.5)
#     s5,tot5 = secs(t)
#     # rotate 
#     t = now()
#     random_rotation(x,p=.5)
#     s6,tot6 = secs(t)
#     # rand conv 
#     t = now()
#     random_convolution(x)
#     s7,tot7 = secs(t)
#     # rand color jitter 
#     t = now()
#     random_color_jitter(x)
#     s8,tot8 = secs(t)
    
#     print(tabulate([['Crop', s1,tot1], 
#                     ['Grayscale', s2,tot2], 
#                     ['Normal Cutout', s3,tot3], 
#                     ['Color Cutout', s4,tot4], 
#                     ['Flip', s5,tot5], 
#                     ['Rotate', s6,tot6], 
#                     ['Rand Conv', s7,tot7], 
#                     ['Color Jitter', s8,tot8]], 
#                     headers=['Data Aug', 'Time / batch (secs)', 'Time / 100k steps (mins)']))


Hyperparameters. These values should be a good starting point. You can modify them later once you have a working implementation.

In [None]:
# Hyperparameters
total_steps = 8e6
num_envs = 32
num_levels = 100
num_steps = 256
num_epochs = 3
batch_size = 512 #512
eps = .2
grad_eps = .5
value_coef = .5
entropy_coef = .01

In [None]:
# !sudo apt-get install imagemagick

Network definitions. We have defined a policy network for you in advance. It uses the popular `NatureDQN` encoder architecture (see below), while policy and value functions are linear projections from the encodings. There is plenty of opportunity to experiment with architectures, so feel free to do that! Perhaps implement the `Impala` encoder from [this paper](https://arxiv.org/pdf/1802.01561.pdf) (perhaps minus the LSTM).

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import make_env, Storage, orthogonal_init


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class Encoder(nn.Module):
  def __init__(self, in_channels, feature_dim):
    super().__init__()
    self.layers = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=8, stride=4), nn.ReLU(),
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), nn.ReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), nn.ReLU(),
        Flatten(),
        nn.Linear(in_features=1024, out_features=feature_dim), nn.ReLU()
    )
    self.apply(orthogonal_init)

  def forward(self, x):
    return self.layers(x)


class Policy(nn.Module):
  def __init__(self, encoder, feature_dim, num_actions):
    super().__init__()
    self.encoder = encoder
    self.policy = orthogonal_init(nn.Linear(feature_dim, num_actions), gain=.01)
    self.value = orthogonal_init(nn.Linear(feature_dim, 1), gain=1.)

  def act(self, x):
    with torch.no_grad():
      x = x.cuda().contiguous()
      dist, value = self.forward(x)
      action = dist.sample()
      log_prob = dist.log_prob(action)
    
    return action.cpu(), log_prob.cpu(), value.cpu()

  def forward(self, x):
    x = self.encoder(x)
    logits = self.policy(x)
    value = self.value(x).squeeze(1)
    dist = torch.distributions.Categorical(logits=logits)

    return dist, value


# Define environmentbossfight
# check the utils.py file for info on arguments
env = make_env(num_envs, num_levels=num_levels, env_name='starpilot', use_backgrounds=True)
print('Observation space:', env.observation_space)
print('Action space:', env.action_space.n)

# Define network
encoder = Encoder(3,512)
policy = Policy(encoder, 512, 15)
policy.cuda()

# Define optimizer
# these are reasonable values but probably not optimal
optimizer = torch.optim.Adam(policy.parameters(), lr=5e-4, eps=1e-5)

# Define temporary storage
# we use this to collect transitions during each iteration
storage = Storage(
    env.observation_space.shape,
    num_steps,
    num_envs
)

''' make separate environment for evaluation '''
eval_env = make_env(num_envs, env_name = 'starpilot',start_level=num_levels, num_levels=num_levels, use_backgrounds=True)
eval_obs = eval_env.reset()


from collections import deque
eval_info_queue=deque(maxlen=num_steps)
eval_reward_queue=torch.zeros(num_steps, num_envs)

# Run training
obs = env.reset()
step = 0
print("NN setup, Training Starts")
while step < total_steps:

  policy.eval()
  
  '''list for storing eval rewards'''
  total_reward = []
  # Use policy to collect data for num_steps steps
  
  for _ in range(num_steps):
    # Use policy
    action, log_prob, value = policy.act(obs)
    
    # Take step in environment
    next_obs, reward, done, info = env.step(action)

    # Store data
    storage.store(obs, action, reward, done, info, log_prob, value)
    
    # Update current observation
    obs = next_obs

    '''evaluate on 512 or numsteps iteartions here'''
    eval_action, eval_log_prob, eval_value = policy.act(eval_obs)
    eval_obs, eval_reward, eval_done, eval_info = eval_env.step(eval_action)
    total_reward.append(torch.Tensor(eval_reward))
    ''''''
  '''once out of the loop we get the mean validation reward and print it'''
  total_reward = torch.stack(total_reward).sum(0).mean(0)
  print(f'Step: {step}\tEval reward: {total_reward}')
  ''''''
  # Add the last observation to collected data
  _, _, value = policy.act(obs)
  storage.store_last(obs, value)

  # Compute return and advantage
  storage.compute_return_advantage()

  # Optimize policy
  policy.train()
  for epoch in range(num_epochs):

    # Iterate over batches of transitions
    generator = storage.get_generator(batch_size)
    for batch in generator:
      b_obs, b_action, b_log_prob, b_value, b_returns, b_advantage = batch

      # apply color jitter
      # b_obs = b_obs.to('cpu')
      'not sure why doing this but resulting image has 9 channels'
      # b_obs=np.concatenate([b_obs,b_obs,b_obs], 1)
      # b_obs = torch.from_numpy(b_obs).to('cuda')
      
      # import torchvision
      # ss=torch.squeeze(b_obs)
      # from google.colab.patches import cv2_imshow
      # cv2_imshow(ss.to('cpu').permute(1, 2, 0).numpy())
      
      # Get current policy outputs
      new_dist, new_value = policy(b_obs)
      new_log_prob = new_dist.log_prob(b_action)
      # log_prob
      # Clipped policy objective
      #print(str(log_prob.shape) + " " + str(b_log_prob.shape) + " " + str(new_log_prob.shape))
      ratio = torch.exp(new_log_prob - b_log_prob)
      
      clipped_ratio = ratio.clamp(min=1.0 - eps, max=1.0 + eps) 
      policy_reward = torch.min(ratio * b_advantage, clipped_ratio * b_advantage)
      #clip_fraction = (abs((ratio - 1.0)) > clip).to(torch.float).mean()
      pi_loss = -policy_reward.mean()

      # Clipped value function objective
      # clipped_value = new_value + (b_value - new_value).clamp(min=-eps,max=eps)
      # vf_loss=torch.max((b_value-b_returns)**2, (clipped_value-b_returns)**2)
      # value_loss = 0.5 * vf_loss.mean()

      # clipped_value = b_value + (new_value - b_value).clamp(min=-eps,max=eps) #
      # vf_loss=torch.max((new_value-b_returns)**2, (clipped_value-b_returns)**2) #
      # value_loss = 0.5 * vf_loss.mean() #
      clipped_value = (new_value - b_value).clamp(min=-eps,max=eps)
      value_loss = 0.5 * torch.max(torch.pow(new_value - b_returns,2), torch.pow(b_value - b_returns, 2)).mean()

      # Entropy loss
      entropy_loss = new_dist.entropy().mean()

      # Backpropagate losses
      # loss = torch.mean(pi_loss+value_coef*value_loss+entropy_coef*entropy_loss) #
      loss = pi_loss + value_coef * value_loss - entropy_coef * entropy_loss
      loss.backward()

      # Clip gradients
      torch.nn.utils.clip_grad_norm_(policy.parameters(), grad_eps)

      # Update policy
      optimizer.step()
      optimizer.zero_grad()

  # Update stats
  step += num_envs * num_steps
  print(f'Step: {step}\tMean reward: {storage.get_reward()}')


print('Completed training!')
torch.save(policy.state_dict(), 'checkpoint.pt')

Observation space: Box(0.0, 1.0, (3, 64, 64), float32)
Action space: 15
NN setup, Training Starts
Step: 0	Eval reward: 7.715677738189697
Step: 8192	Mean reward: 5.0625
Step: 8192	Eval reward: 6.029409885406494
Step: 16384	Mean reward: 4.1875
Step: 16384	Eval reward: 6.019257545471191
Step: 24576	Mean reward: 6.375
Step: 24576	Eval reward: 6.883467674255371
Step: 32768	Mean reward: 5.9375
Step: 32768	Eval reward: 6.598047256469727
Step: 40960	Mean reward: 6.9375
Step: 40960	Eval reward: 5.562748432159424
Step: 49152	Mean reward: 5.3125
Step: 49152	Eval reward: 5.470124244689941
Step: 57344	Mean reward: 5.65625
Step: 57344	Eval reward: 4.845388889312744
Step: 65536	Mean reward: 6.75
Step: 65536	Eval reward: 6.3315911293029785
Step: 73728	Mean reward: 5.71875
Step: 73728	Eval reward: 5.566393852233887
Step: 81920	Mean reward: 5.625
Step: 81920	Eval reward: 4.9568190574646
Step: 90112	Mean reward: 7.0625
Step: 90112	Eval reward: 5.745906829833984
Step: 98304	Mean reward: 5.84375
Step: 9830

Below cell can be used for policy evaluation and saves an episode to mp4 for you to view.

In [None]:
# print(type(total_reward))
# import pickle
# with open('storage1.pkl', 'wb') as f:
#     pickle.dump(storage, f)
# import imageio

# # Make evaluation environment
# eval_env = make_env(num_envs, env_name = 'starpilot',start_level=num_levels, num_levels=num_levels, use_backgrounds=True)
# obs = eval_env.reset()

# frames = []
# total_reward = []

# # Evaluate policy
# policy.eval()#512
# for _ in range(512):

#   # Use policy
#   action, log_prob, value = policy.act(obs)

#   # Take step in environment
#   obs, reward, done, info = eval_env.step(action)
#   total_reward.append(torch.Tensor(reward))

#   # Render environment and store
#   frame = (torch.Tensor(eval_env.render(mode='rgb_array'))*255.).byte()
#   frames.append(frame)

# # Calculate average return
# total_reward = torch.stack(total_reward).sum(0).mean(0)
# print('Average return:', total_reward)



# # Save frames as video
# frames = torch.stack(frames)
# imageio.mimsave('cropvid.mp4', frames, fps=25)



In [None]:
# print(type(storage))

# from google.colab import files
# files.download('cropvid.mp4')

In [None]:
# # !pip install kora
# # from kora import console
# # console.start()
# !pip install procgen
# !wget https://raw.githubusercontent.com/nicklashansen/ppo-procgen-utils/main/utils.py
# !wget https://raw.githubusercontent.com/MishaLaskin/rad/1246bfd6e716669126e12c1f02f393801e1692c1/TransformLayer.py
# # Hyperparameters
# total_steps = 8e6
# num_envs = 32
# num_levels = 100
# num_steps = 256
# num_epochs = 3
# batch_size = 512 #512
# eps = .2
# grad_eps = .5
# value_coef = .5
# entropy_coef = .01

In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from utils import make_env, Storage, orthogonal_init
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from utils import make_env, Storage, orthogonal_init
  

# class Flatten(nn.Module):
#     def forward(self, x):
#         return x.view(x.size(0), -1)


# class Encoder(nn.Module):
#   def __init__(self, in_channels, feature_dim):
#     super().__init__()
#     self.layers = nn.Sequential(
#         nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=8, stride=4), nn.ReLU(),
#         nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), nn.ReLU(),
#         nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), nn.ReLU(),
#         Flatten(),
#         nn.Linear(in_features=1024, out_features=feature_dim), nn.ReLU()
#     )
#     self.apply(orthogonal_init)

#   def forward(self, x):
#     return self.layers(x)


# class Policy(nn.Module):
#   def __init__(self, encoder, feature_dim, num_actions):
#     super().__init__()
#     self.encoder = encoder
#     self.policy = orthogonal_init(nn.Linear(feature_dim, num_actions), gain=.01)
#     self.value = orthogonal_init(nn.Linear(feature_dim, 1), gain=1.)

#   def act(self, x):
#     with torch.no_grad():
#       x = x.cuda().contiguous()
#       dist, value = self.forward(x)
#       action = dist.sample()
#       log_prob = dist.log_prob(action)
    
#     return action.cpu(), log_prob.cpu(), value.cpu()

#   def forward(self, x):
#     x = self.encoder(x)
#     logits = self.policy(x)
#     value = self.value(x).squeeze(1)
#     dist = torch.distributions.Categorical(logits=logits)

#     return dist, value


# # Define environmentbossfight
# # check the utils.py file for info on arguments
# eval_env = make_env(num_envs, env_name = 'starpilot',start_level=num_levels, num_levels=num_levels, use_backgrounds=True)
# obs = eval_env.reset()
# total_reward = []

# # Define network
# encoder = Encoder(3,512)
# policy = Policy(encoder, 512, 15)
# policy.cuda()
# policy.load_state_dict(torch.load('checkpoint.pt'))
# policy.eval()



# frames = []
# total_reward = []

# # Evaluate policy
# policy.eval()#512
# for _ in range(512):

#   # Use policy
#   action, log_prob, value = policy.act(obs)

#   # Take step in environment
#   obs, reward, done, info = eval_env.step(action)
#   total_reward.append(torch.Tensor(reward))

#   # Render environment and store
#   frame = (torch.Tensor(eval_env.render(mode='rgb_array'))*255.).byte()
#   frames.append(frame)

# # Calculate average return
# # total_reward = torch.stack(total_reward).sum(0).mean(0)
# # print('Average return:', total_reward)


In [None]:
# print(len(total_reward))
# print(len(total_reward[0]))
# ree=0.0
# reet=[]
# for i in range(len(total_reward)):
#     ree+=total_reward[i].mean()
#     reet.append(total_reward[i].mean())
# print(ree)
# print(reet)