
BACKEND part should be run or the geomstats does not convert correctly


In [1]:
%env GEOMSTATS_BACKEND=pytorch
%load_ext autoreload

%autoreload 2


import trimesh
import pyrender
import numpy as np
import glob 
import h5py
from tqdm import tqdm

import torch
torch.set_default_dtype(torch.float64)
import torch.nn as nn
from torchdiffeq import odeint

env: GEOMSTATS_BACKEND=pytorch


In [2]:
import torch
from torch.utils.data import DataLoader
import sys
import yaml
sys.path.append("..")  
from sefmp.data.grasp_dataset import GraspDataset
from sefmp.core.utils import load_config, get_device
from sefmp.data.grasp_dataset import DataLoader, DataSelector, GraspDataModule
from datetime import datetime

from pathlib import Path


config_path = '../sefmp/configs/sanity_check.yaml'
with open(config_path) as f:
    cfg = yaml.safe_load(f)
device = get_device()

# Setup unique run name if not specified
if cfg["logging"]["run_name"] is None:
    cfg["logging"]["run_name"] = (
        f"sefmp_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    )


selector = DataSelector(
            grasp_id=cfg["data"].get("grasp_id"),
            object_id=cfg["data"].get("object_id"),
            item_name=cfg["data"].get("item_name"),
        )
grasp_data = GraspDataModule(
            data_root='../data',
            selectors=selector,
            sampler_opt='repeat',  # Using list of selectors
            batch_size=16,
            num_samples=2,  # Optional: limit total samples
        )
grasp_data.setup()

# # Get data loaders
# train_loader = grasp_data.train_dataloader()
# val_loader = grasp_data.val_dataloader()

# # Create a DataLoader instance
# dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# # Iterate through the DataLoader
# for batch in dataloader:
#     # Process your batch
#     print(batch)

INFO: Processing mesh 1/1: 56f2cfa7d89ef32a5eef6c5d029c7274.obj
INFO: Mesh ../data/meshes/Plant/56f2cfa7d89ef32a5eef6c5d029c7274.obj has changed, recomputing SDF
INFO: Computing SDF for 56f2cfa7d89ef32a5eef6c5d029c7274.obj
INFO: Processing mesh 1/1: 56f2cfa7d89ef32a5eef6c5d029c7274.obj
INFO: Loaded 56f2cfa7d89ef32a5eef6c5d029c7274.obj from cache
INFO: Processing mesh 1/1: 56f2cfa7d89ef32a5eef6c5d029c7274.obj
INFO: Loaded 56f2cfa7d89ef32a5eef6c5d029c7274.obj from cache


In [None]:
from matplotlib import pyplot as plt
import torch
import pytorch_lightning as pl
import wandb
from torch import Tensor
from typing import Tuple
import torch.nn.functional as F
from jaxtyping import Float
from typeguard import typechecked
import torch.nn as nn
import numpy as np

from models.inr import INR


def plot_image(
    mlp_model: INR, device: torch.device
) -> plt.Figure:  # Updated return type hint
    resolution = 28
    x = np.linspace(-1, 1, resolution)
    y = np.linspace(-1, 1, resolution)
    grid_x, grid_y = np.meshgrid(x, y)

    inputs = np.stack([grid_x.ravel(), grid_y.ravel()], axis=-1)
    inputs_tensor = torch.tensor(inputs, dtype=torch.float32, device=device)

    with torch.no_grad():
        outputs = mlp_model(inputs_tensor).cpu().numpy()

    image = outputs.reshape(resolution, resolution)

    fig, ax = plt.subplots()
    ax.imshow(image, cmap="gray", extent=(-1, 1, -1, 1))
    plt.axis("off")
    return fig


def load_weights_into_inr(weights: Tensor, inr_model: INR) -> INR:
    """Helper function to load weights into INR model."""
    state_dict = {}
    start_idx = 0
    for key, param in inr_model.state_dict().items():
        param_size = param.numel()
        param_data = weights[start_idx : start_idx + param_size].reshape(param.shape)
        state_dict[key] = param_data
        start_idx += param_size
    inr_model.load_state_dict(state_dict)
    return inr_model


def create_reconstruction_visualizations(
    originals: Tensor,
    reconstructions: Tensor,
    inr_model: INR,
    prefix: str,
    batch_idx: int,
    global_step: int,
    is_fixed: bool = False,
) -> dict:
    """Create visualization grid for original-reconstruction pairs."""
    result_dict = {}

    # Create visualizations for each pair
    for i, (orig, recon) in enumerate(zip(originals, reconstructions)):
        # Generate figures
        original_fig = plot_image(load_weights_into_inr(orig, inr_model), orig.device)
        recon_fig = plot_image(load_weights_into_inr(recon, inr_model), recon.device)

        # Add to result dictionary with unique keys
        sample_type = "fixed" if is_fixed else "batch"
        result_dict[f"{prefix}/{sample_type}/original_{i}"] = wandb.Image(original_fig)
        result_dict[f"{prefix}/{sample_type}/reconstruction_{i}"] = wandb.Image(
            recon_fig
        )

        # Close figures
        plt.close(original_fig)
        plt.close(recon_fig)

    return result_dict


class Encoder(nn.Module):
    @typechecked
    def __init__(self, input_dim: int, hidden_dim: int, z_dim: int, **kwargs):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, z_dim)

    @typechecked
    def forward(
        self, x: Float[Tensor, "batch input_dim"]
    ) -> Float[Tensor, "batch z_dim"]:
        x = F.relu(self.fc1(x))
        z = self.fc2(x)
        return z


