In [1]:
#!cat ~/.bashrc

In [2]:
#!mv ../drone_dataset.pkl .

In [3]:
#!pip3 install --upgrade protobuf==3.20.0 

In [4]:
#!pip3 install transformers==4.5.1
#!pip3 install -U tokenizers
# The code below just solve many problems lol
#!pip3 uninstall tokenizers -y

In [5]:
from torch.utils.tensorboard import SummaryWriter
import argparse
import pickle
import random
import time
import gym
import d4rl
import torch
import numpy as np

import utils
from replay_buffer import ReplayBuffer
from lamb import Lamb
from stable_baselines3.common.vec_env import SubprocVecEnv
from pathlib import Path
from data import create_dataloader
from decision_transformer.models.decision_transformer import DecisionTransformer
from evaluation import create_vec_eval_episodes_fn, vec_evaluate_episode_rtg
from trainer import SequenceTrainer
from logger import Logger

from env import make_pytorch_env

MAX_EPISODE_LEN = 4000 # Warning: there is a similar variable in data.py! 

pybullet build time: May 20 2022 19:44:17


In [6]:
import sys
sys.argv = ['']

parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=10)
parser.add_argument("--env", type=str, default="drone_dataset")
#parser.add_argument("--env", type=str, default="antmaze-large-diverse-v2")

# model options
parser.add_argument("--K", type=int, default=20)
parser.add_argument("--embed_dim", type=int, default=512)
parser.add_argument("--n_layer", type=int, default=4)
parser.add_argument("--n_head", type=int, default=4)
parser.add_argument("--activation_function", type=str, default="relu")
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--eval_context_length", type=int, default=5)
# 0: no pos embedding others: absolute ordering
parser.add_argument("--ordering", type=int, default=0)

# shared evaluation options
parser.add_argument("--eval_rtg", type=int, default=3600)
parser.add_argument("--num_eval_episodes", type=int, default=10)

# shared training options
parser.add_argument("--init_temperature", type=float, default=0.1)
#parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--learning_rate", "-lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", "-wd", type=float, default=5e-4)
parser.add_argument("--warmup_steps", type=int, default=10000)

# pretraining options
parser.add_argument("--max_pretrain_iters", type=int, default=1)
parser.add_argument("--num_updates_per_pretrain_iter", type=int, default=5000)

# finetuning options
parser.add_argument("--max_online_iters", type=int, default=1500)
parser.add_argument("--online_rtg", type=int, default=7200)
parser.add_argument("--num_online_rollouts", type=int, default=1)
parser.add_argument("--replay_size", type=int, default=1000)
parser.add_argument("--num_updates_per_online_iter", type=int, default=300)
parser.add_argument("--eval_interval", type=int, default=10)

# environment options
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--log_to_tb", "-w", type=bool, default=True)
parser.add_argument("--save_dir", type=str, default="./exp")
parser.add_argument("--exp_name", type=str, default="default")

args = parser.parse_args()

