In [3]:
import torch
from torch import nn
import torch.nn.functional as F
import gymnasium as gym
from distrl.worker import Agent
from gymnasium.utils.save_video import save_video

In [4]:
class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)

In [40]:
workers = 40
policy = Policy()
policy.eval()
policy.load_state_dict(torch.load(f"../models/archive/CartPole-v1_{workers}workers_policy.pt"))

agent = Agent(policy)
env = gym.make('CartPole-v1', render_mode="rgb_array_list")
state, _ = env.reset()

for _ in range(1000):
    action, _ = agent.act(state)
    state, reward, done, _, _ = env.step(action)
    
    if done:
        break

save_video(env.render(), "archive/videos", fps=env.metadata["render_fps"])
env.close()

Moviepy - Building video /home/jlcg/projects/DistRL/notebooks/videos/rl-video-episode-0.mp4.
Moviepy - Writing video /home/jlcg/projects/DistRL/notebooks/videos/rl-video-episode-0.mp4


                                                                

Moviepy - Done !
Moviepy - video ready /home/jlcg/projects/DistRL/notebooks/videos/rl-video-episode-0.mp4