class Decoder(nn.Module):
    @typechecked
    def __init__(self, z_dim: int, hidden_dim: int, output_dim: int, **kwargs):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    @typechecked
    def forward(
        self, z: Float[Tensor, "batch z_dim"]
    ) -> Float[Tensor, "batch output_dim"]:
        z = F.relu(self.fc1(z))
        x_reconstructed = self.fc2(z)
        return x_reconstructed


class Autoencoder(pl.LightningModule):
    @typechecked
    def __init__(self, config: dict):
        super().__init__()
        self.save_hyperparameters(config)
        self.config = config

        # Initialize encoder and decoder
        self.encoder = Encoder(**config["model"])
        self.decoder = Decoder(**config["model"])

        # Initialize fixed validation and training samples
        self.fixed_val_samples: list[Tensor] | None = None
        self.fixed_train_samples: list[Tensor] | None = None
        self.fixed_sample_reconstructions: dict[str, list[Tensor]] = {}

        # Store optimizer and scheduler config
        self.optimizer_config = config["optimizer"]
        self.scheduler_config = config["scheduler"]

        # Initialize quality metrics
        self.best_val_loss = float("inf")

        # Create demo INR for visualization
        self.demo_inr = INR(up_scale=16)
        # Move demo INR to the same device as the model
        self.demo_inr = self.demo_inr.to(self.device)

    def setup(self, stage: str | None = None):
        """Setup fixed validation and training samples for tracking reconstruction progress."""
        if stage == "fit":
            try:
                num_samples = self.config["logging"]["num_samples_to_visualize"]

                # Setup validation samples
                if (
                    hasattr(self.trainer, "val_dataloaders")
                    and self.trainer.val_dataloaders is not None
                    and self.fixed_val_samples is None
                ):
                    val_batch = next(iter(self.trainer.val_dataloaders[0]))
                    self.fixed_val_samples = val_batch[:num_samples].clone()

                # Setup training samples
                if (
                    hasattr(self.trainer, "train_dataloader")
                    and self.trainer.train_dataloader is not None
                    and self.fixed_train_samples is None
                ):
                    train_batch = next(iter(self.trainer.train_dataloader()))
                    self.fixed_train_samples = train_batch[:num_samples].clone()

            except Exception as e:
                print(f"Warning: Could not setup fixed samples: {e}")
                self.fixed_val_samples = None
                self.fixed_train_samples = None

    @typechecked
    def encode(self, x: Float[Tensor, "batch feature_dim"]) -> Tensor:
        return self.encoder(x)

    @typechecked
    def decode(
        self, z: Float[Tensor, "batch latent_dim"]
    ) -> Float[Tensor, "batch feature_dim"]:
        return self.decoder(z)

    @typechecked
    def forward(
        self, input: Float[Tensor, "batch feature_dim"]
    ) -> Float[Tensor, "batch feature_dim"]:
        z = self.encode(input)
        dec = self.decode(z)
        return dec

    def compute_loss(
        self,
        inputs: Float[Tensor, "batch feature_dim"],
        reconstructions: Float[Tensor, "batch feature_dim"],
        prefix: str = "train",
    ) -> Tuple[Tensor, dict[str, Tensor]]:
        recon_loss = F.mse_loss(reconstructions, inputs)
        return recon_loss, {f"{prefix}/loss": recon_loss}

    def visualize_batch(self, batch: Tensor, prefix: str, batch_idx: int):
        """Visualize a batch of samples during training or validation."""
        if batch_idx % self.config["logging"]["log_every_n_steps"] == 0:
            with torch.no_grad():
                reconstructions = self(batch)

            # Log visualizations for a subset of the batch
            num_samples = min(
                self.config["logging"]["num_samples_to_visualize"], batch.shape[0]
            )  # Visualize up to num_samples_to_visualize
            vis_dict = create_reconstruction_visualizations(
                batch[:num_samples],
                reconstructions[:num_samples],
                self.demo_inr,
                prefix,
                batch_idx,
                self.global_step,
                is_fixed=False,
            )

            # Add step to wandb log
            vis_dict["global_step"] = self.global_step
            self.logger.experiment.log(vis_dict)

    def visualize_reconstructions(self, samples: Tensor, prefix: str, batch_idx: int):
        """Helper method to visualize fixed sample reconstructions during training or validation."""
        if (
            samples is not None
            and batch_idx % self.config["logging"]["log_every_n_steps"] == 0
        ):
            with torch.no_grad():
                reconstructions = self(samples)

            # Store reconstructions for this step
            step_key = f"{prefix}_step_{self.global_step}"
            self.fixed_sample_reconstructions[step_key] = reconstructions

            # Create and log visualizations
            vis_dict = create_reconstruction_visualizations(
                samples,
                reconstructions,
                self.demo_inr,
                prefix,
                batch_idx,
                self.global_step,
                is_fixed=True,
            )

            # Add step to wandb log
            vis_dict["global_step"] = self.global_step
            self.logger.experiment.log(vis_dict)

    @typechecked
    def training_step(
        self, batch: Float[Tensor, "batch feature_dim"], batch_idx: int
    ) -> Tensor:
        # Forward pass
        reconstructions = self(batch)
        loss, log_dict = self.compute_loss(batch, reconstructions, prefix="train")

        # Logging
        self.log_dict(log_dict, prog_bar=True, sync_dist=True)

        # Log gradient norm
        if batch_idx % self.config["trainer"]["log_every_n_steps"] == 0:
            total_norm = 0.0
            for p in self.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm**0.5
            self.log("train/grad_norm", total_norm, prog_bar=False, sync_dist=True)

        # Visualize both fixed samples and current batch
        self.visualize_reconstructions(self.fixed_train_samples, "train", batch_idx)
        self.visualize_batch(batch, "train_batch", batch_idx)

        return loss

    @typechecked
    def validation_step(
        self, batch: Float[Tensor, "batch feature_dim"], batch_idx: int
    ) -> dict[str, Tensor]:
        reconstructions = self(batch)
        val_loss, val_log_dict = self.compute_loss(batch, reconstructions, prefix="val")

        # Log validation metrics
        self.log_dict(val_log_dict, prog_bar=True, sync_dist=True)

        # Visualize both fixed samples and current batch
        if (
            batch_idx == 0
            and self.current_epoch % self.config["logging"]["sample_every_n_epochs"]
            == 0
        ):
            self.visualize_reconstructions(self.fixed_val_samples, "val", batch_idx)
            self.visualize_batch(batch, "val_batch", batch_idx)

        return val_log_dict

    def configure_optimizers(self):
        # Configure optimizer
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.optimizer_config["lr"],
            betas=tuple(self.optimizer_config["betas"]),
            eps=self.optimizer_config["eps"],
            weight_decay=self.optimizer_config["weight_decay"],
        )

        # Configure scheduler
        if self.scheduler_config["name"] == "cosine":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=self.scheduler_config["T_max"],
                eta_min=self.scheduler_config["eta_min"],
            )
        else:
            return optimizer

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val/loss",
            },
        }