In [7]:
class Experiment:
    def __init__(self, variant):

        self.state_dim, self.act_dim, self.action_range = self._get_env_spec(variant)
        self.offline_trajs, self.state_mean, self.state_std = self._load_dataset(
            variant["env"]
        )
        # initialize by offline trajs
        self.replay_buffer = ReplayBuffer(variant["replay_size"], self.offline_trajs)

        self.aug_trajs = []

        self.device = variant.get("device", "cuda")
        self.target_entropy = -self.act_dim
        self.model = DecisionTransformer(
            state_dim=self.state_dim,
            act_dim=self.act_dim,
            action_range=self.action_range,
            max_length=variant["K"],
            eval_context_length=variant["eval_context_length"],
            max_ep_len=MAX_EPISODE_LEN,
            hidden_size=variant["embed_dim"],
            n_layer=variant["n_layer"],
            n_head=variant["n_head"],
            n_inner=4 * variant["embed_dim"],
            activation_function=variant["activation_function"],
            n_positions=1024,
            resid_pdrop=variant["dropout"],
            attn_pdrop=variant["dropout"],
            stochastic_policy=True,
            ordering=variant["ordering"],
            init_temperature=variant["init_temperature"],
            target_entropy=self.target_entropy,
        ).to(device=self.device)

        self.optimizer = Lamb(
            self.model.parameters(),
            lr=variant["learning_rate"],
            weight_decay=variant["weight_decay"],
            eps=1e-8,
        )
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer, lambda steps: min((steps + 1) / variant["warmup_steps"], 1)
        )

        self.log_temperature_optimizer = torch.optim.Adam(
            [self.model.log_temperature],
            lr=1e-4,
            betas=[0.9, 0.999],
        )

        # track the training progress and
        # training/evaluation/online performance in all the iterations
        self.pretrain_iter = 0
        self.online_iter = 0
        self.total_transitions_sampled = 0
        self.variant = variant
        self.reward_scale = 1.0 if "antmaze" in variant["env"] else 0.001
        self.logger = Logger(variant)

    def _get_env_spec(self, variant):
        #####env = gym.make(variant["env"])
        env = make_pytorch_env(args)
        state_dim = env.observation_space.shape[0]
        act_dim = env.action_space.shape[0]
        #action_range = [-0.999999, 0.999999]
        
        action_range = [
            float(env.action_space.low.min()) + 1e-6,
            float(env.action_space.high.max()) - 1e-6,
        ]
        
        print("action_range: {}".format(action_range))
        env.close()
        return state_dim, act_dim, action_range

    def _save_model(self, path_prefix, is_pretrain_model=False):
        to_save = {
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
            "pretrain_iter": self.pretrain_iter,
            "online_iter": self.online_iter,
            "args": self.variant,
            "total_transitions_sampled": self.total_transitions_sampled,
            "np": np.random.get_state(),
            "python": random.getstate(),
            "pytorch": torch.get_rng_state(),
            "log_temperature_optimizer_state_dict": self.log_temperature_optimizer.state_dict(),
        }

        with open(f"{path_prefix}/model.pt", "wb") as f:
            torch.save(to_save, f)
        print(f"\nModel saved at {path_prefix}/model.pt")

        if is_pretrain_model:
            with open(f"{path_prefix}/pretrain_model.pt", "wb") as f:
                torch.save(to_save, f)
            print(f"Model saved at {path_prefix}/pretrain_model.pt")

    def _load_model(self, path_prefix):
        if Path(f"{path_prefix}/model.pt").exists():
            with open(f"{path_prefix}/model.pt", "rb") as f:
                checkpoint = torch.load(f)
            self.model.load_state_dict(checkpoint["model_state_dict"])
            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
            self.log_temperature_optimizer.load_state_dict(
                checkpoint["log_temperature_optimizer_state_dict"]
            )
            self.pretrain_iter = checkpoint["pretrain_iter"]
            self.online_iter = checkpoint["online_iter"]
            self.total_transitions_sampled = checkpoint["total_transitions_sampled"]
            np.random.set_state(checkpoint["np"])
            random.setstate(checkpoint["python"])
            torch.set_rng_state(checkpoint["pytorch"])
            print(f"Model loaded at {path_prefix}/model.pt")

    def _load_dataset(self, env_name):

        dataset_path = f"./data/{env_name}.pkl"
        with open(dataset_path, "rb") as f:
            trajectories = pickle.load(f)

        states, traj_lens, returns = [], [], []
        for path in trajectories:
            states.append(path["observations"])
            traj_lens.append(len(path["observations"]))
            returns.append(path["rewards"].sum())
        traj_lens, returns = np.array(traj_lens), np.array(returns)

        # used for input normalization
        states = np.concatenate(states, axis=0)
        state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
        num_timesteps = sum(traj_lens)

        print("=" * 50)
        print(f"Starting new experiment: {env_name}")
        print(f"{len(traj_lens)} trajectories, {num_timesteps} timesteps found")
        print(f"Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}")
        print(f"Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}")
        print(f"Average length: {np.mean(traj_lens):.2f}, std: {np.std(traj_lens):.2f}")
        print(f"Max length: {np.max(traj_lens):.2f}, min: {np.min(traj_lens):.2f}")
        print("=" * 50)

        sorted_inds = np.argsort(returns)  # lowest to highest
        num_trajectories = 1
        timesteps = traj_lens[sorted_inds[-1]]
        ind = len(trajectories) - 2
        while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] < num_timesteps:
            timesteps += traj_lens[sorted_inds[ind]]
            num_trajectories += 1
            ind -= 1
        sorted_inds = sorted_inds[-num_trajectories:]
        trajectories = [trajectories[ii] for ii in sorted_inds]

        return trajectories, state_mean, state_std

    def _augment_trajectories(
        self,
        online_envs,
        target_explore,
        n,
        randomized=False,
    ):

        max_ep_len = MAX_EPISODE_LEN

        with torch.no_grad():
            # generate init state
            target_return = [target_explore * self.reward_scale] * online_envs.num_envs

            returns, lengths, trajs = vec_evaluate_episode_rtg(
                online_envs,
                self.state_dim,
                self.act_dim,
                self.model,
                max_ep_len=max_ep_len,
                reward_scale=self.reward_scale,
                target_return=target_return,
                mode="normal",
                state_mean=self.state_mean,
                state_std=self.state_std,
                device=self.device,
                use_mean=False,
            )

        self.replay_buffer.add_new_trajs(trajs)
        self.aug_trajs += trajs
        self.total_transitions_sampled += np.sum(lengths)

        return {
            "aug_traj/return": np.mean(returns),
            "aug_traj/length": np.mean(lengths),
        }

    def pretrain(self, eval_envs, loss_fn):
        print("\n\n\n*** Pretrain ***")
        print("----------------")
        print("eval_envs: {}".format(eval_envs))
        print("loss_fn: {}".format(loss_fn))
        
        eval_fns = [
            create_vec_eval_episodes_fn(
                vec_env=eval_envs,
                eval_rtg=self.variant["eval_rtg"],
                state_dim=self.state_dim,
                act_dim=self.act_dim,
                state_mean=self.state_mean,
                state_std=self.state_std,
                device=self.device,
                use_mean=True,
                reward_scale=self.reward_scale,
            )
        ]

        trainer = SequenceTrainer(
            model=self.model,
            optimizer=self.optimizer,
            log_temperature_optimizer=self.log_temperature_optimizer,
            scheduler=self.scheduler,
            device=self.device,
        )

        writer = (
            SummaryWriter(self.logger.log_path) if self.variant["log_to_tb"] else None
        )
        while self.pretrain_iter < self.variant["max_pretrain_iters"]:
            # in every iteration, prepare the data loader
            dataloader = create_dataloader(
                trajectories=self.offline_trajs,
                num_iters=self.variant["num_updates_per_pretrain_iter"],
                batch_size=self.variant["batch_size"],
                max_len=self.variant["K"],
                state_dim=self.state_dim,
                act_dim=self.act_dim,
                state_mean=self.state_mean,
                state_std=self.state_std,
                reward_scale=self.reward_scale,
                action_range=self.action_range,
            )

            train_outputs = trainer.train_iteration(
                loss_fn=loss_fn,
                dataloader=dataloader,
            )
            eval_outputs, eval_reward = self.evaluate(eval_fns)
            outputs = {"time/total": time.time() - self.start_time}
            outputs.update(train_outputs)
            outputs.update(eval_outputs)
            self.logger.log_metrics(
                outputs,
                iter_num=self.pretrain_iter,
                total_transitions_sampled=self.total_transitions_sampled,
                writer=writer,
            )

            self._save_model(
                path_prefix=self.logger.log_path,
                is_pretrain_model=True,
            )

            self.pretrain_iter += 1

    def evaluate(self, eval_fns):
        eval_start = time.time()
        self.model.eval()
        outputs = {}
        for eval_fn in eval_fns:
            o = eval_fn(self.model)
            outputs.update(o)
        outputs["time/evaluation"] = time.time() - eval_start

        eval_reward = outputs["evaluation/return_mean_gm"]
        return outputs, eval_reward

    def online_tuning(self, online_envs, eval_envs, loss_fn):

        print("\n\n\n*** Online Finetuning ***")

        trainer = SequenceTrainer(
            model=self.model,
            optimizer=self.optimizer,
            log_temperature_optimizer=self.log_temperature_optimizer,
            scheduler=self.scheduler,
            device=self.device,
        )
        eval_fns = [
            create_vec_eval_episodes_fn(
                vec_env=eval_envs,
                eval_rtg=self.variant["eval_rtg"],
                state_dim=self.state_dim,
                act_dim=self.act_dim,
                state_mean=self.state_mean,
                state_std=self.state_std,
                device=self.device,
                use_mean=True,
                reward_scale=self.reward_scale,
            )
        ]
        writer = (
            SummaryWriter(self.logger.log_path) if self.variant["log_to_tb"] else None
        )
        while self.online_iter < self.variant["max_online_iters"]:

            outputs = {}
            augment_outputs = self._augment_trajectories(
                online_envs,
                self.variant["online_rtg"],
                n=self.variant["num_online_rollouts"],
            )
            outputs.update(augment_outputs)

            dataloader = create_dataloader(
                trajectories=self.replay_buffer.trajectories,
                num_iters=self.variant["num_updates_per_online_iter"],
                batch_size=self.variant["batch_size"],
                max_len=self.variant["K"],
                state_dim=self.state_dim,
                act_dim=self.act_dim,
                state_mean=self.state_mean,
                state_std=self.state_std,
                reward_scale=self.reward_scale,
                action_range=self.action_range,
            )

            # finetuning
            is_last_iter = self.online_iter == self.variant["max_online_iters"] - 1
            if (self.online_iter + 1) % self.variant[
                "eval_interval"
            ] == 0 or is_last_iter:
                evaluation = True
            else:
                evaluation = False

            train_outputs = trainer.train_iteration(
                loss_fn=loss_fn,
                dataloader=dataloader,
            )
            outputs.update(train_outputs)

            if evaluation:
                eval_outputs, eval_reward = self.evaluate(eval_fns)
                outputs.update(eval_outputs)

            outputs["time/total"] = time.time() - self.start_time

            # log the metrics
            self.logger.log_metrics(
                outputs,
                iter_num=self.pretrain_iter + self.online_iter,
                total_transitions_sampled=self.total_transitions_sampled,
                writer=writer,
            )

            self._save_model(
                path_prefix=self.logger.log_path,
                is_pretrain_model=False,
            )

            self.online_iter += 1

    def __call__(self):

        utils.set_seed_everywhere(args.seed)

        import d4rl

        def loss_fn(
            a_hat_dist,     # action_preds
            a,              # action_target
            attention_mask, # padding_mask
            entropy_reg,    # self.model.temperature().detach()
        ):
            # a_hat is a SquashedNormal Distribution
            log_likelihood = a_hat_dist.log_likelihood(a)[attention_mask > 0].mean()
            
            entropy = a_hat_dist.entropy().mean()
            loss = -(log_likelihood + entropy_reg * entropy)
            
            '''
            print("a_hat_dist : {}".format(a_hat_dist))
            print("a : {}".format(a))
            torch.save(a,"a.pt")
            print("a_hat_dist.log_likelihood(a) : {}".format(a_hat_dist.log_likelihood(a)))
            #print("attention_mask : {}".format(attention_mask))
            print("log_likelihood: {}".format(log_likelihood))
            print("loss inside jupyter: {} of type: {}".format(loss,type(loss)))
            '''
            
            return (
                loss,
                -log_likelihood,
                entropy,
            )

        def get_env_builder(seed, env_name, target_goal=None):
            def make_env_fn():
                import d4rl

                #####env = gym.make(env_name)
                env = make_pytorch_env(args)
                env.seed(seed)
                '''
                if hasattr(env.env, "wrapped_env"):
                    env.env.wrapped_env.seed(seed)
                elif hasattr(env.env, "seed"):
                    env.env.seed(seed)
                else:
                    pass
                '''
                '''
                env.action_space.seed(seed)
                env.observation_space.seed(seed)
                '''

                if target_goal:
                    env.set_target_goal(target_goal)
                    print(f"Set the target goal to be {env.target_goal}")
                return env

            return make_env_fn

        print("\n\nMaking Eval Env.....")
        env_name = self.variant["env"]
        if "antmaze" in env_name:
            env = gym.make(env_name)
            target_goal = env.target_goal
            env.close()
            print(f"Generated the fixed target goal: {target_goal}")
        else:
            target_goal = None
        eval_envs = SubprocVecEnv(
            [
                get_env_builder(i, env_name=env_name, target_goal=target_goal)
                for i in range(self.variant["num_eval_episodes"])
            ]
        )

        self.start_time = time.time()
        if self.variant["max_pretrain_iters"]:
            self.pretrain(eval_envs, loss_fn)

        if self.variant["max_online_iters"]:
            print("\n\nMaking Online Env.....")
            online_envs = SubprocVecEnv(
                [
                    get_env_builder(i + 100, env_name=env_name, target_goal=target_goal)
                    for i in range(self.variant["num_online_rollouts"])
                ]
            )
            self.online_tuning(online_envs, eval_envs, loss_fn)
            online_envs.close()

        eval_envs.close()

