In [None]:
import wandb
wandb.login()

In [None]:
!apt-get install -y libosmesa6-dev patchelf
!pip install gymnasium[mujoco] imageio[ffmpeg] pyopengl glfw -qU
!pip install wandb -qU

In [None]:
!pip install stable-baselines3 -qU

In [None]:
import os
os.environ["MUJOCO_GL"] = "egl"

import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

from matplotlib.animation import FuncAnimation
from IPython.display import HTML

import math

import random
import numpy as np

import glob
import time
from datetime import datetime

from typing import Callable ###for evaluate
from google.colab import files ###for downloading files

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import MultivariateNormal



import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
    raise ValueError(
        """Ongoing migration: run the following command to install the new dependencies:
           poetry run pip install "stable_baselines3==2.0.0a1"
        """
        )
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

In [None]:
from dataclasses import dataclass
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
@dataclass
class Args:
    exp_name: str = "Climber_EXP_3"
    exp_group: str = "Climber_EXP"
    """the name of this experiment"""
    seed: int = 3
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = True
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "FINAL_TL_EXP"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""

    save_model: bool = True
    load_model: bool = True

    # env_id: str =  "MountainCarContinuous-v0"
    # env_id: str =  "Pendulum-v1"
    # env_id: str =  "BipedalWalker-v3"
    # env_id: str =  "LunarLanderContinuous-v2"
    # env_id: str =  "InvertedPendulumSwingup-v4"
    # env_id: str =  "InvertedDoublePendulumSwingup-v4"
    # env_id: str =  "HalfCheetah-v5"
    # env_id: str =  "Hopper-v5"
    # env_id: str =  "Walker2d-v5"

    env_id: str =  "Ant-v5"
    # env_id: str =  "3LegAnt"
    # env_id: str =  "5LegAnt"
    # env_id: str =  "HopperAnt"
    # env_id: str =  "LongShortAnt"
    # env_id: str =  "ShortLongAnt"

    # env_id: str =  "ClimberAnt"
    # env_id: str =  "GoAroundAnt"

    # env_id: str =  "Humanoid-v5"
    # env_id: str =  "InvertedDoublePendulum-v5"
    # env_id: str =  "InvertedPendulum-v5"
    # env_id: str =  "Reacher-v5"
    # env_id: str =  "Swimmer-v5"
    """the environment id """
    total_timesteps: int = 500000
    """total timesteps of the experiments"""
    num_envs: int = 5
    """the number of parallel game environments"""
    buffer_size: int = int(1e6)
    """the replay memory buffer size"""
    gamma: float = 0.99
    """the discount factor gamma"""
    tau: float = 0.005
    """target smoothing coefficient (default: 0.005)"""
    batch_size: int = 256
    """the batch size of sample from the reply memory"""
    learning_starts: int = 5e3
    """timestep to start learning"""
    policy_lr: float = 3e-4
    """the learning rate of the policy network optimizer"""
    q_lr: float = 1e-3
    """the learning rate of the Q network network optimizer"""
    policy_frequency: int = 2
    """the frequency of training policy (delayed)"""
    target_network_frequency: int = 1  # Denis Yarats' implementation delays this by 2.
    """the frequency of updates for the target nerworks"""
    alpha: float = 0.2
    """Entropy regularization coefficient."""
    autotune: bool = True
    """automatic tuning of the entropy coefficient"""


def make_env(env_id, seed, idx, capture_video, run_name, xml=None):
    def thunk():
        if xml is not None:
          if capture_video and idx == 0:
              env = gym.make(env_id, xml_file=xml , render_mode="rgb_array")
          else:
              env = gym.make(env_id, xml_file=xml)
        else:
          if capture_video and idx == 0:
              env = gym.make(env_id, render_mode="rgb_array")
          else:
              env = gym.make(env_id)

        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)
        return env

    return thunk


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

def xav_layer_init(layer, gain=1.0):
    torch.nn.init.xavier_uniform_(layer.weight, gain=gain)
    torch.nn.init.constant_(layer.bias, 0)
    return layer

def kaiming_layer_init(layer):
    init.kaiming_uniform_(layer.weight, nonlinearity='relu')
    torch.nn.init.constant_(layer.bias, 0)
    return layer