tensor([1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1])


In [3]:
train_loader = grasp_data.train_dataloader()
grasp_data.train_dataset.index
for batch in train_loader:
    print(len(batch))
batch[0].shape, batch[1].shape, batch[2].shape

3


(torch.Size([16, 3, 3]), torch.Size([16, 3]), torch.Size([16, 32, 32, 32]))

In [None]:
import sys
sys.path.append("..")  

import os
import numpy as np
import torch
from einops import rearrange
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')


from scipy.spatial.transform import Rotation
from geomstats.geometry.special_orthogonal import SpecialOrthogonal


# from utils.plotting import plot_so3
# from utils.optimal_transport import so3_wasserstein as wasserstein
# from FoldFlow.foldflow.utils.so3_helpers import norm_SO3, expmap
# from FoldFlow.foldflow.utils.so3_condflowmatcher import SO3ConditionalFlowMatcher
# from FoldFlow.so3_experiments.models.models import PMLP

from sefmp.models.so3_helpers import norm_SO3, expmap
from sefmp.models.so3_condflowmatcher import SO3ConditionalFlowMatcher
from sefmp.models.pmlp import PMLP


from torch.utils.data import DataLoader,Dataset
# from data.datasets import SpecialOrthogonalGroup

from geomstats._backend import _backend_config as _config
_config.DEFAULT_DTYPE = torch.cuda.FloatTensor 