In [None]:
utils.set_seed_everywhere(args.seed)
experiment = Experiment(vars(args))

print("=" * 50)
experiment()

  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


action_range: [-0.999999, 0.999999]
Starting new experiment: drone_dataset
1544 trajectories, 3497627 timesteps found
Average return: -29.00, std: 2362.99
Max return: 3362.87, min: -5541.95
Average length: 2265.30, std: 1012.84
Max length: 4001.00, min: 919.00
Experiment log path: ./exp/2023.03.17/013157-default


Making Eval Env.....


pybullet build time: May 20 2022 19:44:17
pybullet build time: May 20 2022 19:44:17
pybullet build time: May 20 2022 19:44:17
pybullet build time: May 20 2022 19:44:17
pybullet build time: May 20 2022 19:44:17
pybullet build time: May 20 2022 19:44:17
pybullet build time: May 20 2022 19:44:17
pybullet build time: May 20 2022 19:44:17
pybullet build time: May 20 2022 19:44:17





*** Pretrain ***
----------------
eval_envs: <stable_baselines3.common.vec_env.subproc_vec_env.SubprocVecEnv object at 0x7fd9372d66a0>
loss_fn: <function Experiment.__call__.<locals>.loss_fn at 0x7fd93722c670>


pybullet build time: May 20 2022 19:44:17