# ALGO LOGIC: initialize agent here:
class SoftQNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.fc1 = nn.Linear(
            np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape),
            256,
        )
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, x, a):
        x = torch.cat([x, a], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


LOG_STD_MAX = 2
LOG_STD_MIN = -5


class Actor(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.fc1 = nn.Linear(np.array(env.single_observation_space.shape).prod(), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mean = nn.Linear(256, np.prod(env.single_action_space.shape))
        self.fc_logstd = nn.Linear(256, np.prod(env.single_action_space.shape))

        # action rescaling
        self.register_buffer(
            "action_scale",
            torch.tensor(
                (env.single_action_space.high - env.single_action_space.low) / 2.0,
                dtype=torch.float32,
            ),
        )
        self.register_buffer(
            "action_bias",
            torch.tensor(
                (env.single_action_space.high + env.single_action_space.low) / 2.0,
                dtype=torch.float32,
            ),
        )

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)  # From SpinUp / Denis Yarats

        return mean, log_std

    def get_action(self, x):
        mean, log_std = self(x)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean


def evaluate(
    model_path: str,
    make_env: Callable,
    env_id: str,
    eval_episodes: int,
    run_name: str,
    Model: torch.nn.Module,
    device: torch.device = torch.device("cpu"),
    capture_video: bool = True,
    gamma: float = 0.99,
    ):

    envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, capture_video, run_name, gamma)])
    agent = Model(envs).to(device)
    agent.load_state_dict(torch.load(model_path, map_location=device))
    agent.eval()

    obs, _ = envs.reset()

    episode_return = np.zeros(1, dtype=np.float32)
    episodic_returns = []
    while len(episodic_returns) < eval_episodes:
        actions, _, _, _ = agent.get_action_and_value(torch.Tensor(obs).to(device))
        obs, rewards, done, truncated, _ = envs.step(actions.detach().cpu().numpy())
        dones = np.logical_or(done, truncated)
        episode_return += rewards
        if dones.any():
          episodic_returns.append(episode_return)
          episode_return = 0.0

    return episodic_returns

In [None]:
from gymnasium.envs.mujoco.ant_v5 import AntEnv
from gymnasium.spaces import Box
from gymnasium.envs.registration import register

class ThreeLegAntEnv(AntEnv):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.action_space = Box(low=-1.0, high=1.0, shape=(6,), dtype=np.float32)

register(
    id="3LegAnt",
    entry_point=ThreeLegAntEnv,
    max_episode_steps=1000,
    reward_threshold=6000.0,
)

class FiveLegAntEnv(AntEnv):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.action_space = Box(low=-1.0, high=1.0, shape=(10,), dtype=np.float32)

register(
    id="5LegAnt",
    entry_point=FiveLegAntEnv,
    max_episode_steps=1000,
    reward_threshold=6000.0,
)

class HopperAntEnv(AntEnv):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

register(
    id="HopperAnt",
    entry_point=HopperAntEnv,
    max_episode_steps=1000,
    reward_threshold=6000.0,
)

class LongShortAntEnv(AntEnv):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

register(
    id="LongShortAnt",
    entry_point=LongShortAntEnv,
    max_episode_steps=1000,
    reward_threshold=6000.0,
)

class ShortLongAntEnv(AntEnv):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

register(
    id="ShortLongAnt",
    entry_point=ShortLongAntEnv,
    max_episode_steps=1000,
    reward_threshold=6000.0,
)

class ClimberAntEnv(AntEnv):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self._forward_reward_weight = 0.3
    self._healthy_z_range = (0.2, 1.6)
  def _get_rew(self, x_velocity: float, action):
      to_x_reward = 0
      if x_velocity > 0.3:
        to_x_reward = 2
      forward_reward = (x_velocity * self._forward_reward_weight) + to_x_reward
      healthy_reward = self.healthy_reward
      rewards = forward_reward + healthy_reward

      ctrl_cost = self.control_cost(action)
      contact_cost = self.contact_cost
      costs = ctrl_cost + contact_cost

      reward = rewards - costs

      reward_info = {
          "reward_forward": forward_reward,
          "reward_ctrl": -ctrl_cost,
          "reward_contact": -contact_cost,
          "reward_survive": healthy_reward,
      }

      return reward, reward_info

register(
    id="ClimberAnt",
    entry_point=ClimberAntEnv,
    max_episode_steps=1200,
    reward_threshold=5000.0,
)

