# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""X-MAGICAL cross-embodiment pretraining script."""


In [1]:

import os.path as osp
import subprocess

from absl import app
from absl import logging
from configs.constants import ALGORITHMS
from configs.constants import EMBODIMENTS
from torchkit.experiment import string_from_kwargs
from torchkit.experiment import unique_id
import yaml
import random


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# pylint: disable=logging-fstring-interpolation

# Mapping from pretraining algorithm to config file.
ALGO_TO_CONFIG = {
    "xirl": "configs/xmagical/pretraining/tcc.py",
    "lifs": "configs/xmagical/pretraining/lifs.py",
    "tcn": "configs/xmagical/pretraining/tcn.py",
    "goal_classifier": "configs/xmagical/pretraining/classifier.py",
    "raw_imagenet": "configs/xmagical/pretraining/imagenet.py",
}

# We want to pretrain on the entire 1k demonstrations.
MAX_DEMONSTRATIONS = -1


In [3]:
embodiment = "shortstick"

algo = "xirl"

unique_name = True,

random_number = random.randint(1, 1000)  # You can adjust the range as needed
experiment_name = f"/home/user/xirl/exp{random_number}"



In [4]:


embodiments = EMBODIMENTS if embodiment is None else [embodiment]

for embodiment in embodiments:
    # Generate a unique experiment name.
    kwargs = {
        "dataset": "xmagical",
        "mode": "cross",
        "algo": algo,
        "embodiment": embodiment,
    }
    if unique_name:
      kwargs["uid"] = unique_id()
    logging.info("Experiment name: %s", experiment_name)
    
    # Train on all classes but the given embodiment.
    trainable_embs = tuple(EMBODIMENTS - set([embodiment]))
    
    subprocess.run(
        [
            "python",
            "pretrain.py",
            "--experiment_name",
            experiment_name,
            # "--raw_imagenet" if algo == "raw_imagenet" else "",
            "--raw_imagenet" if algo == "raw_imagenet" else None,  # Use None instead of ""
            "--config",
            f"{ALGO_TO_CONFIG[algo]}",
            "--config.data.pretrain_action_class",
            f"{repr(trainable_embs)}",
            "--config.data.downstream_action_class",
            f"{repr(trainable_embs)}",
            "--config.data.max_vids_per_class",
            f"{MAX_DEMONSTRATIONS}",
        ],
        check=True,
        stdout=subprocess.PIPE,  # Capture standard output
        stderr=subprocess.PIPE   # Capture standard error
    )
    
    # Note: This assumes that the config.root_dir value has not been
    # changed to its default value of 'tmp/xirl/pretrain_runs/'.
    exp_path = osp.join("/tmp/xirl/pretrain_runs/", experiment_name)
    
    # The 'goal_classifier' baseline does not need to compute a goal embedding.
    if algo != "goal_classifier":
      subprocess.run(
          [
              "python",
              "compute_goal_embedding.py",
              "--experiment_path",
              exp_path,
          ],
          check=True,
      )
    
    # Dump experiment metadata as yaml file.
    with open(osp.join(exp_path, "metadata.yaml"), "w") as fp:
      yaml.dump(kwargs, fp)



CalledProcessError: Command '['python', 'pretrain.py', '--experiment_name', '/home/user/xirl/exp956', '', '--config', 'configs/xmagical/pretraining/tcc.py', '--config.data.pretrain_action_class', "('longstick', 'mediumstick', 'gripper')", '--config.data.downstream_action_class', "('longstick', 'mediumstick', 'gripper')", '--config.data.max_vids_per_class', '-1']' returned non-zero exit status 1.