Iteration 0
time/total: 292.4005665779114
time/training: 280.7637116909027
training/train_loss_mean: 1070.941534824503
training/train_loss_std: 3923.5855151933315
training/nll: -2.756171464920044
training/entropy: -2.9883909225463867
training/temp_value: 0.09604749946966153
evaluation/return_mean_gm: -17.947984481657798
evaluation/return_std_gm: 0.01928994054797549
evaluation/length_mean_gm: 4000.0
evaluation/length_std_gm: 0.0
time/evaluation: 11.619734525680542

Model saved at ./exp/2023.03.17/013157-default/model.pt
Model saved at ./exp/2023.03.17/013157-default/pretrain_model.pt


Making Online Env.....


pybullet build time: May 20 2022 19:44:17





*** Online Finetuning ***
Iteration 1
aug_traj/return: -47967.96691939609
aug_traj/length: 4000.0
time/training: 17.054861545562744
training/train_loss_mean: -0.7687976190852648
training/train_loss_std: 4.08821092570524
training/nll: -1.7699211835861206
training/entropy: -3.703155517578125
training/temp_value: 0.0972356226017003
time/total: 320.5533993244171

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 2
aug_traj/return: -47952.28645396787
aug_traj/length: 4000.0
time/training: 16.67607092857361
training/train_loss_mean: -1.7760222835264066
training/train_loss_std: 1.4195965302595506
training/nll: -2.382671356201172
training/entropy: -3.1791560649871826
training/temp_value: 0.09824595914798266
time/total: 345.9127948284149

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 3
aug_traj/return: -46197.923833808294
aug_traj/length: 4000.0
time/training: 16.780717849731445
training/train_loss_mean: -2.268551374918228
training/train_loss_std: 0.526949


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 18
aug_traj/return: -47957.170852563
aug_traj/length: 4000.0
time/training: 16.842525243759155
training/train_loss_mean: -4.223411689194072
training/train_loss_std: 0.518280237227426
training/nll: -4.688774108886719
training/entropy: -4.36580753326416
training/temp_value: 0.16820299801090216
time/total: 757.1206848621368

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 19
aug_traj/return: -47956.20255720327
aug_traj/length: 4000.0
time/training: 16.75088596343994
training/train_loss_mean: -4.167618503452139
training/train_loss_std: 1.1212507774507006
training/nll: -6.261611461639404
training/entropy: -4.979609966278076
training/temp_value: 0.1744141497269175
time/total: 781.502683877945

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 20
aug_traj/return: -47883.053667969536
aug_traj/length: 4000.0
time/training: 16.798699378967285
training/train_loss_mean: -4.429819267167783
training/


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 35
aug_traj/return: -47751.42436912152
aug_traj/length: 4000.0
time/training: 16.75510835647583
training/train_loss_mean: -4.819006631942983
training/train_loss_std: 0.5675186906262883
training/nll: -6.090849876403809
training/entropy: -5.551949977874756
training/temp_value: 0.30720605032251014
time/total: 1203.6214230060577

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 36
aug_traj/return: -47608.30450259453
aug_traj/length: 4000.0
time/training: 16.87517237663269
training/train_loss_mean: -4.935043278870668
training/train_loss_std: 0.6799752357696844
training/nll: -6.036075592041016
training/entropy: -5.790248394012451
training/temp_value: 0.31798687270535836
time/total: 1228.811190366745

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 37
aug_traj/return: -47914.48057926734
aug_traj/length: 4000.0
time/training: 16.964256048202515
training/train_loss_mean: -4.791202974527258
trai


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 52
aug_traj/return: -47948.776230884054
aug_traj/length: 4000.0
time/training: 20.009317636489868
training/train_loss_mean: -4.356282766412284
training/train_loss_std: 0.5901739624569622
training/nll: -6.7533063888549805
training/entropy: -5.418612003326416
training/temp_value: 0.5401923103180026
time/total: 1667.7827167510986

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 53
aug_traj/return: -47907.20319700872
aug_traj/length: 4000.0
time/training: 20.14599299430847
training/train_loss_mean: -4.289085617901032
training/train_loss_std: 0.5523072216514
training/nll: -8.041379928588867
training/entropy: -5.881241798400879
training/temp_value: 0.5577738381278594
time/total: 1698.300621509552

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 54
aug_traj/return: -47881.57765463727
aug_traj/length: 4000.0
time/training: 19.11288595199585
training/train_loss_mean: -4.1709617824302505
traini


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 69
aug_traj/return: -47610.44433196877
aug_traj/length: 4000.0
time/training: 16.972599267959595
training/train_loss_mean: -2.8698182197464464
training/train_loss_std: 0.36076591785608075
training/nll: -6.394449710845947
training/entropy: -4.186805725097656
training/temp_value: 0.8988428527999774
time/total: 2140.3827497959137

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 70
aug_traj/return: -47566.09739442542
aug_traj/length: 4000.0
time/training: 16.75585126876831
training/train_loss_mean: -2.8140322931944945
training/train_loss_std: 0.3717537137048823
training/nll: -6.742089748382568
training/entropy: -4.228518009185791
training/temp_value: 0.9196736658023592
evaluation/return_mean_gm: 0.0
evaluation/return_std_gm: 0.0
evaluation/length_mean_gm: 4000.0
evaluation/length_std_gm: 0.0
time/evaluation: 11.819518089294434
time/total: 2177.004120349884

