In [5]:
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecMonitor, VecFrameStack, DummyVecEnv

from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import CheckpointCallback
import torch


import sys
sys.path.append('../') #This is added so we can import from the source folder
from src.policies import ImpalaCNN

def make_procgen_env(env_name, num_envs=1, start_level=0, num_levels=0):
    def _init():
        
        return env
    return DummyVecEnv([_init] * num_envs)

env_name = "procgen:procgen-heist-v0"  



env = gym.make(env_name, start_level=100, num_levels=200, render_mode="rgb_array", distribution_mode="easy") #remove render mode argument to go faster but not produce images 



In [6]:
model_path = '../model_1400_latest.pt'
observation_space = env.observation_space
action_space = env.action_space.n
model = ImpalaCNN(observation_space, action_space)
model.load_from_file(model_path, device="cpu")


In [7]:
@torch.no_grad()
def generate_action(model, observation):
    observation = torch.tensor(observation, dtype=torch.float32).unsqueeze(0)

    model_output = model(observation)
    
    logits = model_output[0].logits  # discard the output of the critic in our actor critic network
    
    probabilities = torch.softmax(logits, dim=-1)
    
    action = torch.multinomial(probabilities, 1).item() 
    return action


In [8]:
import torch
import imageio

def run_episode_and_save_as_gif(env, model, filepath='../gifs/run.gif', save_gif=False):
    frames = []  
    observation = env.reset()
    done = False
    total_reward = 0

    while not done:
        if save_gif:
            frames.append(env.render(mode='rgb_array'))  
        action = generate_action(model, observation)  
        observation, reward, done, info = env.step(action)
        total_reward += reward

    if save_gif:
        imageio.mimsave(filepath, frames, fps=30) 

    return total_reward, frames 


save_gif_option = True  

for episode in range(20):
    total_reward, _ = run_episode_and_save_as_gif(env, model, filepath=f'episode_mod_2_{episode+1}.gif', save_gif=save_gif_option)
    print(f"Episode {episode + 1} finished with total reward: {total_reward}")


Episode 1 finished with total reward: 10.0
Episode 2 finished with total reward: 10.0
Episode 3 finished with total reward: 10.0
Episode 4 finished with total reward: 10.0
Episode 5 finished with total reward: 10.0
Episode 6 finished with total reward: 10.0
Episode 7 finished with total reward: 10.0
Episode 8 finished with total reward: 10.0
Episode 9 finished with total reward: 10.0
Episode 10 finished with total reward: 10.0
Episode 11 finished with total reward: 10.0
Episode 12 finished with total reward: 10.0
Episode 13 finished with total reward: 10.0
Episode 14 finished with total reward: 10.0
Episode 15 finished with total reward: 10.0
Episode 16 finished with total reward: 10.0
Episode 17 finished with total reward: 10.0


KeyboardInterrupt: 