In [5]:
so3_group = SpecialOrthogonal(n=3, point_type="matrix")
FM = SO3ConditionalFlowMatcher(manifold=so3_group)
def loss_fn(v, u, x):
    res = v - u
    norm = norm_SO3(x, res) # norm-squared on SO(3)
    loss = torch.mean(norm, dim=-1)
    return loss

dim = 9 # network ouput is 9 dimensional (3x3 matrix)
# MLP with a projection at the end, projection on to the tanget space of the manifold
model = PMLP(dim=dim, time_varying=True) 
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
device = 'cpu'
# ODE inference on SO(3)
def inference(model, xt, t, dt):
    with torch.no_grad():
        vt = model(torch.cat([xt, t[:, None]], dim=-1)) # vt on the tanget of xt
        vt = rearrange(vt, 'b (c d) -> b c d', c=3, d=3)
        xt = rearrange(xt, 'b (c d) -> b c d', c=3, d=3)
        xt_new = expmap(xt, vt * dt)                   # expmap to get the next point
    return rearrange(xt_new, 'b c d -> b (c d)', c=3, d=3)
# def inference_recursive(model, x_0, steps=100, device='cuda'):
#     t = torch.linspace(0, 1, steps).to(device)
#     def ode_func(t, xt,dt):
#         # Reshape t to match model input expectations
#         t_batch = torch.full((x.shape[0], 1), t.item(), device=device)
#         with torch.no_grad():
#             vt = model(torch.cat([xt, t[:, None]], dim=-1)) # vt on the tanget of xt
#             vt = rearrange(vt, 'b (c d) -> b c d', c=3, d=3)
#             xt = rearrange(xt, 'b (c d) -> b c d', c=3, d=3)
#             xt_new = expmap(xt, vt * dt)                   # expmap to get the next point
#         return rearrange(xt_new, 'b c d -> b (c d)', c=3, d=3)
#         #return flow_model(x, t_batch)
#     # Integrate from t=0 to t=1
#     trajectory = odeint(
#         ode_func,
#         x_0,
#         t,
#         method='rk4'  # You can also try 'dopri5' for adaptive stepping
#     )
    
#     return trajectory

In [7]:
meshes = glob.glob("../data/meshes/**/*.obj")
grasps = glob.glob("../data/grasps/*.h5")
example_obj= meshes[0]
example_grasp = grasps[0]


example_obj_id = example_obj.split("/")[-1].split(".")[0]
print("Example obj: ", example_obj)

corresponding_grasps = [grasp for grasp in grasps if example_obj_id in grasp][0]

with h5py.File(example_grasp, 'r') as h5file:
    grasp_T = h5file['grasps']['transforms'][0,:,:]
grasp_T = torch.tensor(grasp_T).unsqueeze(0).float()
grasp_T.shape


Example obj:  ../data/meshes/CerealBox/a61cd12446207107d59ff053d1480d84.obj


torch.Size([1, 4, 4])

In [8]:
class GraspDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        so3_part = self.data[idx][:3,:3]
        translational_part = self.data[idx][:3,3]
        return so3_part, translational_part
    
grasp_dataset = GraspDataset(grasp_T.double())
trainloader = DataLoader(grasp_dataset, batch_size=100, shuffle=True)
testset = DataLoader(grasp_dataset, batch_size=100, shuffle=False)

