In [5]:
import sys
import yaml
import os
import random
import time
from collections import defaultdict

os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true --xla_gpu_autotune_level=0"
os.environ["TF_DETERMINISTIC_OPS"] = "1"
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"

sys.path.append("..")

import wandb
import jax
from dataclasses import asdict
import numpy as np
import tqdm
from agents import agents
from ml_collections import FrozenConfigDict
from utils.datasets import Dataset, GCDataset, HGCDataset
from utils.env_utils import make_env_and_datasets
# from utils.evaluation import evaluate
from utils.flax_utils import restore_agent, save_agent
from utils.log_utils import CsvLogger, get_exp_name, get_wandb_video, setup_wandb
from utils.config import GCTTTConfig, load_config
import matplotlib.pyplot as plt
import io
from PIL import Image
import importlib


In [20]:
# Load config
cfg = load_config("../bc.yaml")
 # Load agent defaults
cfg_dict = asdict(cfg)
agent_cfg = importlib.import_module(f"agents.{cfg.agent.agent_name}").get_config()
for k, v in agent_cfg.items():
    if k not in cfg_dict["agent"]:
        cfg_dict["agent"][k] = v


In [21]:
cfg_dict

{'run_group': 'debug',
 'seed': 0,
 'env_name': 'pointmaze-medium-stitch-v0',
 'data_ratio': 1.0,
 'working_dir': 'meta_exp',
 'restore_path': None,
 'restore_epoch': None,
 'agent': {'agent_name': 'gciql',
  'actor_geom_sample': False,
  'actor_hidden_dims': (512, 512, 512),
  'actor_loss': 'ddpgbc',
  'actor_p_curgoal': 0.0,
  'actor_p_randomgoal': 0.0,
  'actor_p_trajgoal': 1.0,
  'alpha': 0.3,
  'batch_size': 1024,
  'const_std': True,
  'dataset_class': 'GCDataset',
  'discount': 0.99,
  'discrete': False,
  'encoder': None,
  'expectile': 0.9,
  'frame_stack': None,
  'gc_negative': True,
  'layer_norm': True,
  'lr': 0.0003,
  'p_aug': 0.0,
  'tau': 0.005,
  'value_geom_sample': True,
  'value_hidden_dims': (512, 512, 512),
  'value_p_curgoal': 0.2,
  'value_p_randomgoal': 0.3,
  'value_p_trajgoal': 0.5},
 'finetune': {'ratio': 0.5,
  'num_steps': 0,
  'lr': 3e-05,
  'actor_loss': 'ddpgbc',
  'alpha': None,
  'batch_size': 1024,
  'fix_actor_goal': 0.0,
  'mc_quantile': 0.2,
  '

In [22]:
# Load data
config_agent = cfg_dict["agent"]
env, train_dataset, val_dataset = make_env_and_datasets(
    cfg.env_name, cfg.data_ratio, frame_stack=config_agent["frame_stack"]
)

Downloading dataset from: https://rail.eecs.berkeley.edu/datasets/ogbench/pointmaze-medium-stitch-v0.npz


pointmaze-medium-stitch-v0.npz: 100%|██████████| 18.8M/18.8M [00:02<00:00, 9.47MB/s]


Downloading dataset from: https://rail.eecs.berkeley.edu/datasets/ogbench/pointmaze-medium-stitch-v0-val.npz


pointmaze-medium-stitch-v0-val.npz: 100%|██████████| 1.88M/1.88M [00:01<00:00, 1.80MB/s]


In [23]:
config_agent

{'agent_name': 'gciql',
 'actor_geom_sample': False,
 'actor_hidden_dims': (512, 512, 512),
 'actor_loss': 'ddpgbc',
 'actor_p_curgoal': 0.0,
 'actor_p_randomgoal': 0.0,
 'actor_p_trajgoal': 1.0,
 'alpha': 0.3,
 'batch_size': 1024,
 'const_std': True,
 'dataset_class': 'GCDataset',
 'discount': 0.99,
 'discrete': False,
 'encoder': None,
 'expectile': 0.9,
 'frame_stack': None,
 'gc_negative': True,
 'layer_norm': True,
 'lr': 0.0003,
 'p_aug': 0.0,
 'tau': 0.005,
 'value_geom_sample': True,
 'value_hidden_dims': (512, 512, 512),
 'value_p_curgoal': 0.2,
 'value_p_randomgoal': 0.3,
 'value_p_trajgoal': 0.5}

In [24]:
train_dataset

# There are 1001K observations [x, y]
# There are 1001K actions [deltaX, deltaY] maybe
# Terminals: bool
# Valids: bool


FrozenDict({
    observations: array([[2.0033117e+01, 1.5632947e+01],
           [2.0099489e+01, 1.5443025e+01],
           [2.0299490e+01, 1.5243025e+01],
           ...,
           [1.6030459e-02, 1.1852762e+01],
           [9.9748522e-03, 1.2051859e+01],
           [1.1539937e-02, 1.1851859e+01]], shape=(1005000, 2), dtype=float32),
    actions: array([[ 0.33186212, -0.9496147 ],
           [ 1.        , -1.        ],
           [-0.9027821 , -0.3318183 ],
           ...,
           [-0.03027803,  0.9954832 ],
           [ 0.00782542, -1.        ],
           [ 0.47040606,  0.07431631]], shape=(1005000, 2), dtype=float32),
    terminals: array([0., 0., 0., ..., 0., 1., 1.], shape=(1005000,), dtype=float32),
    valids: array([1., 1., 1., ..., 1., 1., 0.], shape=(1005000,), dtype=float32),
})

In [25]:
train_dataset.sample(config_agent["batch_size"])


{'actions': array([[ 0.9847515 , -0.16461378],
        [ 0.08414634,  0.11178224],
        [-0.8433495 , -0.8098009 ],
        ...,
        [-0.83299917, -0.06097023],
        [ 0.12648928, -1.        ],
        [-0.3431872 ,  0.6470249 ]], shape=(1024, 2), dtype=float32),
 'observations': array([[ 1.5343468 , 12.337385  ],
        [ 7.9657097 , 15.891599  ],
        [20.04253   , 16.10107   ],
        ...,
        [12.126947  ,  3.9037666 ],
        [11.908254  ,  8.13164   ],
        [19.223696  , -0.16270353]], shape=(1024, 2), dtype=float32),
 'terminals': array([0., 0., 0., ..., 0., 0., 0.], shape=(1024,), dtype=float32),
 'valids': array([1., 1., 1., ..., 1., 1., 1.], shape=(1024,), dtype=float32),
 'next_observations': array([[ 1.7312971 , 12.304462  ],
        [ 7.982539  , 15.913956  ],
        [19.873861  , 15.93911   ],
        ...,
        [11.960348  ,  3.8915727 ],
        [11.933552  ,  7.9316406 ],
        [19.155058  , -0.03329854]], shape=(1024, 2), dtype=float32)}

In [None]:


    # Set up logger.
    # split env_name by '-'
    env_name_split = cfg.env_name.split("-")
    # set wandb_env_name as the first part and second part of env_name_split
    wandb_env_name = env_name_split[0] + "-" + env_name_split[2]
    exp_name = get_exp_name(cfg)
    setup_wandb(
        project="TTT_AllFinalRuns", group=cfg.run_group, name=exp_name, config=cfg_dict
    )

    # Save current expanded config in the experiment dir
    os.makedirs(cfg.working_dir, exist_ok=True)
    with open(os.path.join(cfg.working_dir, "config.yaml"), "w") as f:
        yaml.dump(cfg_dict, f)

    # Set up environment and dataset.
    config_agent = cfg_dict["agent"]
    env, train_dataset, val_dataset = make_env_and_datasets(
        cfg.env_name, cfg.data_ratio, frame_stack=config_agent["frame_stack"]
    )
    env.reset(seed=cfg.seed)
    env.action_space.seed(cfg.seed)

    dataset_class = {
        "GCDataset": GCDataset,
        "HGCDataset": HGCDataset,
    }[config_agent["dataset_class"]]
    train_dataset: GCDataset | HGCDataset = dataset_class(Dataset.create(**train_dataset), config_agent)
    if val_dataset is not None:
        val_dataset: GCDataset | HGCDataset = dataset_class(Dataset.create(**val_dataset), config_agent)

    # Initialize agent.
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)

    example_batch = train_dataset.sample(1)
    if config_agent["discrete"]:
        # Fill with the maximum action to let the agent know the action space size.
        example_batch["actions"] = np.full_like(
            example_batch["actions"], env.action_space.n - 1
        )

    agent_class = agents[config_agent["agent_name"]]
    agent = agent_class.create(
        cfg.seed,
        example_batch["observations"],
        example_batch["actions"],
        config_agent,
    )

    # Restore agent.
    if cfg.restore_path is not None:
        agent = restore_agent(agent, cfg.restore_path, cfg.restore_epoch)

    # Train agent.
    train_logger = CsvLogger(os.path.join(cfg.working_dir, "train.csv"))
    eval_logger = CsvLogger(os.path.join(cfg.working_dir, "eval.csv"))
    first_time = time.time()
    last_time = time.time()

In [None]:
/cluster/home/anmari/gc_ttt/venv/bin/python