Model saved at ./exp/2023.03.17/013157-default/mo


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 86
aug_traj/return: -47062.70364373794
aug_traj/length: 4000.0
time/training: 16.775705337524414
training/train_loss_mean: -2.6032217491979215
training/train_loss_std: 0.3691478529565869
training/nll: -5.130330562591553
training/entropy: -2.974463701248169
training/temp_value: 0.9793126772520233
time/total: 2600.462861061096

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 87
aug_traj/return: -33103.54510033952
aug_traj/length: 4000.0
time/training: 16.817216396331787
training/train_loss_mean: -2.595251476957301
training/train_loss_std: 0.3469684098558466
training/nll: -5.735482215881348
training/entropy: -3.067293167114258
training/temp_value: 0.9813368937010538
time/total: 2626.1555993556976

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 88
aug_traj/return: -34494.42216322507
aug_traj/length: 4000.0
time/training: 16.84565782546997
training/train_loss_mean: -2.6359197297340256
tra


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 103
aug_traj/return: -7552.609516144821
aug_traj/length: 4000.0
time/training: 17.076557159423828
training/train_loss_mean: -2.676235420970487
training/train_loss_std: 0.3884949827821909
training/nll: -6.234414577484131
training/entropy: -3.3683090209960938
training/temp_value: 0.983830596311183
time/total: 3049.9392795562744

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 104
aug_traj/return: -6490.319826870499
aug_traj/length: 4000.0
time/training: 17.030917167663574
training/train_loss_mean: -2.6426926999125455
training/train_loss_std: 0.3293009054994761
training/nll: -5.220583438873291
training/entropy: -2.892394781112671
training/temp_value: 0.9844566843493402
time/total: 3074.9993369579315

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 105
aug_traj/return: -6512.579623722526
aug_traj/length: 4000.0
time/training: 17.17139172554016
training/train_loss_mean: -2.6331545719775287


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 120
aug_traj/return: -6898.177895194984
aug_traj/length: 4000.0
time/training: 17.279197931289673
training/train_loss_mean: -2.626926362020163
training/train_loss_std: 0.38445024409236095
training/nll: -6.193855285644531
training/entropy: -3.298154592514038
training/temp_value: 0.9824656918676765
evaluation/return_mean_gm: -717.1454130582473
evaluation/return_std_gm: 77.30368891124724
evaluation/length_mean_gm: 4000.0
evaluation/length_std_gm: 0.0
time/evaluation: 12.029415607452393
time/total: 3501.783799648285

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 121
aug_traj/return: -15672.898474555152
aug_traj/length: 4000.0
time/training: 16.9880211353302
training/train_loss_mean: -2.6600703841117794
training/train_loss_std: 0.3587519458125076
training/nll: -5.145882606506348
training/entropy: -2.8740978240966797
training/temp_value: 0.9842117329422697
time/total: 3526.652708053589

