In [None]:
%load_ext autoreload
%autoreload 2

import os
import time
import yaml

from agents.soft_actor_critic import SoftActorCritic
from infrastructure.replay_buffer import ReplayBuffer
import env_configs

import os
import time

import gymnasium as gym
from gymnasium import wrappers
import numpy as np
import torch
from infrastructure import pytorch_util as ptu
import tqdm

from infrastructure import utils
from infrastructure.logger import Logger

from scripting_utils import make_logger, make_config

import argparse

## Defining the Arguments

In [None]:
class Args:
  def __init__(self):
    self.config_file = "experiments/sac/halfcheetah_reparametrize.yaml"
    self.eval_interval = 5000
    self.num_eval_trajectories = 10
    self.num_render_trajectories = 0
    self.seed = 1
    self.no_gpu = False
    self.which_gpu = 0
    self.log_interval = 1000

args = Args()
        

# create directory for logging
logdir_prefix = "hw3_sac_"  # keep for autograder

config = make_config(args.config_file)
logger = make_logger(logdir_prefix, config)

In [None]:
# set random seeds
np.random.seed(args.seed)
torch.manual_seed(args.seed)
ptu.init_gpu(use_gpu=not args.no_gpu, gpu_id=args.which_gpu)

In [None]:
# make the gym environment
env = config["make_env"]()
eval_env = config["make_env"]()
render_env = config["make_env"](render=True)

discrete = isinstance(env.action_space, gym.spaces.Discrete)
assert (
  not discrete
), "Our actor-critic implementation only supports continuous action spaces. (This isn't a fundamental limitation, just a current implementation decision.)"

In [None]:
ob_shape = env.observation_space.shape
ac_dim = env.action_space.shape[0]

# initialize agent
agent = SoftActorCritic(
  ob_shape,
  ac_dim,
  **config["agent_kwargs"],
)

In [None]:
# simulation timestep, will be used for video saving
if "model" in dir(env):
  fps = 1 / env.model.opt.timestep
else:
  fps = env.env.metadata["render_fps"]

ep_len = config["ep_len"] or env.spec.max_episode_steps
batch_size = config["batch_size"] # or batch_size

replay_buffer = ReplayBuffer(config["replay_buffer_capacity"])

In [None]:
observation, info = env.reset()

for step in tqdm.trange(config["total_steps"], dynamic_ncols=True):
  if step < config["random_steps"]:
    action = env.action_space.sample()
  else:
    # TODO(student): Select an action
    action = agent.get_action(observation)

  # Step the environment and add the data to the replay buffer
  next_observation, reward, terminated, truncated, info = env.step(action) # done got replaced by terminated and truncated
  next_observation = np.asarray(next_observation)

  truncated = info.get("TimeLimit.truncated", False)
  replay_buffer.insert(
    observation=observation,
    action=action,
    reward=reward,
    next_observation=next_observation,
    terminated=terminated,
  )

  if terminated or truncated:
    observation, info = env.reset()

    if "episode" in info:
      logger.log_scalar(info["episode"]["r"], "train_return", step)
      logger.log_scalar(info["episode"]["l"], "train_ep_len", step)
  else:
    observation = next_observation

  # Train the agent
  if step >= config["training_starts"]:
    # TODO(student): Sample a batch of config["batch_size"] transitions from the replay buffer
    batch = replay_buffer.sample(config["batch_size"])
    batch =  ptu.from_numpy(batch)
    update_info = agent.update(
      batch["observations"],
      batch["actions"],
      batch["rewards"],
      batch["next_observations"],
      batch["terminateds"],
      step,
    )


    # Logging
    update_info["actor_lr"] = agent.actor_lr_scheduler.get_last_lr()[0]
    update_info["critic_lr"] = agent.critic_lr_scheduler.get_last_lr()[0]

    if step % args.log_interval == 0:
      for k, v in update_info.items():
        logger.log_scalar(v, k, step)
        logger.log_scalars
      logger.flush()

  # Run evaluation
  if step % args.eval_interval == 0:
    trajectories = utils.sample_n_trajectories(
      eval_env,
      policy=agent,
      ntraj=args.num_eval_trajectories,
      max_length=ep_len,
    )
    returns = [t["episode_statistics"]["r"] for t in trajectories]
    ep_lens = [t["episode_statistics"]["l"] for t in trajectories]

    logger.log_scalar(np.mean(returns), "eval_return", step)
    logger.log_scalar(np.mean(ep_lens), "eval_ep_len", step)

    if len(returns) > 1:
      logger.log_scalar(np.std(returns), "eval/return_std", step)
      logger.log_scalar(np.max(returns), "eval/return_max", step)
      logger.log_scalar(np.min(returns), "eval/return_min", step)
      logger.log_scalar(np.std(ep_lens), "eval/ep_len_std", step)
      logger.log_scalar(np.max(ep_lens), "eval/ep_len_max", step)
      logger.log_scalar(np.min(ep_lens), "eval/ep_len_min", step)

    if args.num_render_trajectories > 0:
      video_trajectories = utils.sample_n_trajectories(
        render_env,
        agent,
        args.num_render_trajectories,
        ep_len,
        render=True,
      )

      logger.log_paths_as_videos(
        video_trajectories,
        step,
        fps=fps,
        max_videos_to_save=args.num_render_trajectories,
        video_title="eval_rollouts",
      )