In [9]:
def main_loop(model, optimizer, num_epochs=1, display=True):
    losses = []
    global_step = 0
    
    # Create a single progress bar for all epochs
    with tqdm(total=num_epochs * len(trainloader), desc="Training") as global_progress_bar:
        for epoch in range(num_epochs):
            epoch_losses = []
            
            if (epoch % 10) == 0:
                n_test = len(testset.dataset)
                traj = torch.tensor(Rotation.random(n_test).as_matrix()).to(device).reshape(-1, 9)
                for t in torch.linspace(0, 1, 200):
                    t = torch.tensor([t]).to(device).repeat(n_test).requires_grad_(True)
                    dt = torch.tensor([1/200]).to(device)
                    traj = inference(model, traj, t, dt)
                final_traj = rearrange(traj, 'b (c d) -> b c d', c=3, d=3)
            
            for _, (so3_data, trnslt_part) in enumerate(trainloader):
                optimizer.zero_grad()
                
                # Repeat the data if needed
                so3_data = so3_data.repeat(1000, 1, 1)
                x1 = so3_data.to(device).double()
                x0 = torch.tensor(Rotation.random(x1.size(0)).as_matrix(), dtype=torch.float64).to(device)
                
                t, xt, ut = FM.sample_location_and_conditional_flow_simple(x0, x1)
                
                vt = model(torch.cat([rearrange(xt, 'b c d -> b (c d)', c=3, d=3), t[:, None]], dim=-1))
                vt = rearrange(vt, 'b (c d) -> b c d', c=3, d=3)
                
                loss = loss_fn(vt, ut, xt)
                epoch_losses.append(loss.detach().item())
                losses.append(loss.detach().cpu().numpy())
                
                loss.backward()
                optimizer.step()
                
                # Update the global progress bar
                global_progress_bar.update(1)
                global_progress_bar.set_postfix({
                    'Epoch': epoch, 
                    'Loss': f'{loss.item():.4f}', 
                    'Avg Loss': f'{np.mean(epoch_losses):.4f}'
                })
                
                global_step += 1
    
    return model, np.array(losses)

# Run training
model, losses = main_loop(model, optimizer, num_epochs=1000, display=True)

Training: 100%|██████████| 1000/1000 [02:11<00:00,  7.61it/s, Epoch=999, Loss=0.0863, Avg Loss=0.0863]


In [20]:
n_test = len(grasp_dataset)
traj = torch.tensor(Rotation.random(n_test).as_matrix()).reshape(-1, 9)
for t in torch.linspace(0, 1, 200):
    t = torch.tensor([t],dtype=torch.float64).repeat(n_test)
    dt = torch.tensor([1/200])
    traj = inference(model, traj, t, dt)
final_traj = rearrange(traj, 'b (c d) -> b c d', c=3, d=3)
final_traj,grasp_dataset.data[:3,:3]

(tensor([[[ 0.2007,  0.3251, -0.9241],
          [ 0.9171,  0.2694,  0.2940],
          [ 0.3445, -0.9065, -0.2440]]]),
 tensor([[[-0.3808, -0.3274,  0.8647, -0.2228],
          [-0.1488,  0.9447,  0.2922, -0.1666],
          [-0.9126, -0.0174, -0.4085,  0.1002]]]))

In [21]:
class SE3VelocityField(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=64): #trial for translation 
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_dim + 1, hidden_dim),  # Include time t as dim+1
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)  # 3 for translation, we will implement SO3 later rt = expr0(tlogr0(r1)) with linalg inverse 
        )

    def forward(self, T, t):
        #T_flat = T.view(T.shape[0], -1)  # Flatten T
        input_data = torch.cat([T,t ], dim=1)
        return self.net(input_data)


In [22]:
def conditional_flow_matching_loss(flow_model, x):
    #Question: Should we calculate one for each time step or generate one time at a time?
    
    sigma_min = 1e-4
    t = torch.rand(x.shape[0], device=x.device).unsqueeze(-1)
    noise = torch.randn_like(x).to(x.device)

    x_t = (1 - (1 - sigma_min) * t) * noise + t* x
    optimal_flow = x - (1 - sigma_min) * noise
    predicted_flow = flow_model(x_t, t)

    return (predicted_flow - optimal_flow).square().mean()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SE3VelocityField().to(device)
x = grasp_T[:, :3, 3]
x_train = x.repeat(1000, 1).to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

for epoch in range(100000):
    model.zero_grad()
    loss = conditional_flow_matching_loss(model,x_train)
    if epoch % 100 == 0:
        print(f'Epoch: {epoch}',loss.item())
    loss.backward()
    optimizer.step()