Model saved at ./ex


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 137
aug_traj/return: -8197.67714263307
aug_traj/length: 4000.0
time/training: 17.398804187774658
training/train_loss_mean: -2.634006962008786
training/train_loss_std: 0.36830486835975984
training/nll: -6.01640510559082
training/entropy: -2.928622007369995
training/temp_value: 0.9855863323048509
time/total: 3941.5842649936676

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 138
aug_traj/return: -8173.051522319389
aug_traj/length: 4000.0
time/training: 17.030654191970825
training/train_loss_mean: -2.5954894051386685
training/train_loss_std: 0.4424242514862133
training/nll: -5.977717399597168
training/entropy: -2.8527939319610596
training/temp_value: 0.9839786526820341
time/total: 3966.6646716594696

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 139
aug_traj/return: -6694.98815648973
aug_traj/length: 4000.0
time/training: 17.055579900741577
training/train_loss_mean: -2.6003795893493598


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 154
aug_traj/return: -8400.296253617069
aug_traj/length: 4000.0
time/training: 16.979113340377808
training/train_loss_mean: -2.5844321946122157
training/train_loss_std: 0.34786957683916525
training/nll: -6.413856029510498
training/entropy: -3.3433310985565186
training/temp_value: 0.9839310487191949
time/total: 4392.9406769275665

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 155
aug_traj/return: -7776.166593166357
aug_traj/length: 4000.0
time/training: 17.04395079612732
training/train_loss_mean: -2.576502857819322
training/train_loss_std: 0.46293990955500547
training/nll: -6.173624515533447
training/entropy: -3.2884438037872314
training/temp_value: 0.9849173079718785
time/total: 4417.955722570419

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 156
aug_traj/return: -9355.90024441358
aug_traj/length: 4000.0
time/training: 17.031273365020752
training/train_loss_mean: -2.55113025153996


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 171
aug_traj/return: -9726.584858819506
aug_traj/length: 4000.0
time/training: 17.340940475463867
training/train_loss_mean: -2.5303124267692345
training/train_loss_std: 0.37564898501764044
training/nll: -4.5968708992004395
training/entropy: -2.671544313430786
training/temp_value: 0.9828480731639933
time/total: 4844.732668161392

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 172
aug_traj/return: -6785.263368702157
aug_traj/length: 4000.0
time/training: 16.750747442245483
training/train_loss_mean: -2.551644861986912
training/train_loss_std: 0.35066308443798727
training/nll: -5.941537380218506
training/entropy: -2.893120527267456
training/temp_value: 0.9817717982741934
time/total: 4869.380698680878

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 173
aug_traj/return: -5500.373999196534
aug_traj/length: 4000.0
time/training: 16.924652814865112
training/train_loss_mean: -2.50335973141413


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 188
aug_traj/return: -13591.045147660398
aug_traj/length: 4000.0
time/training: 17.105010747909546
training/train_loss_mean: -2.472049836345318
training/train_loss_std: 0.34563195576633554
training/nll: -5.399677753448486
training/entropy: -3.2277069091796875
training/temp_value: 0.9781780663455848
time/total: 5284.67707824707

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 189
aug_traj/return: -9886.081204651055
aug_traj/length: 4000.0
time/training: 17.103915214538574
training/train_loss_mean: -2.505549484902785
training/train_loss_std: 0.3850089946277032
training/nll: -4.296379089355469
training/entropy: -2.4665133953094482
training/temp_value: 0.9821146653871615
time/total: 5310.0856211185455

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 190
aug_traj/return: -11465.88270583403
aug_traj/length: 4000.0
time/training: 17.069074392318726
training/train_loss_mean: -2.46679430655561


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 205
aug_traj/return: -8086.661272623094
aug_traj/length: 4000.0
time/training: 17.421685457229614
training/train_loss_mean: -2.445928780086866
training/train_loss_std: 0.3353718555113776
training/nll: -5.408965587615967
training/entropy: -3.0405025482177734
training/temp_value: 0.979048211725852
time/total: 5738.155985355377

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 206
aug_traj/return: -9232.285602900525
aug_traj/length: 4000.0
time/training: 17.28055214881897
training/train_loss_mean: -2.4503194267185413
training/train_loss_std: 0.33906993572175964
training/nll: -4.502511024475098
training/entropy: -3.412924289703369
training/temp_value: 0.9778813842183153
time/total: 5763.955520868301

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 207
aug_traj/return: -7876.730702322083
aug_traj/length: 4000.0
time/training: 16.995806455612183
training/train_loss_mean: -2.448891958119341
t


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 222
aug_traj/return: -8735.694118237607
aug_traj/length: 4000.0
time/training: 17.048571586608887
training/train_loss_mean: -2.4445667984511252
training/train_loss_std: 0.3665160402729505
training/nll: -5.401331901550293
training/entropy: -3.153722047805786
training/temp_value: 0.98137378543919
time/total: 6193.634870767593

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 223
aug_traj/return: -6273.130197273625
aug_traj/length: 4000.0
time/training: 17.32750654220581
training/train_loss_mean: -2.4156238983941023
training/train_loss_std: 0.3685071003892754
training/nll: -5.3691582679748535
training/entropy: -2.97286319732666
training/temp_value: 0.9783183260681361
time/total: 6218.865560054779

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 224
aug_traj/return: -10112.848354633788
aug_traj/length: 4000.0
time/training: 16.9351589679718
training/train_loss_mean: -2.418508875064226
trai


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 239
aug_traj/return: -7366.668213906221
aug_traj/length: 4000.0
time/training: 17.5063533782959
training/train_loss_mean: -2.362501470609195
training/train_loss_std: 0.5240057552566347
training/nll: -4.519344806671143
training/entropy: -2.340442180633545
training/temp_value: 0.9776267876056174
time/total: 6635.820787191391

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 240
aug_traj/return: -8561.175439160454
aug_traj/length: 4000.0
time/training: 17.294710874557495
training/train_loss_mean: -2.37127045118559
training/train_loss_std: 0.38146542008325235
training/nll: -4.856265544891357
training/entropy: -2.8749539852142334
training/temp_value: 0.9787355207290405
evaluation/return_mean_gm: -999.5065299793445
evaluation/return_std_gm: 3078.967862920957
evaluation/length_mean_gm: 4000.0
evaluation/length_std_gm: 0.0
time/evaluation: 11.788743257522583
time/total: 6672.811394929886

Model saved at ./exp/2


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 256
aug_traj/return: -9383.614641289994
aug_traj/length: 4000.0
time/training: 18.01411771774292
training/train_loss_mean: -2.336283973092475
training/train_loss_std: 0.5446902771403193
training/nll: -6.2662248611450195
training/entropy: -3.3122200965881348
training/temp_value: 0.9791530903612303
time/total: 7091.604229211807

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 257
aug_traj/return: -8171.666568126158
aug_traj/length: 4000.0
time/training: 17.212060689926147
training/train_loss_mean: -2.303185155809916
training/train_loss_std: 0.7850275484218285
training/nll: -5.453782081604004
training/entropy: -3.007634162902832
training/temp_value: 0.9760169489246792
time/total: 7116.565243959427

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 258
aug_traj/return: -5972.626396526156
aug_traj/length: 4000.0
time/training: 17.24878239631653
training/train_loss_mean: -2.3231932364502494
t


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 273
aug_traj/return: -8516.872988886531
aug_traj/length: 4000.0
time/training: 17.1646785736084
training/train_loss_mean: -2.2705013166335206
training/train_loss_std: 0.6218253724076183
training/nll: -4.846070766448975
training/entropy: -3.060990571975708
training/temp_value: 0.9756979519265062
time/total: 7546.565254449844

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 274
aug_traj/return: -7946.051146932306
aug_traj/length: 4000.0
time/training: 17.41611909866333
training/train_loss_mean: -2.3055289735038644
training/train_loss_std: 0.5053246360198216
training/nll: -5.753830432891846
training/entropy: -3.329118490219116
training/temp_value: 0.9745835220270008
time/total: 7572.035917520523

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 275
aug_traj/return: -8253.05495946162
aug_traj/length: 4000.0
time/training: 17.098983764648438
training/train_loss_mean: -2.270779771231772
trai


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 290
aug_traj/return: -5773.557212604587
aug_traj/length: 4000.0
time/training: 17.47034978866577
training/train_loss_mean: -2.28632096286601
training/train_loss_std: 0.35957680588273816
training/nll: -5.902886867523193
training/entropy: -3.360706329345703
training/temp_value: 0.974260734305682
evaluation/return_mean_gm: -35939.81826060127
evaluation/return_std_gm: 5972.799667435256
evaluation/length_mean_gm: 4000.0
evaluation/length_std_gm: 0.0
time/evaluation: 11.95029616355896
time/total: 8005.0950610637665

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 291
aug_traj/return: -9985.161611141351
aug_traj/length: 4000.0
time/training: 17.704586029052734
training/train_loss_mean: -2.2322770474261966
training/train_loss_std: 0.38345402220476665
training/nll: -5.295780181884766
training/entropy: -3.084667921066284
training/temp_value: 0.970265024390104
time/total: 8030.687732696533