class GoAroundAntEnv(AntEnv):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self._forward_reward_weight = 0.3
  def _get_rew(self, x_velocity: float, action):
      to_x_reward = 0
      if x_velocity > 0.3:
        to_x_reward = 2
      forward_reward = (x_velocity * self._forward_reward_weight) + to_x_reward
      healthy_reward = self.healthy_reward
      rewards = forward_reward + healthy_reward

      ctrl_cost = self.control_cost(action)
      contact_cost = self.contact_cost
      costs = ctrl_cost + contact_cost

      reward = rewards - costs

      reward_info = {
          "reward_forward": forward_reward,
          "reward_ctrl": -ctrl_cost,
          "reward_contact": -contact_cost,
          "reward_survive": healthy_reward,
      }

      return reward, reward_info


register(
    id="GoAroundAnt",
    entry_point=GoAroundAntEnv,
    max_episode_steps=1200,
    reward_threshold=5000.0,
)

In [None]:
from gymnasium.vector.vector_env import AutoresetMode
from collections import deque

In [None]:
def train(seed=None):
    args = Args()
    if seed is not None:
      args.seed = seed
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            group=args.exp_group, ###
            monitor_gym=True,
            save_code=True,
            reinit="return_previous" ###
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    ################## get xml path ##################
    xml_path = None
    if args.env_id != "Ant-v5":
      xml_path = f'/content/{args.env_id}.xml'
    ##################################################
    # env setup
    envs = gym.vector.AsyncVectorEnv(
        [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name, xml=xml_path) for i in range(args.num_envs)],
        autoreset_mode=AutoresetMode.SAME_STEP
    )
    assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

    max_action = float(envs.single_action_space.high[0])

    actor = Actor(envs).to(device)
    qf1 = SoftQNetwork(envs).to(device)
    qf2 = SoftQNetwork(envs).to(device)
    qf1_target = SoftQNetwork(envs).to(device)
    qf2_target = SoftQNetwork(envs).to(device)
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())
    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr)

    ######## load pretrained model #########
    if args.load_model:
      for file in os.listdir('/content'):
        if file.endswith("_actor.pth"):
          filepath = os.path.join('/content', file)
          actor.load_state_dict(torch.load(filepath, map_location=device))
        elif file.endswith("_qf1.pth"):
          filepath = os.path.join('/content', file)
          qf1.load_state_dict(torch.load(filepath, map_location=device))
        elif file.endswith("_qf2.pth"):
          filepath = os.path.join('/content', file)
          qf2.load_state_dict(torch.load(filepath, map_location=device))
        elif file.endswith("_qf1_target.pth"):
          filepath = os.path.join('/content', file)
          qf1_target.load_state_dict(torch.load(filepath, map_location=device))
        elif file.endswith("_qf2_target.pth"):
          filepath = os.path.join('/content', file)
          qf2_target.load_state_dict(torch.load(filepath, map_location=device))
    ########################################

    ########################## log model ######################
    if args.track:
        wandb.watch(actor, log="all", log_freq=1024)
        model_architecture = str(actor)
        wandb.log({f"model/architecture-exp.:{args.exp_name}-actor": wandb.Html(f"<pre>{model_architecture}</pre>")})
        wandb.watch(qf1, log="all", log_freq=1024)
        model_architecture = str(qf1)
        wandb.log({f"model/architecture-exp.:{args.exp_name}-SoftQnet": wandb.Html(f"<pre>{model_architecture}</pre>")})
        wandb.watch(qf2, log="all", log_freq=1024)
        wandb.watch(qf1_target, log="all", log_freq=1024)
        wandb.watch(qf2_target, log="all", log_freq=1024)
    ############################################################

    # Automatic entropy tuning
    if args.autotune:
        target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item()
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
        alpha = log_alpha.exp().item()
        a_optimizer = optim.Adam([log_alpha], lr=args.q_lr)
    else:
        alpha = args.alpha

    envs.single_observation_space.dtype = np.float32
    rb = ReplayBuffer(
        args.buffer_size,
        envs.single_observation_space,
        envs.single_action_space,
        device,
        n_envs=args.num_envs,
        handle_timeout_termination=False,
    )
    start_time = time.time()

    ######## for logging reward components #############
    reward_components = ["reward_forward", "reward_ctrl", "reward_contact", "reward_survive"]
    reward_sums = {comp: np.zeros(args.num_envs, dtype=np.float32) for comp in reward_components}
    smoothed_returns = deque(maxlen=50)
    ####################################################

    # TRY NOT TO MODIFY: start the game
    obs, _ = envs.reset(seed=args.seed)
    for global_step in range(args.total_timesteps):
        # ALGO LOGIC: put action logic here
        if global_step < args.learning_starts:
            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
        else:
            actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
            actions = actions.detach().cpu().numpy()

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards, terminations, truncations, infos = envs.step(actions)
        # print(infos)

        ###### calc reward comp ####
        for comp in reward_components:
          if comp in infos:
            reward_sums[comp] += infos[comp]
        ############################

        # TRY NOT TO MODIFY: record rewards for plotting purposes
        if "final_info" in infos:
            # print("infos['final_info']:", infos["final_info"])
            episode_infos = infos["final_info"]["episode"]
            episode_mask = infos["final_info"]["_episode"]
            episode_returns = episode_infos["r"][episode_mask]
            episode_lengths = episode_infos["l"][episode_mask]
            avg_epi_return = np.mean(episode_returns)
            avg_epi_length = np.mean(episode_lengths)
            # Append to smoothing buffer
            smoothed_returns.extend(episode_returns)
            smoothed_avg_return = np.mean(smoothed_returns)
            ####
            print(f"global_step: {global_step*args.num_envs}, avg_episodic_return: {avg_epi_return} avg_episodic_length: {avg_epi_length}")
            writer.add_scalar("charts/envs_finished", len(episode_returns), (global_step*args.num_envs))
            writer.add_scalar("charts/avg_episodic_return", avg_epi_return,(global_step*args.num_envs))
            writer.add_scalar("charts/smoothed_avg_episodic_return", smoothed_avg_return, global_step * args.num_envs)
            writer.add_scalar("charts/avg_episodic_length", avg_epi_length, (global_step*args.num_envs))
            for comp in reward_components:
              avg_reward_comp = np.mean(reward_sums[comp][episode_mask])
              writer.add_scalar(f"costs_reward/{comp}", avg_reward_comp, (global_step*args.num_envs))
              reward_sums[comp][episode_mask] = 0.0

        # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
        real_next_obs = next_obs.copy()
        for idx, trunc in enumerate(truncations):
            if trunc:
                real_next_obs[idx] = infos["final_obs"][idx]
        rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
        obs = next_obs

        # ALGO LOGIC: training.
        if global_step > args.learning_starts:
            data = rb.sample(args.batch_size)
            with torch.no_grad():
                next_state_actions, next_state_log_pi, _ = actor.get_action(data.next_observations)
                qf1_next_target = qf1_target(data.next_observations, next_state_actions)
                qf2_next_target = qf2_target(data.next_observations, next_state_actions)
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)

            qf1_a_values = qf1(data.observations, data.actions).view(-1)
            qf2_a_values = qf2(data.observations, data.actions).view(-1)
            qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
            qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
            qf_loss = qf1_loss + qf2_loss

            # optimize the model
            q_optimizer.zero_grad()
            qf_loss.backward()
            q_optimizer.step()

            if global_step % args.policy_frequency == 0:  # TD 3 Delayed update support
                for _ in range(
                    args.policy_frequency
                ):  # compensate for the delay by doing 'actor_update_interval' instead of 1
                    pi, log_pi, _ = actor.get_action(data.observations)
                    if args.track:
                      entropy = -log_pi.mean().item()
                      writer.add_scalar("policy/entropy", entropy, global_step * args.num_envs)
                    qf1_pi = qf1(data.observations, pi)
                    qf2_pi = qf2(data.observations, pi)
                    min_qf_pi = torch.min(qf1_pi, qf2_pi)
                    actor_loss = ((alpha * log_pi) - min_qf_pi).mean()

                    actor_optimizer.zero_grad()
                    actor_loss.backward()
                    actor_optimizer.step()

                    if args.autotune:
                        with torch.no_grad():
                            _, log_pi, _ = actor.get_action(data.observations)
                        alpha_loss = (-log_alpha.exp() * (log_pi + target_entropy)).mean()

                        a_optimizer.zero_grad()
                        alpha_loss.backward()
                        a_optimizer.step()
                        alpha = log_alpha.exp().item()

            # update the target networks
            if global_step % args.target_network_frequency == 0:
                for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
                for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                    target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)

            if global_step % 100 == 0:
                writer.add_scalar("charts/train_minutes", ((time.time() - start_time)//60), (global_step*args.num_envs)) ########
                writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), (global_step*args.num_envs))
                writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), (global_step*args.num_envs))
                writer.add_scalar("losses/qf1_loss", qf1_loss.item(), (global_step*args.num_envs))
                writer.add_scalar("losses/qf2_loss", qf2_loss.item(), (global_step*args.num_envs))
                writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, (global_step*args.num_envs))
                writer.add_scalar("losses/actor_loss", actor_loss.item(), (global_step*args.num_envs))
                writer.add_scalar("losses/alpha", alpha, (global_step*args.num_envs))
                print("SPS:", int((global_step*args.num_envs) / (time.time() - start_time)), " total_run_time:", ((time.time()-start_time)//60)," minutes")
                writer.add_scalar(
                    "charts/SPS",
                    int((global_step*args.num_envs) / (time.time() - start_time)),
                    (global_step*args.num_envs),
                )
                if args.autotune:
                    writer.add_scalar("losses/alpha_loss", alpha_loss.item(), (global_step*args.num_envs))
    if args.save_model:
        model_path = f"runs/{run_name}/{args.exp_name}-seed-{args.seed}"
        os.makedirs(model_path, exist_ok=True)
        torch.save(actor.state_dict(), f"{model_path}/{seed}_{args.exp_name}_actor.pth")
        torch.save(qf1.state_dict(), f"{model_path}/{seed}_{args.exp_name}_qf1.pth")
        torch.save(qf2.state_dict(), f"{model_path}/{seed}_{args.exp_name}_qf2.pth")
        torch.save(qf1_target.state_dict(), f"{model_path}/{seed}_{args.exp_name}_qf1_target.pth")
        torch.save(qf2_target.state_dict(), f"{model_path}/{seed}_{args.exp_name}_qf2_target.pth")
        if args.autotune:
            torch.save(log_alpha, f"{model_path}/{seed}_{args.exp_name}_log_alpha.pth")
        print(f"model saved to {model_path}")

    envs.close()
    writer.close()

In [None]:
# for seed in [1,2,3,4,5]:
#   train(seed)

train()

In [None]:
wandb.finish()

In [None]:
#download models
for folder in os.listdir(f"runs"):
  if os.path.isdir(os.path.join(f"runs", folder)):
    for subfolder in os.listdir(os.path.join(f"runs", folder)):
      if os.path.isdir(os.path.join(f"runs", folder, subfolder)):
        for file in os.listdir(os.path.join(f"runs", folder, subfolder)):
          for data in ["actor", "qf1", "qf2", "qf1_target", "qf2_target", "log_alpha"]:
            if file.endswith(f"{data}.pth"):
              filepath = os.path.join(f"runs", folder, subfolder, file)
              files.download(filepath)
              time.sleep(10)

In [None]:
#get animation
plt.rcParams['animation.embed_limit'] = 1500 * 1024 * 1024
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pretrained = False

filepath = None
if pretrained:
  for file in os.listdir(f"runs"):
    if file.endswith("actor.pth"):
      filepath = os.path.join(f"runs", file)
else:
  for folder in os.listdir(f"runs"):
    if os.path.isdir(os.path.join(f"runs", folder)):
      for subfolder in os.listdir(os.path.join(f"runs", folder)):
        if os.path.isdir(os.path.join(f"runs", folder, subfolder)):
          for file in os.listdir(os.path.join(f"runs", folder, subfolder)):
            if file.endswith("actor.pth"):
              filepath = os.path.join(f"runs", folder, subfolder, file)

args = Args()

xml_path = None
if args.env_id != "Ant-v5":
    xml_path = f'/content/{args.env_id}.xml'

envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, True, None, xml=xml_path)])
actor = Actor(envs).to(device)
actor.load_state_dict(torch.load(filepath, map_location=device))
actor.eval()

obs, _ = envs.reset(seed=args.seed)
obs = torch.Tensor(obs).to(device)
frames = []
terminations = False
truncations = False
for _ in range(800):
    action, _, _ = actor.get_action(torch.Tensor(obs).to(device))
    action = action.detach().cpu().numpy()
    obs, rewards, terminations, truncations, infos = envs.step(action)
    obs = torch.Tensor(obs).to(device)
    frames.append(envs.render()[0])
    if terminations[0] or truncations[0]:
      display(infos)
      break

envs.close()


def display_video(frames, fps=24):
    fig, ax = plt.subplots()
    ax.axis("off")
    img = ax.imshow(frames[0])

    def update(frame):
        img.set_array(frame)
        return [img]

    ani = FuncAnimation(fig, update, frames=frames, blit=True, interval=1000 // fps)
    ani_file = f"runs/{args.env_id}.mp4"
    ani.save(ani_file, writer='ffmpeg', fps=24)
    files.download(ani_file)
    plt.close(fig)
    # return HTML(ani.to_jshtml())

display_video(frames)