Epoch: 0 1.0780157394451935
Epoch: 100 0.1735529375801334
Epoch: 200 0.15793626708497058
Epoch: 300 0.1419872366921193
Epoch: 400 0.12431893279608217
Epoch: 500 0.09352229379606569
Epoch: 600 0.09848773438141888
Epoch: 700 0.07494786140355303
Epoch: 800 0.07107831549405334
Epoch: 900 0.06115682387155495
Epoch: 1000 0.06291244604575141
Epoch: 1100 0.052585109247082984
Epoch: 1200 0.05301342046876092
Epoch: 1300 0.062307890074896344
Epoch: 1400 0.050254247895819185
Epoch: 1500 0.04238361949063993
Epoch: 1600 0.056443570236101824
Epoch: 1700 0.0417227896832179
Epoch: 1800 0.049891370174069756
Epoch: 1900 0.03843867549317297
Epoch: 2000 0.04052325279315961
Epoch: 2100 0.05866970867426109
Epoch: 2200 0.03767757816371865
Epoch: 2300 0.040966510002591425
Epoch: 2400 0.03880663284367613
Epoch: 2500 0.039523345338779166
Epoch: 2600 0.037861939227961834
Epoch: 2700 0.03878776995874324
Epoch: 2800 0.03329922798936867
Epoch: 2900 0.034159328225875
Epoch: 3000 0.029399214354817498
Epoch: 3100 0.028

KeyboardInterrupt: 

In [10]:
def run_flow(flow_model, x_0, steps=100, device='cuda'):
    t = torch.linspace(0, 1, steps).to(device)
    def ode_func(t, x):
        # Reshape t to match model input expectations
        t_batch = torch.full((x.shape[0], 1), t.item(), device=device)
        return flow_model(x, t_batch)
    # Integrate from t=0 to t=1
    trajectory = odeint(
        ode_func,
        x_0,
        t,
        method='rk4'  # You can also try 'dopri5' for adaptive stepping
    )
    
    return trajectory

noise = torch.randn_like(grasp_T[:,:3,3]).to(device)
trajectory = run_flow(model, noise, steps=100, device=device)
print(trajectory[-1],grasp_T[:,:3,3])

tensor([[-0.2271, -0.1689,  0.1003]], grad_fn=<SelectBackward0>) tensor([[-0.2228, -0.1666,  0.1002]])


In [5]:
import h5py

def explore_group(group, indent=""):
    """Recursively explore an HDF5 group and its contents"""
    for name, item in group.items():
        if isinstance(item, h5py.Group):
            print(f"{indent}Group: {name}")
            print(f"{indent}  Contents: {list(item.keys())}")
            explore_group(item, indent + "  ")
        elif isinstance(item, h5py.Dataset):
            print(f"{indent}Dataset: {name}")
            print(f"{indent}  Shape: {item.shape}")
            print(f"{indent}  Type: {item.dtype}")
            try:
                print(f"{indent}  First few values: {item[:2]}")
            except Exception as e:
                print(f"{indent}  Could not print values: {e}")

with h5py.File(example_grasp, 'r') as h5file:
    print("Top-level groups:", list(h5file.keys()))
    
    print("\nExploring complete structure:")
    for top_group_name in h5file.keys():
        print(f"\n=== {top_group_name} ===")
        top_group = h5file[top_group_name]
        explore_group(top_group)

Top-level groups: ['grasps', 'gripper', 'object']

Exploring complete structure:

=== grasps ===
Group: qualities
  Contents: ['flex']
  Group: flex
    Contents: ['object_in_gripper', 'object_motion_during_closing_angular', 'object_motion_during_closing_linear', 'object_motion_during_shaking_angular', 'object_motion_during_shaking_linear']
    Dataset: object_in_gripper
      Shape: (2000,)
      Type: int64
      First few values: [1 1]
    Dataset: object_motion_during_closing_angular
      Shape: (2000,)
      Type: float64
      First few values: [0.54927611 0.15047622]
    Dataset: object_motion_during_closing_linear
      Shape: (2000,)
      Type: float64
      First few values: [0.09023842 0.02939418]
    Dataset: object_motion_during_shaking_angular
      Shape: (2000,)
      Type: float64
      First few values: [0.0185062  0.04676599]
    Dataset: object_motion_during_shaking_linear
      Shape: (2000,)
      Type: float64
      First few values: [0.00300531 0.00718987]
Dat