Model saved at ./exp/2


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 307
aug_traj/return: -9282.944549930144
aug_traj/length: 4000.0
time/training: 17.03701615333557
training/train_loss_mean: -2.199245014923169
training/train_loss_std: 0.6284865609087349
training/nll: -5.58589506149292
training/entropy: -3.1124846935272217
training/temp_value: 0.9710787399203452
time/total: 8451.824660778046

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 308
aug_traj/return: -8116.532262542861
aug_traj/length: 4000.0
time/training: 17.45321297645569
training/train_loss_mean: -2.2061684851324994
training/train_loss_std: 0.38097177144891065
training/nll: -4.604641437530518
training/entropy: -2.3073830604553223
training/temp_value: 0.9704120620580688
time/total: 8477.096631526947

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 309
aug_traj/return: -7069.438388019611
aug_traj/length: 4000.0
time/training: 17.575999975204468
training/train_loss_mean: -2.0769818102828315



Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 324
aug_traj/return: -6098.69783824416
aug_traj/length: 4000.0
time/training: 18.019481658935547
training/train_loss_mean: -2.084327027193389
training/train_loss_std: 1.2723766246595023
training/nll: -5.385732173919678
training/entropy: -3.3683249950408936
training/temp_value: 0.9664586104005058
time/total: 8911.60150551796

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 325
aug_traj/return: -10038.794006336386
aug_traj/length: 4000.0
time/training: 17.353841304779053
training/train_loss_mean: -2.2200345699547572
training/train_loss_std: 0.32195986831348206
training/nll: -4.789292335510254
training/entropy: -2.7749106884002686
training/temp_value: 0.9683660229511163
time/total: 8936.868010759354

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 326
aug_traj/return: -8244.029450551
aug_traj/length: 4000.0
time/training: 18.07238006591797
training/train_loss_mean: -2.1835693718714615
tr


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 341
aug_traj/return: -5608.19664106972
aug_traj/length: 4000.0
time/training: 18.102673053741455
training/train_loss_mean: -2.182212884066968
training/train_loss_std: 0.35746132781393614
training/nll: -5.952455997467041
training/entropy: -3.6766602993011475
training/temp_value: 0.9660233842743641
time/total: 9372.601047039032

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 342
aug_traj/return: -6966.700813427651
aug_traj/length: 4000.0
time/training: 18.24293065071106
training/train_loss_mean: -2.0969889452792594
training/train_loss_std: 0.7155207368411638
training/nll: -5.877196788787842
training/entropy: -3.604707717895508
training/temp_value: 0.9634321667053717
time/total: 9398.76132106781

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 343
aug_traj/return: -9364.778647256298
aug_traj/length: 4000.0
time/training: 17.74156093597412
training/train_loss_mean: -2.155165228404029
tra


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 358
aug_traj/return: -9935.148942406728
aug_traj/length: 4000.0
time/training: 18.244296312332153
training/train_loss_mean: -2.0369531526097306
training/train_loss_std: 1.3078473634141392
training/nll: -5.856361389160156
training/entropy: -3.442504644393921
training/temp_value: 0.9612225273823148
time/total: 9821.608304262161

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 359
aug_traj/return: -7985.691925330221
aug_traj/length: 4000.0
time/training: 17.923965454101562
training/train_loss_mean: -1.997750395229237
training/train_loss_std: 1.8683766622303903
training/nll: -5.843387603759766
training/entropy: -3.757492780685425
training/temp_value: 0.9585916310129935
time/total: 9847.527424097061

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 360
aug_traj/return: -5922.671640686589
aug_traj/length: 4000.0
time/training: 17.914639949798584
training/train_loss_mean: -2.1622029891451047



Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 375
aug_traj/return: -9119.451715251065
aug_traj/length: 4000.0
time/training: 17.12938928604126
training/train_loss_mean: -2.1202294443232135
training/train_loss_std: 0.365484138197905
training/nll: -4.643568515777588
training/entropy: -2.6249682903289795
training/temp_value: 0.960356372680296
time/total: 10284.983618974686

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 376
aug_traj/return: -7520.204158382667
aug_traj/length: 4000.0
time/training: 17.831810235977173
training/train_loss_mean: -2.1210366329134707
training/train_loss_std: 0.35702962144996003
training/nll: -6.705935955047607
training/entropy: -3.7340590953826904
training/temp_value: 0.9586006224241985
time/total: 10311.034532546997

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 377
aug_traj/return: -9890.095277624907
aug_traj/length: 4000.0
time/training: 17.91273546218872
training/train_loss_mean: -2.020100821831679


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 392
aug_traj/return: -10434.53201275126
aug_traj/length: 4000.0
time/training: 17.212172508239746
training/train_loss_mean: -1.845073584586917
training/train_loss_std: 4.037782256155695
training/nll: -5.112393856048584
training/entropy: -3.396005392074585
training/temp_value: 0.9540708026654756
time/total: 10750.112160682678

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 393
aug_traj/return: -7172.1106581086815
aug_traj/length: 4000.0
time/training: 17.381750106811523
training/train_loss_mean: -2.0348171297104547
training/train_loss_std: 0.3996353648318517
training/nll: -4.916017055511475
training/entropy: -3.1013176441192627
training/temp_value: 0.9515032605045515
time/total: 10775.618800401688

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 394
aug_traj/return: -3547.0745203928745
aug_traj/length: 4000.0
time/training: 17.581472396850586
training/train_loss_mean: -2.0840212787101


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 409
aug_traj/return: -7786.355351931866
aug_traj/length: 4000.0
time/training: 18.029974222183228
training/train_loss_mean: -2.044885565759641
training/train_loss_std: 0.605425750401292
training/nll: -5.344016075134277
training/entropy: -3.07611346244812
training/temp_value: 0.9492623255884033
time/total: 11204.70546078682

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 410
aug_traj/return: -8289.558749552632
aug_traj/length: 4000.0
time/training: 17.98828148841858
training/train_loss_mean: -1.8860402325419217
training/train_loss_std: 1.748819144836115
training/nll: -4.196528911590576
training/entropy: -2.5834217071533203
training/temp_value: 0.9522118973688157
evaluation/return_mean_gm: -15991.394026818807
evaluation/return_std_gm: 15128.220444277902
evaluation/length_mean_gm: 4000.0
evaluation/length_std_gm: 0.0
time/evaluation: 11.854246854782104
time/total: 11242.601642847061

