In [1]:
pip install mani_skill

Collecting mani_skill
  Downloading mani_skill-3.0.0b15-py3-none-any.whl.metadata (3.2 kB)
Collecting gymnasium==0.29.1 (from mani_skill)
  Downloading gymnasium-0.29.1-py3-none-any.whl.metadata (10 kB)
Collecting sapien==3.0.0.b1 (from mani_skill)
  Downloading sapien-3.0.0b1-cp310-cp310-manylinux2014_x86_64.whl.metadata (10 kB)
Collecting transforms3d (from mani_skill)
  Downloading transforms3d-0.4.2-py3-none-any.whl.metadata (2.8 kB)
Collecting trimesh (from mani_skill)
  Downloading trimesh-4.5.3-py3-none-any.whl.metadata (18 kB)
Collecting mplib==0.1.1 (from mani_skill)
  Downloading mplib-0.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.7 kB)
Collecting fast_kinematics==0.2.2 (from mani_skill)
  Downloading fast_kinematics-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB)
Collecting pytorch_kinematics==0.7.4 (from mani_skill)
  Downloading pytorch_kinematics-0.7.4-py3-none-any.whl.metadata (14 kB)
Collecting pynvml (fr

In [2]:
import os
from glob import glob
import json
import h5py
from tqdm.notebook import tqdm
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

import gymnasium as gym
import mani_skill.envs
from mani_skill.utils.wrappers.flatten import FlattenRGBDObservationWrapper
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from mani_skill.utils.registration import TimeLimitWrapper

  import pkg_resources
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
  warn("Failed to find system lib

In [3]:
def yield_mini_batches(data_folder):
    all_dirs = os.listdir(data_folder)
    random.shuffle(all_dirs)

    buffer_episode = None  

    for curr_dir in all_dirs:
        print(f"Processing: {curr_dir}")
        sub_dir_path = os.path.join(data_folder, curr_dir, curr_dir)
        
        h5_files = [f for f in os.listdir(sub_dir_path) if f.endswith(".h5")]
        random.shuffle(h5_files)

        for h5f in h5_files:
            h5_path = os.path.join(sub_dir_path, h5f)
            
            episode_generator = load_trajectories(h5_path)
            for episode in episode_generator:
                if buffer_episode is not None:
                    rgb = torch.cat([buffer_episode["rgb"], episode["rgb"]], dim=0)
                    acts = torch.cat([buffer_episode["actions"], episode["actions"]], dim=0)
                    buffer_episode = None 
                else:
                    buffer_episode = episode
                    continue

                mini_batch = {
                    "rgb": rgb,
                }
                mini_acts = acts
                yield mini_batch, mini_acts

                del mini_batch, mini_acts, rgb, state, acts

    if buffer_episode is not None:
        mini_batch = {
            "rgb": buffer_episode["rgb"],
        }
        mini_acts = buffer_episode["actions"]
        yield mini_batch, mini_acts

        del mini_batch, mini_acts, buffer_episode
        
def load_trajectories(h5_path):
    json_path = h5_path.replace(".h5", ".json")

    with open(json_path, "r", encoding="utf-8") as file:
        data = json.load(file)
    
    with h5py.File(h5_path, "r") as hf:
        for i, episode in enumerate(data["episodes"]):
            if episode["success"]:
                group_name = f"traj_{i}"             
                grp = hf[group_name]

                rgb_np = grp["obs"]["rgb"][:-1] 
                rgb_np = np.transpose(rgb_np, (0, 3, 1, 2)) 
                rgb_tensor = torch.from_numpy(rgb_np).float()

                acts_np = grp["actions"][:]
                acts_tensor = torch.from_numpy(acts_np).float()

                episode_data = {
                    "rgb": rgb_tensor,
                    "actions": acts_tensor
                }
                yield episode_data

  and should_run_async(code)


In [4]:
class Logger:
    def __init__(self, log_wandb=False):
        self.log_wandb = log_wandb
    def add_scalar(self, tag, scalar_value, step):
        if self.log_wandb:
            wandb.log({tag: scalar_value}, step=step)

In [6]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        cnn = models.resnet18(pretrained=True)
        cnn.fc = nn.Identity()
        
        self.extractors = nn.Sequential(cnn, nn.ReLU())
        self.mlp = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 8),
            nn.Tanh()
        )
    
    def forward(self, observations):
        obs = observations["rgb"]
        if obs.shape[3] == 6:
           obs = obs.float().permute(0,3,1,2)
        obs = obs / 255
        feats_1 = extractor(obs[:, :3, :, :]) 
        feats_2 = extractor(obs[:, 3:, :, :])
        feats = torch.concat((feats_1, feats_2), dim=1)

        return self.mlp(feats)

In [7]:
def evaluate(model, eval_envs, logger, global_step):
    print(f"Evaluating at {global_step} steps")
    model.eval()
    eval_obs, _ = eval_envs.reset()
    eval_metrics = defaultdict(list)
    for _ in tqdm(range(80)):
        with torch.no_grad():
            eval_obs, eval_rew, eval_terminations, eval_truncations, eval_infos = eval_envs.step(model(eval_obs))
            if "final_info" in eval_infos:
                mask = eval_infos["_final_info"]
                for k, v in eval_infos["final_info"]["episode"].items():
                    eval_metrics[k].append(v)
    for k, v in eval_metrics.items():
        mean = torch.stack(v).float().mean()
        logger.add_scalar(f"eval/{k}", mean, global_step)
        print(f"eval_{k}_mean={mean}")

