In [1]:
import gymnasium as gym
import panda_gym
import numpy as np
from agents.sac import SAC, ReplayBuffer
from envs.panda_utils import generate_video, eval_model, save_plots
from envs.utils import setup_training_dir
import torch
import time

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device) 

cpu


In [3]:
env_name = 'PandaReachJoints-v3'
version = "v1"
max_episode_steps = 200
env = gym.make(env_name, max_episode_steps=max_episode_steps)
print(env.observation_space)
print(env.action_space)

Dict('achieved_goal': Box(-10.0, 10.0, (3,), float32), 'desired_goal': Box(-10.0, 10.0, (3,), float32), 'observation': Box(-10.0, 10.0, (6,), float32))
Box(-1.0, 1.0, (7,), float32)


In [4]:
resume_training = True
checkpoint = f"training/sac/{env_name}/{version}/training2/500000.pth"
training_number = setup_training_dir(resume_training, "sac", env_name, version)

max_training_time = 7 #h

In [5]:
obs_size = env.observation_space['observation'].shape[0] + env.observation_space['desired_goal'].shape[0]
n_actions = env.action_space.shape[0]
buffer_size = 1000000
max_timesteps = 1000000
max_timesteps = 500000
alpha = .2
gamma = .99
tau = .005
lr = 3e-4
batch_size = 256
warmup_timesteps = 500
eval_frequency = 2000
n_episodes_eval = 10
checkpoint_frequency = 100000
video_frequency = 50000

In [6]:
model = SAC(obs_size, n_actions, buffer_size, alpha, gamma, tau, lr, device).to(device)

In [7]:
if resume_training:
    training_vars = model.load_state(checkpoint)
    timestep_start, avg_returns = training_vars
else:
    timestep_start = 0
    avg_returns = []

In [8]:
obs, info = env.reset()
state = torch.tensor(np.concatenate([obs["observation"], obs["desired_goal"]]))
start_time = time.time()

for timestep in range(timestep_start, max_timesteps):

    action = model.act(state)
    obs, reward, terminated, truncated, info = env.step(action)
    next_state = torch.tensor(np.concatenate([obs["observation"], obs["desired_goal"]]))
    done = terminated or truncated

    model.save_to_buffer([state, action, reward, next_state, terminated])

    if timestep >= warmup_timesteps: # Update every timestep after warmup
        model.update(batch_size)

    state = next_state

    if done:
        obs, info = env.reset()

    if (timestep + 1) % eval_frequency == 0:
        avg_return = eval_model(model, env_name, max_episode_steps, n_episodes_eval)
        print(f"Average return after {timestep+1} timesteps : {avg_return}")
        avg_returns.append(avg_return)
        save_plots(avg_returns, f"training/sac/{env_name}/{version}/training{training_number}", timestep+1, eval_frequency)

    if (timestep + 1) % checkpoint_frequency == 0:
        model.save_state(timestep, avg_returns, f"training/sac/{env_name}/{version}/training{training_number}/{timestep+1}.pth")

    # if (timestep + 1) % video_frequency == 0:
    #     generate_video(env_name, model, 1, 
    #            deterministic=True, 
    #            filename=f"training/sac/{env_name}/{version}/training{training_number}/{timestep+1}.mp4")
        
    if time.time() - start_time > 3600 * max_training_time:
        print(f"Maximum training time of {max_training_time}h exceeded. Interrupting training after {timestep} timesteps.")
        break 

Average return after 500000 timesteps : -5.1


In [9]:
generate_video(env_name, max_episode_steps, model, 50, deterministic=True, 
               filename=f"training/sac/{env_name}/{version}/training{training_number}/final.mp4")

In [10]:
env.close()