In [1]:
import scipy.signal
import sys
import torch
import torch.nn as nn
import numpy as np

In [2]:
from typing import Dict, List, Optional, Tuple
import gym
from PIL import Image
# from pyvirtualdisplay import Display
# Display().start()
from datetime import datetime
from tqdm import tqdm

In [3]:
import math
import random
from copy import deepcopy
import torch
from torch.optim import Adam
from torch.optim import RMSprop
import gym
import time
from collections import namedtuple, deque
import neptune.new as neptune

In [4]:
import robosuite as suite
from robosuite.controllers import load_controller_config
from robosuite.controllers.controller_factory import reset_controllers
from robosuite.utils import observables
from robosuite.utils.input_utils import *
from robosuite.robots import Bimanual
import imageio
import numpy as np
import robosuite.utils.macros as macros
macros.IMAGE_CONVENTION = "opencv"

In [5]:
nep_log = neptune.init(
    project="xhnfirst/DDPG-robosuite",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI1NTg5MDI2OS01MTVmLTQ2YjUtODA1Yy02ZWQyNDgxZDcwN2UifQ==",
)

https://app.neptune.ai/xhnfirst/DDPG-robosuite/e/DDPGROB-149
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [6]:
options = {
    'env_name': 'EElab_test2',
    "robots": "UR5e"
}
controller_name = "JOINT_VELOCITY"
options["controller_configs"] = suite.load_controller_config(default_controller=controller_name)

env = suite.make(
    **options,
    has_renderer=False,
    has_offscreen_renderer=True,
    ignore_done=True,
    use_camera_obs=False,
    gripper_types=None,
    renderer = 'mujoco',

)

test_env = suite.make(
    **options,
    has_renderer=False,
    has_offscreen_renderer=True,
    ignore_done=True,
    use_camera_obs=False,
    gripper_types=None,
    renderer = 'mujoco',
)


video_env = suite.make(
    **options,
    gripper_types=None,
    has_renderer=False,
    has_offscreen_renderer=True,
    ignore_done=True,
    use_camera_obs=True,
    use_object_obs=True, 
    camera_names='Labviewer',
    camera_heights=512,
    camera_widths=512,
    # control_freq=200,
    renderer = 'mujoco',
)

frame = []
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device = ', device)

device =  cuda


In [7]:
def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)


class MLPActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        pi_sizes = [obs_dim] + list(hidden_sizes) + [act_dim]
        self.pi = mlp(pi_sizes, activation, nn.Tanh)
        self.act_limit = act_limit

    def forward(self, obs):
        # Return output from network scaled to action space limits.
        return self.act_limit * self.pi(obs)

class MLPQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.

class MLPActorCritic(nn.Module):

    def __init__(self, hidden_sizes=(256,256),
                 activation=nn.ReLU, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
        super().__init__()

        obs_dim = 35
        act_dim = 6
        act_limit = 1

        # build policy and value functions
        self.pi = MLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit).to(device)
        self.q = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation).to(device)

    def act(self, obs):
        with torch.no_grad():
            return self.pi(obs)