In [8]:
def train(
          data_folder,
          device,
          lr=1e-4,
          checkpoint_dir="checkpoints",
          eval_freq = 10
):
    os.makedirs(checkpoint_dir, exist_ok=True)
    print("Running training")

    env_kwargs = dict(obs_mode="rgb", render_mode="all", sim_backend="gpu")
    eval_envs = gym.make("StackCube-v1", num_envs=100, **env_kwargs)
    eval_envs = TimeLimitWrapper(eval_envs, max_episode_steps=80)
    eval_envs = FlattenRGBDObservationWrapper(eval_envs, rgb=True, depth=False, state=True)
    eval_envs = ManiSkillVectorEnv(eval_envs, 100, ignore_terminations=True, record_metrics=True)
    
    wandb.login(key="")
    config = {}
    config["eval_env_cfg"] = dict(
                                  **env_kwargs, 
                                  num_envs=100, 
                                  env_id="StackCube-v1", 
                                  reward_mode="normalized_dense", 
                                  env_horizon=80
    )
    wandb.init(
                project="pretrain",
                sync_tensorboard=False,
                config=config,
                name="pretrain",
                save_code=True,
    )
    logger = Logger(log_wandb=True)
    
    model = FeatureExtractor().to(device)
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    global_step = 0
    for i, (obs_chunk, acts_chunk) in enumerate(yield_mini_batches(data_folder)):
        if i % eval_freq == 1:
            evaluate(model, eval_envs, logger, global_step)
            model.train()
        global_step += 160
        for _ in range(3):
            obs_chunk["rgb"] = obs_chunk["rgb"].to(device)
            obs_chunk["state"] = obs_chunk["state"].to(device)
            acts_chunk = acts_chunk.to(device)

            pred_actions = model(obs_chunk)  
            loss = criterion(pred_actions, acts_chunk)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        logger.add_scalar("train/loss", loss.item(), global_step)

        del obs_chunk, acts_chunk, pred_actions
        torch.cuda.empty_cache()

    ckpt_path = f"model.pth"
    torch.save({
        "epoch": 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict()
    }, ckpt_path)
    print(f"Checkpoint saved: {ckpt_path}")

    print("Training finished.")

In [None]:
data_folder = "/kaggle/input/trajectories-dataset"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train(
    data_folder=data_folder, 
    device=device, 
    lr=1e-4, 
    checkpoint_dir="checkpoints", 
    eval_freq=50
)

Running training
Downloading PhysX GPU library to /root/.sapien/physx/105.1-physx-5.3.1.patch0 from Github. This can take several minutes. If it fails to download, please manually download fhttps://github.com/sapien-sim/physx-precompiled/releases/download/105.1-physx-5.3.1.patch0/linux-so.zip and unzip at /root/.sapien/physx/105.1-physx-5.3.1.patch0.
Download complete.


  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtolya1111[0m ([33mtolya1111-lomonosov-moscow-state-university[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
/usr/local/lib/python3.10/dist-packages/pydantic/main.py:1309: PydanticDeprecatedSince20: The `copy` method is deprecated; use `model_copy` instead. See the docstring of `BaseModel.copy` for details about how to handle `include` and `exclude`. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.9/migration/


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 149MB/s] 


Processing: traj_3
Evaluating at 160 steps


  0%|          | 0/80 [00:00<?, ?it/s]

eval_success_once_mean=0.0
eval_return_mean=2.1634199619293213
eval_episode_len_mean=80.0
eval_reward_mean=0.027042750269174576
eval_success_at_end_mean=0.0
Evaluating at 8160 steps


  0%|          | 0/80 [00:00<?, ?it/s]

eval_success_once_mean=0.0
eval_return_mean=9.145692825317383
eval_episode_len_mean=80.0
eval_reward_mean=0.11432116478681564
eval_success_at_end_mean=0.0
Processing: traj_8
Evaluating at 16160 steps


  0%|          | 0/80 [00:00<?, ?it/s]

eval_success_once_mean=0.0
eval_return_mean=14.242785453796387
eval_episode_len_mean=80.0
eval_reward_mean=0.17803481221199036
eval_success_at_end_mean=0.0
Processing: traj_18
Evaluating at 24160 steps


  0%|          | 0/80 [00:00<?, ?it/s]

eval_success_once_mean=0.009999999776482582
eval_return_mean=20.734220504760742
eval_episode_len_mean=80.0
eval_reward_mean=0.2591777443885803
eval_success_at_end_mean=0.0
Evaluating at 32160 steps


  0%|          | 0/80 [00:00<?, ?it/s]

eval_success_once_mean=0.0
eval_return_mean=18.621978759765625
eval_episode_len_mean=80.0
eval_reward_mean=0.2327747344970703
eval_success_at_end_mean=0.0
Evaluating at 40160 steps


  0%|          | 0/80 [00:00<?, ?it/s]

eval_success_once_mean=0.009999999776482582
eval_return_mean=19.915142059326172
eval_episode_len_mean=80.0
eval_reward_mean=0.24893926084041595
eval_success_at_end_mean=0.0
Processing: traj_16
Evaluating at 48160 steps


  0%|          | 0/80 [00:00<?, ?it/s]

eval_success_once_mean=0.03999999910593033
eval_return_mean=22.920148849487305
eval_episode_len_mean=80.0
eval_reward_mean=0.28650185465812683
eval_success_at_end_mean=0.0
