In [1]:
import os
from pathlib import Path

import torch
from cliport import agents
from cliport.dataset import RavensDataset, RavensMultiTaskDataset

import hydra
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from cliport.utils import utils

os.environ['CLIPORT_ROOT'] = '/sfs/weka/scratch/ys5hd/cliport'

  warn(


In [2]:
# Load configs
ROOT_DIR = os.environ['CLIPORT_ROOT']
config_file = 'train.yaml' 
cfg = utils.load_hydra_config(os.path.join(ROOT_DIR, f'cliport/cfg/{config_file}'))

In [3]:
wandb_logger = WandbLogger(name=cfg['tag']) if cfg['train']['log'] else None

In [4]:
# Set False for cliport architecture
# Set the as task name for Text pre-processed Architecture

LLM_PARSING = 'align-rope'

TRAIN_TASK = 'align-rope'
#'put-block-in-bowl-seen-colors'
#'packing-boxes-pairs-seen-colors'
#'stack-block-pyramid-seq-unseen-colors'

TRAIN_AGENT = 'cliport'
TRAIN_NDEMOS = 100
if LLM_PARSING:
    EXPS = 'exps_llm'
else:
    EXPS = 'exps'

cfg['train']['task'] = TRAIN_TASK
cfg['train']['agent'] = TRAIN_AGENT
cfg['train']['attn_stream_fusion_type'] = 'add'
cfg['train']['trans_stream_fusion_type'] = 'conv'
cfg['train']['lang_fusion_type'] = 'mult'
cfg['train']['n_demos'] = 100
cfg['train']['n_steps'] =  101000 #201000
cfg['train']['exp_folder'] = EXPS
cfg['dataset']['cache'] = False

In [5]:
cfg['train']['train_dir'] = f'{ROOT_DIR}/{EXPS}/{TRAIN_TASK}-{TRAIN_AGENT}-n{TRAIN_NDEMOS}-train'

In [6]:
hydra_dir = Path(os.getcwd())
checkpoint_path = os.path.join(cfg['train']['train_dir'], 'checkpoints')
last_checkpoint_path = os.path.join(checkpoint_path, 'last.ckpt')

last_checkpoint = last_checkpoint_path if os.path.exists(last_checkpoint_path) and cfg['train']['load_from_last_ckpt'] else None
checkpoint_callback = ModelCheckpoint(
    monitor=cfg['wandb']['saver']['monitor'],
    filepath=os.path.join(checkpoint_path, 'best'),
    save_top_k=1,
    save_last=True,
)

In [7]:
max_epochs = cfg['train']['n_steps'] // cfg['train']['n_demos']

In [8]:
trainer = Trainer(
    gpus=cfg['train']['gpu'],
    fast_dev_run=cfg['debug'],
    logger=wandb_logger,
    checkpoint_callback=checkpoint_callback,
    max_epochs=max_epochs,
    automatic_optimization=False,
    check_val_every_n_epoch=max_epochs // 50,
    resume_from_checkpoint=last_checkpoint,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [9]:
if last_checkpoint:
    print(f"Resuming: {last_checkpoint}")
    last_ckpt = torch.load(last_checkpoint)
    trainer.current_epoch = last_ckpt['epoch']
    trainer.global_step = last_ckpt['global_step']
    del last_ckpt

Resuming: /sfs/weka/scratch/ys5hd/cliport/exps_llm/align-rope-cliport-n100-train/checkpoints/last.ckpt


In [10]:
DATA_DIR = "/sfs/weka/scratch/ys5hd/cliport/cliport/data/"

In [11]:
# Config
data_dir = DATA_DIR
task = cfg['train']['task']
agent_type = cfg['train']['agent']
n_demos = cfg['train']['n_demos']
n_val = cfg['train']['n_val']
name = '{}-{}-{}'.format(task, agent_type, n_demos)

In [12]:
# Datasets
dataset_type = cfg['dataset']['type']
if 'multi' in dataset_type:
    train_ds = RavensMultiTaskDataset(data_dir, cfg, group=task, mode='train', n_demos=n_demos, augment=True)
    val_ds = RavensMultiTaskDataset(data_dir, cfg, group=task, mode='val', n_demos=n_val, augment=False)
else:
    train_ds = RavensDataset(os.path.join(data_dir, '{}-train'.format(task)), cfg, n_demos=n_demos, augment=True)
    val_ds = RavensDataset(os.path.join(data_dir, '{}-val'.format(task)), cfg, n_demos=n_val, augment=False)

In [13]:
# Initialize agent
agent = agents.names[agent_type](name, cfg, train_ds, val_ds, llm_parsing=LLM_PARSING)

Attn FCN - Stream One: plain_resnet_lat, Stream Two: clip_lingunet_lat, Stream Fusion: add
Transport FCN - Stream One: plain_resnet_lat, Stream Two: clip_lingunet_lat, Stream Fusion: conv
Agent: align-rope-cliport-100, Logging: False


In [14]:
# Main training loop
trainer.fit(agent)

Set SLURM handle signals.

  | Name      | Type                            | Params
--------------------------------------------------------------
0 | attention | TwoStreamAttentionLangFusionLat | 194 M 
1 | transport | TwoStreamTransportLangFusionLat | 388 M 


Validation sanity check: 0it [00:00, ?it/s]

pybullet build time: Sep 22 2020 00:56:01

Attn Err - Dist: 89.00, Theta: 0.00
Transport Err - Dist: 238.00, Theta: 0.00




Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2380.00, Theta: 0.00
Transport Err - Dist: 8677.00, Theta: 142.07


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2619.00, Theta: 0.00
Transport Err - Dist: 8139.00, Theta: 72.61


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1828.00, Theta: 0.00
Transport Err - Dist: 8042.00, Theta: 58.12


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2228.00, Theta: 0.00
Transport Err - Dist: 8897.00, Theta: 1.57


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2302.00, Theta: 0.00
Transport Err - Dist: 8573.00, Theta: 7.85


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1739.00, Theta: 0.00
Transport Err - Dist: 7765.00, Theta: 3.14


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2442.00, Theta: 0.00
Transport Err - Dist: 8527.00, Theta: 1.57


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2007.00, Theta: 0.00
Transport Err - Dist: 7692.00, Theta: 2.27


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2120.00, Theta: 0.00
Transport Err - Dist: 8430.00, Theta: 6.28


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2375.00, Theta: 0.00
Transport Err - Dist: 8679.00, Theta: 1.57


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Attn Err - Dist: 1933.00, Theta: 0.00
Transport Err - Dist: 7366.00, Theta: 6.11


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2060.00, Theta: 0.00
Transport Err - Dist: 9069.00, Theta: 1.57


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2089.00, Theta: 0.00
Transport Err - Dist: 9217.00, Theta: 3.14


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2447.00, Theta: 0.00
Transport Err - Dist: 7842.00, Theta: 11.87


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1632.00, Theta: 0.00
Transport Err - Dist: 7340.00, Theta: 23.91


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1715.00, Theta: 0.00
Transport Err - Dist: 7768.00, Theta: 25.13


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1787.00, Theta: 0.00
Transport Err - Dist: 6889.00, Theta: 14.49


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1958.00, Theta: 0.00
Transport Err - Dist: 7118.00, Theta: 9.60


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2059.00, Theta: 0.00
Transport Err - Dist: 6239.00, Theta: 15.01


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2661.00, Theta: 0.00
Transport Err - Dist: 5786.00, Theta: 13.09


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2057.00, Theta: 0.00
Transport Err - Dist: 6643.00, Theta: 12.92


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1783.00, Theta: 0.00
Transport Err - Dist: 5358.00, Theta: 9.95


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2157.00, Theta: 0.00
Transport Err - Dist: 5377.00, Theta: 3.14


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 3640.00, Theta: 0.00
Transport Err - Dist: 4577.00, Theta: 13.61


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1944.00, Theta: 0.00
Transport Err - Dist: 6034.00, Theta: 5.93


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2372.00, Theta: 0.00
Transport Err - Dist: 4722.00, Theta: 21.99


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1555.00, Theta: 0.00
Transport Err - Dist: 7432.00, Theta: 21.64


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2288.00, Theta: 0.00
Transport Err - Dist: 6094.00, Theta: 27.58


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2123.00, Theta: 0.00
Transport Err - Dist: 4293.00, Theta: 11.52


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2357.00, Theta: 0.00
Transport Err - Dist: 4534.00, Theta: 17.98


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1697.00, Theta: 0.00
Transport Err - Dist: 4847.00, Theta: 16.58


Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1963.00, Theta: 0.00
Transport Err - Dist: 4677.00, Theta: 10.47


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1726.00, Theta: 0.00
Transport Err - Dist: 5641.00, Theta: 27.58


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 2048.00, Theta: 0.00
Transport Err - Dist: 6010.00, Theta: 7.85


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1765.00, Theta: 0.00
Transport Err - Dist: 4543.00, Theta: 6.81


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Validating: 0it [00:00, ?it/s]


Attn Err - Dist: 1817.00, Theta: 0.00
Transport Err - Dist: 6195.00, Theta: 5.59


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