In [8]:
Transition = namedtuple('Transition',
                        ('obs', 'act', 'rew', 'next_obs', 'done'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [9]:

params = {
    "dropout": 0.2,
    "learning_rate": 0.001,
    "optimizer": "Adam",
    "hid": 256,
    "l": 3,
    "seed": 0,
    "steps_per_epoch": 3000,
    "steps_video": 30000,
    "epochs": 1000,
    "replay_size": int(1e8),
    "gamma": 0.99,
    "polyak": 0.995,
    "pi_lr": 1e-4,
    "q_lr": 1e-4,
    "batch_size": 1000,
    "start_steps": 10000, 
    "update_after": 5000,
    "update_every": 100,
    "act_noise": 0.01,
    "num_test_episodes": 5,
    "max_ep_len": 1000,
    "max_video_len": 1000,
    "save_model_len": 10000,
    # "obs_dim": 47,
    # "act_dim": 7,
    # "act_limit": 1
}

ac_kwargs=dict(hidden_sizes=[params["hid"]]*params["l"])

In [10]:
nep_log["parameters"] = params

torch.manual_seed(params["seed"])
np.random.seed(params["seed"])

obs_dim = 35
print('obs_dim = ', obs_dim)
act_dim = 6
print('act_dim = ', act_dim)
# Action limit for clamping: critically, assumes all dimensions share the same bound!
act_limit = 1
print('act_limit = ', act_limit)
# Create actor-critic module and target networks
ac = MLPActorCritic(**ac_kwargs)
ac_targ = deepcopy(ac)

# Freeze target networks with respect to optimizers (only update via polyak averaging)
for p in ac_targ.parameters():
    p.requires_grad = False

memory = ReplayMemory(params["replay_size"])

obs_dim =  35
act_dim =  6
act_limit =  1


In [11]:
# Set up function for computing DDPG Q-loss
def compute_loss_q(data):

    o = torch.cat(data.obs).float()
    a = torch.cat(data.act).float()
    r = torch.cat(data.rew).float()
    o2 =torch.cat(data.next_obs).float()
    d = torch.cat(data.done).float()

    q = ac.q(o,a)


    # Bellman backup for Q function
    with torch.no_grad():
        q_pi_targ = ac_targ.q(o2, ac_targ.pi(o2))
        backup = r + params["gamma"] * (1 - d) * q_pi_targ

    # MSE loss against Bellman backup
    loss_q = ((q - backup)**2).mean()

    return loss_q

# Set up function for computing DDPG pi loss
def compute_loss_pi(data):

    o = torch.cat(data.obs).float()

    q_pi = ac.q(o, ac.pi(o))

    return -q_pi.mean()


In [12]:
pi_optimizer = RMSprop(ac.pi.parameters(), lr=params["pi_lr"])
q_optimizer = RMSprop(ac.q.parameters(), lr=params["q_lr"])

def update(data):
    # First run one gradient descent step for Q.


    q_optimizer.zero_grad()
    loss_q = compute_loss_q(data)

    loss_q.backward()

    q_optimizer.step()


    # Freeze Q-network so you don't waste computational effort 
    # computing gradients for it during the policy learning step.
    for p in ac.q.parameters():
        p.requires_grad = False

    # Next run one gradient descent step for pi.
    pi_optimizer.zero_grad()
    loss_pi = compute_loss_pi(data)
    loss_pi.backward()
    pi_optimizer.step()

    # Unfreeze Q-network so you can optimize it at next DDPG step.
    for p in ac.q.parameters():
        p.requires_grad = True



    # Finally, update target networks by polyak averaging.
    with torch.no_grad():
        for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
            # NB: We use an in-place operations "mul_", "add_" to update target
            # params, as opposed to "mul" and "add", which would make new tensors.
            p_targ.data.mul_(params["polyak"])
            p_targ.data.add_((1 - params["polyak"]) * p.data)


In [13]:



def get_action(o, noise_scale):
    a = ac.act(torch.as_tensor(o, dtype=torch.float32))
    a += noise_scale * torch.randn(act_dim).to(device)
    return torch.clip(a, -act_limit, act_limit)

def test_agent(epoch):
    test_main = 0
    test_step = 0
    for j in range(params["num_test_episodes"]):
        obs, d, test_ep_ret, test_ep_len = test_env.reset(), False, 0, 0
        o = list(obs['robot0_proprio-state']) + list(obs['object-state'])
        o = torch.tensor([o], dtype=torch.float32, device=device)
        while not(d or (test_ep_len == params["max_ep_len"])):
            a_cpu = get_action(o, 0).cpu().data.numpy()
            obs, r, d, _ = test_env.step(a_cpu[0])
            o = list(obs['robot0_proprio-state']) + list(obs['object-state'])
            o = torch.tensor([o], dtype=torch.float32, device=device)
            test_ep_ret += r
            test_ep_len += 1
        test_ep_main = test_ep_ret/test_ep_len
        test_step +=1
        test_main += test_ep_main
    print('test_rew_main = ', float(test_main/test_step))
    nep_log["test/reward"].log(test_main/test_step)
    
def video_agent(epoch):
    obs, d, test_ep_len = video_env.reset(), False, 0
    o = list(obs['robot0_proprio-state']) + list(obs['object-state'])
    o = torch.tensor([o], dtype=torch.float32, device=device)
    now = datetime.now()
    current_time = str(now.isoformat())
    writer = imageio.get_writer(
        "/home/xhnfly/Cosmic_rays_X/X_Robot/robosuite/robosuite/demos/video/DDPG_UR5_%s_ep_%d.mp4" % (current_time, epoch), fps=100)
    frame = obs["Labviewer_image"]
    writer.append_data(frame)

    while not(d or (test_ep_len == params["max_video_len"])):
        a_cpu = get_action(o, 0).cpu().data.numpy()
        obs, _, d, _ = video_env.step(a_cpu[0])
        o = list(obs['robot0_proprio-state']) + list(obs['object-state'])
        o = torch.tensor([o], dtype=torch.float32, device=device)
        frame = obs["Labviewer_image"]
        writer.append_data(frame)
        test_ep_len += 1
    writer.close()
    nep_log['video'] = neptune.types.File('/home/xhnfly/Cosmic_rays_X/X_Robot/robosuite/robosuite/demos/video/DDPG_UR5_%s_ep_%d.mp4' % (current_time, epoch))





In [14]:
# obs = {
#     'robot0_joint_pos_cos': None,
#     'robot0_joint_pos_sin': None,
#     'robot0_joint_vel': None,
#     'robot0_eef_pos': None,
#     'robot0_eef_quat': None,
#     'robot0_gripper_qpos': None,
#     'robot0_gripper_qvel': None,
#     'cubeA_pos': None,
#     'cubeA_quat': None,
#     'cubeB_pos': None,
#     'cubeB_quat': None,
#     'gripper_to_cubeA': None,
#     'gripper_to_cubeB': None,
#     'cubeA_to_cubeB': None,
# }

obs, ep_ret, ep_len = env.reset(), 0, 0

o = list(obs['robot0_proprio-state']) + list(obs['object-state'])

# env.viewer.set_camera(camera_id=0)


# Define neutral value
neutral = np.zeros(7)

# Keep track of done variable to know when to break loop

# Prepare for interaction with environment
total_steps = params["steps_per_epoch"] * params["epochs"]
start_time = time.time()

o = torch.tensor([o], device=device)


start_time_rec = datetime.now()
r_true = 0
total_main = 0
ep_rew_main = 0
reward_dict={}

In [15]:
# Main loop: collect experience in env and update/log each epoch
low, high = env.action_spec

for t in tqdm(range(total_steps)):
    
    # Until start_steps have elapsed, randomly sample actions
    # from a uniform distribution for better exploration. Afterwards, 
    # use the learned policy (with some noise, via act_noise). 
    if t > params["start_steps"]:
        a = get_action(o, params["act_noise"])      # Tensor
    else:
        a = torch.tensor([np.random.uniform(low, high)], dtype=torch.float32, device=device)
        
    a_cpu = a.cpu().data.numpy()
    # Step the env
    obs2, r, d, _ = env.step(a_cpu[0])
    
    o2 = list(obs2['robot0_proprio-state']) + list(obs2['object-state'])
    # print('len(o2) = ', len(o2))

    ep_len += 1
    total_main += r


    # Ignore the "done" signal if it comes from hitting the time
    # horizon (that is, when it's an artificial terminal signal
    # that isn't based on the agent's state)
    d = False if ep_len==params["max_ep_len"] else d

    o2 = torch.tensor([o2], dtype=torch.float32, device=device)
    r = torch.tensor([r], dtype=torch.float32, device=device)
    d = torch.tensor([d], dtype=torch.float32, device=device)

    # Store experience to replay buffer
    memory.push(o, a, r, o2, d)
    nep_log["train/o"].log(o)
    nep_log["train/a"].log(a)
    nep_log["train/r"].log(r)
    nep_log["train/o2"].log(o2)
    nep_log["train/d"].log(d)

    # Super critical, easy to overlook step: make sure to update 
    # most recent observation!
    o=o2
    ep_ret += r
    
    
    # End of trajectory handling
    if d or (ep_len == params["max_ep_len"]):
        ep_rew = ep_ret/ep_len
        obs, ep_ret, ep_len = env.reset(), 0, 0
        o = list(obs['robot0_proprio-state']) + list(obs['object-state'])
        o = torch.tensor([o], device=device)


    # Update handling
    if t >= params["update_after"] and t % params["update_every"] == 0:
        for i in range(params["update_every"]):

            transitions = memory.sample(params["batch_size"])
            # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
            # detailed explanation). This converts batch-array of Transitions
            # to Transition of batch-arrays.
            batch = Transition(*zip(*transitions))
            update(data=batch)

    # End of epoch handling
    if (t+1) % params["steps_per_epoch"] == 0:
        epoch = (t+1) // params["steps_per_epoch"]
        ep_rew_main = ep_rew

        nep_log["train/reward"].log(ep_rew_main)
        print('ep_rew_main = ', ep_rew_main.cpu().data.numpy())
        # Test the performance of the deterministic version of the agent.
        test_agent(epoch)
        

    if (t+1) % params["steps_video"] == 0:
        epoch = (t+1) // params["steps_per_epoch"]
        now = datetime.now()
        current_time = str(now.isoformat())
        print('current_time = ', current_time)
        video_agent(epoch)
        now = datetime.now()
        current_time = str(now.isoformat())
        print('current_time = ', current_time)

    if (t+1) % params["save_model_len"] == 0:
        epoch = (t+1) // params["steps_per_epoch"]
        now = datetime.now()
        current_time = str(now.isoformat())
        torch.save({
                    'model of ac.q': ac.q.state_dict(),
                    'model of ac.pi': ac.pi.state_dict(),
                    'q_optimizer_state_dict': q_optimizer.state_dict(),
                    'pi_optimizer_state_dict': pi_optimizer.state_dict(),
                    
                    }, "model_nn/model_nn_%s%d.pt" % (current_time, epoch))


        

  1%|          | 2990/300000 [00:38<55:01, 89.97it/s]  

ep_rew_main =  [8.116835e-09]


  1%|          | 3009/300000 [02:19<181:43:59,  2.20s/it]

test_rew_main =  7.978789546945197e-08


  2%|▏         | 5996/300000 [03:09<1:00:54, 80.46it/s]  

ep_rew_main =  [7.6302445e-07]


  2%|▏         | 6000/300000 [05:36<482:42:01,  5.91s/it]

test_rew_main =  7.316223477488425e-06


  3%|▎         | 8994/300000 [06:49<1:03:31, 76.35it/s]  

ep_rew_main =  [2.7725366e-07]


  3%|▎         | 9000/300000 [09:07<417:31:35,  5.17s/it]

test_rew_main =  9.925339488618221e-06


  4%|▍         | 11999/300000 [10:29<1:28:25, 54.28it/s] 

ep_rew_main =  [0.00011014]


  4%|▍         | 12000/300000 [12:43<714:57:14,  8.94s/it]

test_rew_main =  0.0005425442013856856


  5%|▍         | 14996/300000 [14:15<1:40:03, 47.47it/s]  

ep_rew_main =  [0.00045168]


  5%|▌         | 15000/300000 [16:39<709:42:39,  8.96s/it]

test_rew_main =  0.0003990717013324079


  6%|▌         | 17999/300000 [18:06<1:18:28, 59.89it/s]  

ep_rew_main =  [7.512274e-05]


  6%|▌         | 18000/300000 [20:29<651:25:37,  8.32s/it]

test_rew_main =  0.00117182467886475


  7%|▋         | 20995/300000 [21:53<1:19:44, 58.31it/s]  

ep_rew_main =  [0.00104657]


  7%|▋         | 21000/300000 [24:24<617:16:15,  7.96s/it]

test_rew_main =  0.0009437874948477826


  8%|▊         | 23997/300000 [25:52<1:16:24, 60.20it/s]  

ep_rew_main =  [5.987002e-06]


  8%|▊         | 24000/300000 [28:02<514:37:26,  6.71s/it]

test_rew_main =  0.0029330048616512975


  9%|▉         | 26997/300000 [29:27<1:10:53, 64.18it/s]  

ep_rew_main =  [0.00086409]


  9%|▉         | 27000/300000 [31:45<533:16:02,  7.03s/it]

test_rew_main =  0.000601662182953737


 10%|▉         | 29996/300000 [33:10<1:17:49, 57.82it/s]  

ep_rew_main =  [0.00878541]


 10%|▉         | 29996/300000 [33:21<1:17:49, 57.82it/s]

test_rew_main =  0.0023408060605306278
current_time =  2022-02-02T14:59:08.923028


 10%|█         | 30000/300000 [36:46<898:55:31, 11.99s/it]

current_time =  2022-02-02T15:00:34.088989


 11%|█         | 32993/300000 [38:11<1:07:54, 65.53it/s]  

ep_rew_main =  [1.4000696e-05]


 11%|█         | 33000/300000 [40:16<390:30:55,  5.27s/it]

test_rew_main =  0.003738326134177021


 12%|█▏        | 35992/300000 [41:39<57:50, 76.07it/s]    

ep_rew_main =  [2.3337474e-05]


 12%|█▏        | 36000/300000 [43:40<306:08:37,  4.17s/it]

test_rew_main =  0.0025846272305685594


 13%|█▎        | 38999/300000 [45:01<57:42, 75.38it/s]    

ep_rew_main =  [4.1424053e-05]


 13%|█▎        | 39000/300000 [46:59<397:46:22,  5.49s/it]

test_rew_main =  0.0019964958228432727


 14%|█▍        | 41997/300000 [48:18<1:06:38, 64.52it/s]  

ep_rew_main =  [1.8854164e-08]


 14%|█▍        | 42000/300000 [50:16<396:20:42,  5.53s/it]

test_rew_main =  1.7876476904404858e-07


 15%|█▍        | 44995/300000 [51:40<1:10:29, 60.29it/s]  

ep_rew_main =  [6.2044787e-06]


 15%|█▌        | 45000/300000 [53:40<398:42:00,  5.63s/it]

test_rew_main =  4.7225908907821963e-07


 16%|█▌        | 47998/300000 [55:00<1:04:33, 65.05it/s]  

ep_rew_main =  [2.9094595e-06]


 16%|█▌        | 48000/300000 [56:52<383:22:02,  5.48s/it]

test_rew_main =  1.1375700755617003e-06


 17%|█▋        | 50999/300000 [58:15<1:12:32, 57.21it/s]  

ep_rew_main =  [0.00030748]


 17%|█▋        | 51000/300000 [1:00:13<475:26:27,  6.87s/it]

test_rew_main =  3.849427039915591e-05


 18%|█▊        | 53999/300000 [1:01:41<1:06:16, 61.87it/s]  

ep_rew_main =  [5.520126e-08]


 18%|█▊        | 54000/300000 [1:03:42<477:58:26,  6.99s/it]

test_rew_main =  0.0004487004564448143


 19%|█▉        | 56995/300000 [1:05:06<1:14:33, 54.33it/s]  

ep_rew_main =  [0.00045868]


 19%|█▉        | 57000/300000 [1:07:04<420:44:30,  6.23s/it]

test_rew_main =  0.0010776260530689816


 20%|█▉        | 59996/300000 [1:08:23<51:34, 77.55it/s]    

ep_rew_main =  [4.6959626e-06]


 20%|█▉        | 59996/300000 [1:08:43<51:34, 77.55it/s]

test_rew_main =  0.000471975196257476
current_time =  2022-02-02T15:34:06.633220


 20%|██        | 60000/300000 [1:11:41<516:09:28,  7.74s/it]

current_time =  2022-02-02T15:35:29.502459


 21%|██        | 62997/300000 [1:13:05<1:00:57, 64.80it/s]  

ep_rew_main =  [7.7933414e-08]


 21%|██        | 63000/300000 [1:15:08<416:39:32,  6.33s/it]

test_rew_main =  0.0016917401008619674


 22%|██▏       | 65995/300000 [1:16:35<1:08:38, 56.81it/s]  

ep_rew_main =  [8.576234e-07]


 22%|██▏       | 66000/300000 [1:18:36<370:12:01,  5.70s/it]

test_rew_main =  0.0022311079103641013


 23%|██▎       | 68999/300000 [1:20:00<1:21:44, 47.10it/s]  

ep_rew_main =  [1.9282169e-07]


 23%|██▎       | 69000/300000 [1:22:02<553:33:27,  8.63s/it]

test_rew_main =  0.0006571817852336156


 24%|██▍       | 71995/300000 [1:23:31<1:11:52, 52.87it/s]  

ep_rew_main =  [0.00308572]


 24%|██▍       | 72000/300000 [1:25:34<412:58:24,  6.52s/it]

test_rew_main =  6.273269783152832e-05


 25%|██▍       | 74995/300000 [1:26:58<1:03:13, 59.32it/s]  

ep_rew_main =  [1.2526738e-05]


 25%|██▌       | 75000/300000 [1:28:52<346:47:23,  5.55s/it]

test_rew_main =  1.1176983070350975e-05


 26%|██▌       | 77994/300000 [1:30:22<1:01:02, 60.62it/s]  

ep_rew_main =  [4.772797e-09]


 26%|██▌       | 78000/300000 [1:32:31<358:30:44,  5.81s/it]

test_rew_main =  1.9329537585758977e-05


 27%|██▋       | 80996/300000 [1:34:00<1:05:59, 55.31it/s]  

ep_rew_main =  [7.540485e-08]


 27%|██▋       | 81000/300000 [1:35:57<394:22:21,  6.48s/it]

test_rew_main =  9.787157588803967e-05


 28%|██▊       | 83994/300000 [1:37:16<47:35, 75.64it/s]    

ep_rew_main =  [3.9871047e-05]


 28%|██▊       | 84000/300000 [1:39:16<255:37:33,  4.26s/it]

test_rew_main =  3.0752502030287896e-07


 29%|██▉       | 86997/300000 [1:40:40<54:34, 65.05it/s]    

ep_rew_main =  [1.2295071e-07]


 29%|██▉       | 87000/300000 [1:42:30<331:04:37,  5.60s/it]

test_rew_main =  4.7072017651216273e-07


 30%|██▉       | 89999/300000 [1:43:53<48:05, 72.77it/s]    

ep_rew_main =  [1.2412696e-07]


 30%|██▉       | 89999/300000 [1:44:05<48:05, 72.77it/s]

test_rew_main =  1.4632563685909261e-06
current_time =  2022-02-02T16:09:48.391076


 30%|███       | 90000/300000 [1:47:25<629:58:44, 10.80s/it]

current_time =  2022-02-02T16:11:13.285017


 31%|███       | 92994/300000 [1:48:52<56:57, 60.58it/s]    

ep_rew_main =  [4.9751856e-07]


 31%|███       | 93000/300000 [1:51:25<395:06:00,  6.87s/it]

test_rew_main =  4.881741144093621e-06


 32%|███▏      | 95994/300000 [1:52:56<1:05:25, 51.97it/s]  

ep_rew_main =  [6.385345e-10]


 32%|███▏      | 96000/300000 [1:55:22<414:40:04,  7.32s/it]

test_rew_main =  7.010779753359887e-06


 33%|███▎      | 98994/300000 [1:56:56<45:31, 73.60it/s]    

ep_rew_main =  [5.719931e-07]


 33%|███▎      | 99000/300000 [1:59:14<297:14:53,  5.32s/it]

test_rew_main =  6.527648566663125e-05


 34%|███▍      | 101993/300000 [2:00:41<44:45, 73.73it/s]   

ep_rew_main =  [1.5104629e-09]


 34%|███▍      | 102000/300000 [2:02:42<240:28:07,  4.37s/it]

test_rew_main =  0.0011057380588087268


 35%|███▍      | 104999/300000 [2:04:10<45:56, 70.74it/s]    

ep_rew_main =  [1.8536637e-08]


 35%|███▌      | 105000/300000 [2:06:15<344:41:10,  6.36s/it]

test_rew_main =  8.587494602738554e-06


 36%|███▌      | 107994/300000 [2:07:44<57:26, 55.70it/s]    

ep_rew_main =  [2.4892672e-09]


 36%|███▌      | 108000/300000 [2:10:00<361:47:59,  6.78s/it]

test_rew_main =  5.4464606776600914e-05


 37%|███▋      | 110996/300000 [2:11:24<59:09, 53.25it/s]    

ep_rew_main =  [1.359622e-09]


 37%|███▋      | 111000/300000 [2:13:26<356:35:01,  6.79s/it]

test_rew_main =  0.0006246761768369291


 38%|███▊      | 113992/300000 [2:14:51<43:10, 71.80it/s]    

ep_rew_main =  [0.01280551]


 38%|███▊      | 114000/300000 [2:16:55<242:14:57,  4.69s/it]

test_rew_main =  0.000947682576958715


 39%|███▉      | 116996/300000 [2:18:20<44:54, 67.91it/s]    

ep_rew_main =  [9.5005795e-11]


 39%|███▉      | 117000/300000 [2:20:14<255:34:12,  5.03s/it]

test_rew_main =  2.645921638664617e-06


 40%|███▉      | 119996/300000 [2:21:40<53:31, 56.05it/s]    

ep_rew_main =  [1.47239145e-08]


 40%|███▉      | 119996/300000 [2:21:57<53:31, 56.05it/s]

test_rew_main =  0.008483549004963963
current_time =  2022-02-02T16:47:17.298362


 40%|████      | 120000/300000 [2:24:51<514:06:41, 10.28s/it]

current_time =  2022-02-02T16:48:39.506357


 41%|████      | 122995/300000 [2:26:16<39:32, 74.61it/s]    

ep_rew_main =  [2.1914155e-11]


 41%|████      | 123000/300000 [2:28:08<215:24:54,  4.38s/it]

test_rew_main =  3.605080667736349e-07


 42%|████▏     | 125998/300000 [2:29:39<1:02:28, 46.41it/s]  

ep_rew_main =  [0.0001249]


 42%|████▏     | 126000/300000 [2:31:46<451:00:38,  9.33s/it]

test_rew_main =  0.0010720599689472649


 43%|████▎     | 128998/300000 [2:33:22<36:22, 78.35it/s]    

ep_rew_main =  [1.3991919e-06]


 43%|████▎     | 129000/300000 [2:35:39<293:23:41,  6.18s/it]

test_rew_main =  0.00034723959488817366


 44%|████▍     | 131999/300000 [2:37:23<1:04:14, 43.59it/s]  

ep_rew_main =  [8.8547175e-05]


 44%|████▍     | 132000/300000 [2:40:07<605:56:28, 12.98s/it]

test_rew_main =  2.1141223838793176e-05


 45%|████▍     | 134998/300000 [2:41:49<51:12, 53.70it/s]    

ep_rew_main =  [4.8935144e-06]


 45%|████▌     | 135000/300000 [2:44:38<484:07:41, 10.56s/it]

test_rew_main =  4.7903627400441474e-05


 46%|████▌     | 137999/300000 [2:46:13<55:48, 48.38it/s]    

ep_rew_main =  [0.00303296]


 46%|████▌     | 138000/300000 [2:48:25<410:03:45,  9.11s/it]

test_rew_main =  0.0005033417920215597


 47%|████▋     | 140998/300000 [2:50:01<38:37, 68.60it/s]    

ep_rew_main =  [6.5210246e-05]


 47%|████▋     | 141000/300000 [2:52:09<250:29:43,  5.67s/it]

test_rew_main =  0.0009403099013782246


 48%|████▊     | 143995/300000 [2:53:42<39:34, 65.70it/s]    

ep_rew_main =  [1.7691495e-07]


 48%|████▊     | 144000/300000 [2:55:34<218:40:18,  5.05s/it]

test_rew_main =  0.0013079903664531126


 49%|████▉     | 146992/300000 [2:57:04<35:18, 72.21it/s]    

ep_rew_main =  [1.1361033e-06]


 49%|████▉     | 147000/300000 [2:58:56<167:36:10,  3.94s/it]

test_rew_main =  0.0008967263500225456


 50%|████▉     | 149997/300000 [3:00:23<44:12, 56.56it/s]    

ep_rew_main =  [0.02896586]


 50%|████▉     | 149997/300000 [3:00:38<44:12, 56.56it/s]

test_rew_main =  0.00019347290393412293
current_time =  2022-02-02T17:25:56.631644


 50%|█████     | 150000/300000 [3:03:27<396:27:38,  9.52s/it]

current_time =  2022-02-02T17:27:15.407865


 51%|█████     | 152998/300000 [3:04:57<43:37, 56.16it/s]    

ep_rew_main =  [0.0009605]


 51%|█████     | 153000/300000 [3:06:48<248:59:58,  6.10s/it]

test_rew_main =  0.0022382576982879604


 52%|█████▏    | 155993/300000 [3:08:18<38:01, 63.13it/s]    

ep_rew_main =  [3.399361e-10]


 52%|█████▏    | 156000/300000 [3:10:06<181:45:33,  4.54s/it]

test_rew_main =  0.00034261907928324675


 53%|█████▎    | 158997/300000 [3:11:36<33:36, 69.92it/s]    

ep_rew_main =  [4.9603436e-06]


 53%|█████▎    | 159000/300000 [3:13:35<207:47:26,  5.31s/it]

test_rew_main =  0.0002773481824996015


 54%|█████▍    | 161992/300000 [3:15:08<31:38, 72.68it/s]    

ep_rew_main =  [0.00338793]


 54%|█████▍    | 162000/300000 [3:16:59<146:50:41,  3.83s/it]

test_rew_main =  0.0011930634627327487


 55%|█████▍    | 164997/300000 [3:18:34<41:27, 54.27it/s]    

ep_rew_main =  [7.9801384e-07]


 55%|█████▌    | 165000/300000 [3:20:36<270:58:08,  7.23s/it]

test_rew_main =  0.00029287242855847946


 56%|█████▌    | 167993/300000 [3:22:06<30:51, 71.31it/s]    

ep_rew_main =  [5.495324e-06]


 56%|█████▌    | 168000/300000 [3:23:58<150:03:28,  4.09s/it]

test_rew_main =  0.0013413282952553145


 57%|█████▋    | 170999/300000 [3:25:34<33:35, 64.01it/s]    

ep_rew_main =  [0.02416068]


 57%|█████▋    | 171000/300000 [3:27:23<224:57:26,  6.28s/it]

test_rew_main =  0.00010509411333248749


 58%|█████▊    | 173997/300000 [3:28:55<32:44, 64.13it/s]    

ep_rew_main =  [0.00010858]


 58%|█████▊    | 174000/300000 [3:30:52<212:40:31,  6.08s/it]

test_rew_main =  0.0003150801817555122


 59%|█████▉    | 176998/300000 [3:32:28<38:29, 53.25it/s]    

ep_rew_main =  [9.934589e-06]


 59%|█████▉    | 177000/300000 [3:34:17<233:20:25,  6.83s/it]

test_rew_main =  0.001726641223381241


 60%|█████▉    | 179991/300000 [3:35:51<26:41, 74.92it/s]    

ep_rew_main =  [1.1027182e-08]


 60%|█████▉    | 179991/300000 [3:36:03<26:41, 74.92it/s]

test_rew_main =  4.929439182862371e-07
current_time =  2022-02-02T18:01:32.422471


 60%|██████    | 180000/300000 [3:39:08<219:55:38,  6.60s/it]

current_time =  2022-02-02T18:02:56.573935


 61%|██████    | 182999/300000 [3:40:41<24:55, 78.23it/s]    

ep_rew_main =  [1.3706995e-07]


 61%|██████    | 183000/300000 [3:42:23<150:29:57,  4.63s/it]

test_rew_main =  0.00041374308743914864


 62%|██████▏   | 185995/300000 [3:43:55<26:44, 71.05it/s]    

ep_rew_main =  [0.00057795]


 62%|██████▏   | 186000/300000 [3:45:43<134:20:05,  4.24s/it]

test_rew_main =  5.3528990944050523e-05


 63%|██████▎   | 188998/300000 [3:47:17<26:18, 70.32it/s]    

ep_rew_main =  [7.566089e-08]


 63%|██████▎   | 189000/300000 [3:49:12<172:02:53,  5.58s/it]

test_rew_main =  0.0006829995703329212


 64%|██████▍   | 191992/300000 [3:50:47<25:19, 71.09it/s]    

ep_rew_main =  [1.6573313e-05]


 64%|██████▍   | 192000/300000 [3:52:43<131:23:08,  4.38s/it]

test_rew_main =  0.0003885311328589034


 65%|██████▍   | 194996/300000 [3:54:18<30:16, 57.79it/s]    

ep_rew_main =  [0.00046225]


 65%|██████▌   | 195000/300000 [3:56:14<172:55:12,  5.93s/it]

test_rew_main =  0.000285948441381867


 66%|██████▌   | 197999/300000 [3:57:51<26:43, 63.62it/s]    

ep_rew_main =  [3.8480035e-08]


 66%|██████▌   | 198000/300000 [3:59:41<173:10:36,  6.11s/it]

test_rew_main =  1.3603662368753689e-05


 67%|██████▋   | 200993/300000 [4:01:17<24:52, 66.36it/s]    

ep_rew_main =  [2.6157126e-08]


 67%|██████▋   | 201000/300000 [4:03:18<135:09:18,  4.91s/it]

test_rew_main =  0.0006011870442138476


 68%|██████▊   | 203999/300000 [4:04:56<24:22, 65.66it/s]    

ep_rew_main =  [0.01687321]


 68%|██████▊   | 204000/300000 [4:06:54<182:55:03,  6.86s/it]

test_rew_main =  0.0004620271644229519


 69%|██████▉   | 206993/300000 [4:08:30<23:11, 66.83it/s]    

ep_rew_main =  [1.9296094e-08]


 69%|██████▉   | 207000/300000 [4:10:31<121:33:02,  4.71s/it]

test_rew_main =  2.3368709944834754e-06


 70%|██████▉   | 209999/300000 [4:12:08<23:50, 62.92it/s]    

ep_rew_main =  [1.8620902e-07]


 70%|██████▉   | 209999/300000 [4:12:26<23:50, 62.92it/s]

test_rew_main =  0.00033867791885156256
current_time =  2022-02-02T18:37:59.016585


 70%|███████   | 210000/300000 [4:15:26<251:56:28, 10.08s/it]

current_time =  2022-02-02T18:39:13.844316


 71%|███████   | 212994/300000 [4:17:01<20:32, 70.57it/s]    

ep_rew_main =  [1.07448095e-05]


 71%|███████   | 213000/300000 [4:18:57<113:57:27,  4.72s/it]

test_rew_main =  0.012309189029353247


 72%|███████▏  | 215998/300000 [4:20:33<19:50, 70.59it/s]    

ep_rew_main =  [0.00022998]


 72%|███████▏  | 216000/300000 [4:22:17<107:00:57,  4.59s/it]

test_rew_main =  0.004815194143161475


 73%|███████▎  | 218994/300000 [4:23:56<18:39, 72.33it/s]    

ep_rew_main =  [0.00515173]


 73%|███████▎  | 219000/300000 [4:25:38<85:12:19,  3.79s/it]

test_rew_main =  0.005874069586473843


 74%|███████▍  | 221991/300000 [4:27:13<18:10, 71.57it/s]   

ep_rew_main =  [1.318877e-06]


 74%|███████▍  | 222000/300000 [4:29:01<78:43:34,  3.63s/it]

test_rew_main =  0.0007533234236343498


 75%|███████▍  | 224997/300000 [4:30:38<18:51, 66.29it/s]   

ep_rew_main =  [0.00618156]


 75%|███████▌  | 225000/300000 [4:32:34<112:43:07,  5.41s/it]

test_rew_main =  0.005685077845413703


 76%|███████▌  | 227991/300000 [4:34:11<17:06, 70.17it/s]    

ep_rew_main =  [0.00202044]


 76%|███████▌  | 228000/300000 [4:36:04<75:29:36,  3.77s/it]

test_rew_main =  0.0014425072222599522


 77%|███████▋  | 230999/300000 [4:37:43<18:46, 61.26it/s]   

ep_rew_main =  [0.01742379]


 77%|███████▋  | 231000/300000 [4:39:32<109:56:35,  5.74s/it]

test_rew_main =  0.021162904476819434


 78%|███████▊  | 233998/300000 [4:41:10<15:37, 70.43it/s]    

ep_rew_main =  [0.00187111]


 78%|███████▊  | 234000/300000 [4:43:01<98:25:58,  5.37s/it]

test_rew_main =  0.00725342371529811


 79%|███████▉  | 236994/300000 [4:44:43<19:46, 53.11it/s]   

ep_rew_main =  [0.00058563]


 79%|███████▉  | 237000/300000 [4:46:33<96:25:24,  5.51s/it]

test_rew_main =  0.0030188235998980266


 80%|███████▉  | 239998/300000 [4:48:24<19:43, 50.69it/s]   

ep_rew_main =  [2.4894283e-07]


 80%|███████▉  | 239998/300000 [4:48:39<19:43, 50.69it/s]

test_rew_main =  0.0027771202352339557
current_time =  2022-02-02T19:14:14.004634


 80%|████████  | 240000/300000 [4:51:48<213:01:30, 12.78s/it]

current_time =  2022-02-02T19:15:36.682750


 81%|████████  | 242999/300000 [4:53:30<14:40, 64.71it/s]    

ep_rew_main =  [0.01504344]


 81%|████████  | 243000/300000 [4:56:15<133:33:09,  8.43s/it]

test_rew_main =  0.0029356768916973626


 82%|████████▏ | 245999/300000 [4:58:04<20:49, 43.23it/s]    

ep_rew_main =  [6.982467e-07]


 82%|████████▏ | 246000/300000 [5:00:33<176:37:11, 11.77s/it]

test_rew_main =  0.0002369745140748383


 83%|████████▎ | 248993/300000 [5:02:26<14:42, 57.83it/s]    

ep_rew_main =  [0.01842669]


 83%|████████▎ | 249000/300000 [5:04:48<86:34:53,  6.11s/it]

test_rew_main =  0.011924124926156004


 84%|████████▍ | 251995/300000 [5:06:42<12:15, 65.25it/s]   

ep_rew_main =  [0.04129611]


 84%|████████▍ | 252000/300000 [5:08:47<71:12:32,  5.34s/it]

test_rew_main =  0.000825411120836584


 85%|████████▍ | 254994/300000 [5:10:32<10:43, 69.98it/s]   

ep_rew_main =  [2.141562e-09]


 85%|████████▌ | 255000/300000 [5:12:13<50:38:14,  4.05s/it]

test_rew_main =  4.823832078304502e-09


 86%|████████▌ | 257996/300000 [5:13:53<09:24, 74.37it/s]   

ep_rew_main =  [9.368295e-10]


 86%|████████▌ | 258000/300000 [5:15:30<45:30:36,  3.90s/it]

test_rew_main =  1.9375096707057615e-05


 87%|████████▋ | 260996/300000 [5:17:08<08:43, 74.52it/s]   

ep_rew_main =  [2.9516248e-10]


 87%|████████▋ | 261000/300000 [5:18:43<41:20:21,  3.82s/it]

test_rew_main =  5.057292778587464e-09


 88%|████████▊ | 263998/300000 [5:20:23<13:15, 45.26it/s]   

ep_rew_main =  [2.750187e-08]


 88%|████████▊ | 264000/300000 [5:22:02<43:08:10,  4.31s/it]

test_rew_main =  5.822935479217962e-05


 89%|████████▉ | 266998/300000 [5:23:44<08:22, 65.62it/s]   

ep_rew_main =  [2.532644e-05]


 89%|████████▉ | 267000/300000 [5:25:16<41:15:52,  4.50s/it]

test_rew_main =  5.515834524671969e-09


 90%|████████▉ | 269991/300000 [5:27:03<07:07, 70.19it/s]   

ep_rew_main =  [1.4897468e-09]


 90%|████████▉ | 269991/300000 [5:27:21<07:07, 70.19it/s]

test_rew_main =  0.0005830196312047751
current_time =  2022-02-02T19:52:43.431666


 90%|█████████ | 270000/300000 [5:30:17<53:51:39,  6.46s/it]

current_time =  2022-02-02T19:54:05.014590


 91%|█████████ | 272993/300000 [5:32:07<06:31, 68.95it/s]   

ep_rew_main =  [7.0386846e-10]


 91%|█████████ | 273000/300000 [5:33:58<32:37:44,  4.35s/it]

test_rew_main =  7.531446111771471e-05


 92%|█████████▏| 275998/300000 [5:35:50<06:12, 64.49it/s]   

ep_rew_main =  [6.857297e-05]


 92%|█████████▏| 276000/300000 [5:37:46<38:00:54,  5.70s/it]

test_rew_main =  0.0161649871882733


 93%|█████████▎| 278996/300000 [5:39:38<05:20, 65.60it/s]   

ep_rew_main =  [0.00092527]


 93%|█████████▎| 279000/300000 [5:41:52<34:17:53,  5.88s/it]

test_rew_main =  0.00019782101688900047


 94%|█████████▍| 281994/300000 [5:43:47<04:39, 64.38it/s]   

ep_rew_main =  [2.0410836e-05]


 94%|█████████▍| 282000/300000 [5:46:05<28:05:14,  5.62s/it]

test_rew_main =  0.0009070309697766928


 95%|█████████▍| 284993/300000 [5:47:59<03:42, 67.55it/s]   

ep_rew_main =  [0.00511883]


 95%|█████████▌| 285000/300000 [5:50:01<19:51:44,  4.77s/it]

test_rew_main =  0.0006900186668521932


 96%|█████████▌| 287992/300000 [5:51:53<02:56, 67.97it/s]   

ep_rew_main =  [0.00690101]


 96%|█████████▌| 288000/300000 [5:54:14<17:37:38,  5.29s/it]

test_rew_main =  0.00031915707391694655


 97%|█████████▋| 290996/300000 [5:56:12<02:17, 65.55it/s]   

ep_rew_main =  [0.00498093]


 97%|█████████▋| 291000/300000 [5:58:34<15:36:57,  6.25s/it]

test_rew_main =  0.0026710750259626498


 98%|█████████▊| 293993/300000 [6:00:31<01:31, 65.48it/s]   

ep_rew_main =  [1.507676e-06]


 98%|█████████▊| 294000/300000 [6:02:38<8:17:17,  4.97s/it]

test_rew_main =  0.044263550184293364


 99%|█████████▉| 296997/300000 [6:04:35<00:59, 50.49it/s]  

ep_rew_main =  [0.35137716]


 99%|█████████▉| 297000/300000 [6:06:51<6:32:01,  7.84s/it]

test_rew_main =  0.02397870305703833


100%|█████████▉| 299995/300000 [6:08:45<00:00, 59.69it/s]  

ep_rew_main =  [0.19846305]


100%|█████████▉| 299995/300000 [6:08:57<00:00, 59.69it/s]

test_rew_main =  0.005614001077933981
current_time =  2022-02-02T20:34:33.882469


100%|██████████| 300000/300000 [6:11:58<00:00, 13.44it/s]

current_time =  2022-02-02T20:35:46.734247





In [16]:
model = ac.q
print("Model_q's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

model = ac.pi
print("Model_pi's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

Model_q's state_dict:
q.0.weight 	 torch.Size([256, 41])
q.0.bias 	 torch.Size([256])
q.2.weight 	 torch.Size([256, 256])
q.2.bias 	 torch.Size([256])
q.4.weight 	 torch.Size([256, 256])
q.4.bias 	 torch.Size([256])
q.6.weight 	 torch.Size([1, 256])
q.6.bias 	 torch.Size([1])
Model_pi's state_dict:
pi.0.weight 	 torch.Size([256, 35])
pi.0.bias 	 torch.Size([256])
pi.2.weight 	 torch.Size([256, 256])
pi.2.bias 	 torch.Size([256])
pi.4.weight 	 torch.Size([256, 256])
pi.4.bias 	 torch.Size([256])
pi.6.weight 	 torch.Size([6, 256])
pi.6.bias 	 torch.Size([6])


In [17]:
print("pi_optimizer's state_dict:")
for var_name in pi_optimizer.state_dict():
    print(var_name, "\t", pi_optimizer.state_dict()[var_name])

print("q_optimizer's state_dict:")
for var_name in q_optimizer.state_dict():
    print(var_name, "\t", q_optimizer.state_dict()[var_name])



pi_optimizer's state_dict:
state 	 {0: {'step': 295000, 'square_avg': tensor([[1.5294e-07, 4.3574e-07, 1.7078e-07,  ..., 4.6295e-08, 1.3153e-07,
         3.2821e-08],
        [1.1260e-06, 1.0042e-06, 9.5667e-07,  ..., 3.2371e-07, 2.5242e-06,
         2.0768e-07],
        [3.3766e-07, 1.8660e-07, 2.4575e-07,  ..., 6.6838e-08, 3.6149e-07,
         3.6968e-08],
        ...,
        [3.4772e-07, 6.0496e-07, 3.3804e-07,  ..., 8.6878e-08, 3.6086e-07,
         6.7247e-08],
        [1.1408e-06, 7.1713e-07, 6.1176e-07,  ..., 3.4519e-07, 1.2215e-06,
         1.1788e-07],
        [2.8419e-07, 7.1120e-07, 3.0715e-07,  ..., 6.3061e-08, 4.3581e-07,
         7.5016e-08]], device='cuda:0')}, 1: {'step': 295000, 'square_avg': tensor([5.5312e-07, 3.3224e-06, 6.3727e-07, 5.1333e-07, 8.2843e-07, 6.8615e-07,
        8.8244e-07, 1.9534e-06, 6.2446e-07, 1.3387e-06, 4.0938e-07, 1.2642e-06,
        6.8603e-07, 9.9440e-07, 1.2369e-06, 7.1836e-07, 1.1587e-06, 1.1835e-06,
        9.5166e-07, 5.0816e-07, 1.7202e-0

In [18]:
now = datetime.now()

current_time = str(now.isoformat())



torch.save({
            'model of ac.q': ac.q.state_dict(),
            'model of ac.pi': ac.pi.state_dict(),
            'q_optimizer_state_dict': q_optimizer.state_dict(),
            'pi_optimizer_state_dict': pi_optimizer.state_dict(),
            
            }, "model_nn/model_nn_%s.pt" % current_time)



In [19]:
nep_log.stop()

Shutting down background jobs, please wait a moment...
Done!


Waiting for the remaining 140 operations to synchronize with Neptune. Do not kill this process.


All 140 operations synced, thanks for waiting!
