In [16]:
import numpy as np
import pandas as pd
from collections import deque
import matplotlib.pyplot as plt

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import torch.distributions as distributions

# Gym
import gym


In [17]:
#Using a neural network to learn our actor (policy) parameters

class Actor(nn.Module):
    def __init__(self, s_size, a_size, h_size):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(s_size, h_size, dtype=torch.float32)
        self.fc2 = nn.Linear(h_size, a_size, dtype=torch.float32)
        self.double()

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = torch.tanh(self.fc2(x))  # Apply tanh activation for bounded output in [-1, 1]
        return x

    def act(self, state):
        # print('here here')
        # print(state)
        mean = self.forward(state)
        # print('mean')
        # print(mean)
        std_dev = torch.ones_like(mean)  # You might want to learn a parameter for the standard deviation as well
        # print(std_dev)
        distribution = distributions.Normal(mean, std_dev)
        # print(distribution)

        action = distribution.sample()
        # print('action')
        # print(action)
        action = torch.clamp(action, -1.0, 1.0)  # Ensure actions are within the valid range

        # print('log_prob')
        log_prob = distribution.log_prob(action)
        # print(log_prob)

        # print(log_prob)
        return torch.tensor(action), log_prob

In [18]:
#Using a neural network to learn state value
class Critic(nn.Module):

    #Takes in state
    def __init__(self, s_size, h_size):
        super(Critic, self).__init__()

        # two fully connected layers
        # add code here
        self.input_layer = nn.Linear(s_size, h_size)
        # add code here
        self.output_layer = nn.Linear(h_size, 1)
        self.double()

    def forward(self, x):

        #input layer
        x = self.input_layer(x)

        #activiation relu
        x = F.relu(x)

        #get state value
        state_value = self.output_layer(x)

        return state_value


In [19]:
import torch
env_id = 'Humanoid-v4'
# Load the entire model
actor_path = './'+ env_id+ '/actor.pth'
actor = torch.load(actor_path)
critic_path = './'+ env_id+ '/critic.pth'
critic = torch.load(critic_path)

In [20]:
import os
import imageio
import gym
from base64 import b64encode
from IPython.display import HTML

def generate_trajectory_with_frames(actor, critic, env, max_t):
    saved_log_probs = []
    rewards = []
    state_values = []

    state, _ = env.reset()
    frames = []  # List to store frames

    for t in range(max_t):
        state = torch.from_numpy(state)

        action, log_prob = actor.act(state)

        # get the state value from the critic network
        state_val = critic(state)

        next_state, reward, done, _, _ = env.step(action.numpy().astype(np.float32))
        frames.append(env.render())

        # add the obtained results to their relative lists ==> saved_log_probs, rewards, state_values
        saved_log_probs.append(log_prob)
        rewards.append(reward)
        state_values.append(state_val)

        state = next_state

        if done:
            break

    return saved_log_probs, rewards, state_values, frames


try:
    os.environ["DISPLAY"]
except:
    os.environ["SDL_VIDEODRIVER"] = "dummy"

env = gym.make(env_id, render_mode="rgb_array")
_, _, _, images = generate_trajectory_with_frames(actor, critic, env.unwrapped, 1000)
print(len(images))

def record_video(images, out_directory, fps=10):
    imageio.mimsave(out_directory, [np.array(img) for i, img in enumerate(images)], fps=fps)

# generate the video
for i in range(100):
    _, _, _, images = generate_trajectory_with_frames(actor, critic, env.unwrapped, 1000)
    print(len(images))
    video_path = "./"+env_id+"/"+str(len(images))+".mp4"
    record_video(images, video_path, 10)


  return torch.tensor(action), log_prob


15
15
15
15
15
16
21
15
16
16
16
16
16
15
15
22
16
29
15
16
17
15
15
16
28
15
16
16
14
15
15
15
15
15
15
15
16
16
19
15
25
15
14
15
17
15
15
15
15
15
15
16
19
15
14
16
16
16
16
16
16
30
16
15
16
16
16
15
17
19
17
15
21
16
16
16
15
16
16
15
16
16
15
15
28
16
28
16
16
16
15
16
15
18
26
15
16
14
16
15
16
