In [None]:
!pip install imageio

In [None]:
from model.utils import setup_training
import gymnasium as gym
import torch
import torch.nn as nn
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import imageio
from model.models import DiscreteActor

In [None]:
# Configuration and weight paths
config_path = '../config/CartPole.yaml'  # Path to the configuration file
weight_path = 'your weight path'  # Path to the pre-trained model weights

# Set the device to CUDA if available, otherwise use CPU
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
args = setup_training(config_path)

env = gym.make(args.env_name, render_mode="rgb_array")
observation, info = env.reset(seed=args.seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

actor = DiscreteActor(n_wires=args.n_wires,
                                       n_blocks=args.n_blocks,
                                       input_dim=args.state_dim,
                                       output_dim=args.action_dim,
                                       ini_method=args.ini_method).to(DEVICE)

actor.load_state_dict(torch.load(weight_path))  # # Load the pre-trained model weights

In [None]:
frames = []

done = False
total_reward = 0

while not done:
    frame = env.render()
    frames.append(frame)

    s = torch.tensor(observation, dtype=torch.float32, device=DEVICE).unsqueeze(0)

    with torch.no_grad():
        dist = Categorical(probs=actor(s))
        action = dist.sample().cpu().item()

    observation, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    total_reward += reward

env.close()
print(f"Total Reward: {total_reward}")


In [None]:
# Save the recorded frames as a GIF
gif_path = "env_inference.gif"
imageio.mimsave(gif_path, frames, fps=30)

In [None]:
# Display the GIF in Jupyter Notebook
from IPython.display import Image
Image(gif_path)