Model saved at ./exp


Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 426
aug_traj/return: -7936.517270759405
aug_traj/length: 4000.0
time/training: 17.63629984855652
training/train_loss_mean: -1.984261800498651
training/train_loss_std: 0.9603150609160136
training/nll: -3.909060001373291
training/entropy: -2.4349496364593506
training/temp_value: 0.9488709386551943
time/total: 11671.641090631485

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 427
aug_traj/return: -10789.874035625333
aug_traj/length: 4000.0
time/training: 17.444542169570923
training/train_loss_mean: -1.8865301310605513
training/train_loss_std: 2.26553216318398
training/nll: -5.222695827484131
training/entropy: -3.245690107345581
training/temp_value: 0.9490744583799458
time/total: 11697.2963616848

Model saved at ./exp/2023.03.17/013157-default/model.pt
Iteration 428
aug_traj/return: -7694.525167215611
aug_traj/length: 4000.0
time/training: 17.756925344467163
training/train_loss_mean: -2.0006578664831114
t

In [None]:
def study_env(env):
    
    state_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    action_range = [
        float(env.action_space.low.min()) + 1e-6,
        float(env.action_space.high.max()) - 1e-6]
        
    print("state_dim: {}".format(state_dim))
    print("act_dim: {}".format(act_dim))
    print("action_range: {}".format(action_range))


In [None]:
my_env = make_pytorch_env(args)
their_env = gym.make('antmaze-large-diverse-v2')

In [None]:
study_env(my_env)

In [None]:
study_env(their_env)

In [None]:
my_env.reset()
my_env.step(2)

In [None]:
args

In [None]:
their_env.action_space

In [None]:
their_env.reset()
their_env.step(2)

In [None]:
#experiment.variant
#experiment.model.forward

In [None]:
loss

In [None]:
experiment.model.forward

In [None]:
experiment.model

In [None]:
import math
math.log(1e-310)

In [None]:
action_preds = torch.load('action_preds.pt')


In [None]:
a = torch.load("a.pt")

In [None]:
action_preds.log_likelihood(a)

In [None]:
sefude = action_preds.log_likelihood(a)

In [None]:
a

In [None]:
a

In [None]:
a

In [None]:
torch.nan_to_num(sefude)

In [None]:
action_preds

In [None]:
a[0][0]

In [None]:
math.log(-0.3)

In [None]:
action_preds.entropy().mean()

In [None]:
action_preds.log_likelihood(10)

In [None]:
action_preds.perplexity

In [None]:
import torch
state_dim = 4
hidden_size = 512

embed_state = torch.nn.Linear(state_dim, hidden_size).to('cuda')
embed_state_2 = torch.load('embed_state.pt').to('cuda')
states = torch.load('states.pt').to('cuda')
state_embeddings = embed_state(states)
state_embeddings_2 = torch.load('state_embeddings.pt').to('cuda')


In [None]:
states[0]

In [None]:
print("state_embeddings {}".format(state_embeddings))


In [None]:
print("state_embeddings 2 {}".format(state_embeddings_2))


In [None]:
embed_state.weight

In [None]:
embed_state_2.weight

In [None]:
embed_state

In [None]:
embed_state_2

In [None]:
embed_state(states)

In [None]:
embed_state_2(states)

In [None]:
stoppppppppppp

In [None]:
import torch
torch.__version__

In [None]:
!pip list | grep torch

In [None]:
!pip3 install torch --upgrade

In [None]:
# Normalizando as rewards pra ver se resolve o problema

In [None]:
import pickle

with open('data/drone_dataset.pkl', 'rb') as f:
    my_data = pickle.load(f)
    
with open('data/antmaze-large-diverse-v2.pkl', 'rb') as f:
    their_data = pickle.load(f)

In [None]:
for data in my_data:
    rewards = data['actions']
    print("max: {}".format(np.max(rewards)))
    print("min: {}".format(np.min(rewards)))
    print("mean: {}".format(np.mean(rewards)))
    print('----------------')

In [None]:
np.shape(my_data[0]['observations'])

In [None]:
np.shape(their_data[0]['observations'])

In [None]:
import pickle

with open('data/drone_dataset.pkl', 'rb') as f:
    my_data = pickle.load(f)
    

for data in my_data:
    
    #data['rewards']   = np.float32(data['rewards'].flatten())
    #data['terminals'] = np.float32(data['terminals'].flatten())
    data['actions'] = np.float32(np.minimum(np.maximum(data['actions'], -1), 1))
    #data['observations'] = np.float32(data['observations'])
    #data['next_observations'] = np.float32(data['next_observations'])
    
with open('data/drone_dataset.pkl', 'wb') as handle:
    pickle.dump(my_data, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
(v - v.min()) / (v.max() - v.min())