
BACKEND part should be run or the geomstats does not convert correctly


In [1]:
%env GEOMSTATS_BACKEND=pytorch
%load_ext autoreload
%autoreload 2

import trimesh
import pyrender
import numpy as np
import glob 
import h5py
from tqdm import tqdm
import torch
import torch.nn as nn
from torchdiffeq import odeint

env: GEOMSTATS_BACKEND=pytorch


In [2]:
import sys
sys.path.append("..")  

import os
import numpy as np
import torch
from einops import rearrange
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')


from scipy.spatial.transform import Rotation
from geomstats.geometry.special_orthogonal import SpecialOrthogonal

# from utils.plotting import plot_so3
# from utils.optimal_transport import so3_wasserstein as wasserstein
# from FoldFlow.foldflow.utils.so3_helpers import norm_SO3, expmap
# from FoldFlow.foldflow.utils.so3_condflowmatcher import SO3ConditionalFlowMatcher
# from FoldFlow.so3_experiments.models.models import PMLP

from so3_helpers import norm_SO3, expmap
from so3_condflowmatcher import SO3ConditionalFlowMatcher
from models import PMLP

from torch.utils.data import DataLoader,Dataset
# from data.datasets import SpecialOrthogonalGroup

from geomstats._backend import _backend_config as _config
_config.DEFAULT_DTYPE = torch.cuda.FloatTensor 

In [3]:
so3_group = SpecialOrthogonal(n=3, point_type="matrix")
FM = SO3ConditionalFlowMatcher(manifold=so3_group)
def loss_fn(v, u, x):
    res = v - u
    norm = norm_SO3(x, res) # norm-squared on SO(3)
    loss = torch.mean(norm, dim=-1)
    return loss

dim = 9 # network ouput is 9 dimensional (3x3 matrix)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# MLP with a projection at the end, projection on to the tanget space of the manifold
model = PMLP(dim=dim, time_varying=True).to(device)  
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# ODE inference on SO(3)
def inference(model, xt, t, dt):
    with torch.no_grad():
        vt = model(torch.cat([xt, t[:, None]], dim=-1)) # vt on the tanget of xt
        vt = rearrange(vt, 'b (c d) -> b c d', c=3, d=3)
        xt = rearrange(xt, 'b (c d) -> b c d', c=3, d=3)
        xt_new = expmap(xt, vt * dt)                   # expmap to get the next point
    return rearrange(xt_new, 'b c d -> b (c d)', c=3, d=3)
# def inference_recursive(model, x_0, steps=100, device='cuda'):
#     t = torch.linspace(0, 1, steps).to(device)
#     def ode_func(t, xt,dt):
#         # Reshape t to match model input expectations
#         t_batch = torch.full((x.shape[0], 1), t.item(), device=device)
#         with torch.no_grad():
#             vt = model(torch.cat([xt, t[:, None]], dim=-1)) # vt on the tanget of xt
#             vt = rearrange(vt, 'b (c d) -> b c d', c=3, d=3)
#             xt = rearrange(xt, 'b (c d) -> b c d', c=3, d=3)
#             xt_new = expmap(xt, vt * dt)                   # expmap to get the next point
#         return rearrange(xt_new, 'b c d -> b (c d)', c=3, d=3)
#         #return flow_model(x, t_batch)
#     # Integrate from t=0 to t=1
#     trajectory = odeint(
#         ode_func,
#         x_0,
#         t,
#         method='rk4'  # You can also try 'dopri5' for adaptive stepping
#     )
    
#     return trajectory

In [4]:
meshes = glob.glob("data/meshes/**/*.obj")
grasps = glob.glob("data/grasps/*.h5")
example_obj= meshes[0]
example_grasp = grasps[0]


example_obj_id = example_obj.split("/")[-1].split(".")[0]
print("Example obj: ", example_obj)

corresponding_grasps = [grasp for grasp in grasps if example_obj_id in grasp][0]

with h5py.File(example_grasp, 'r') as h5file:
    grasp_T = h5file['grasps']['transforms'][0,:,:]
grasp_T = torch.tensor(grasp_T).unsqueeze(0).float()
grasp_T.shape


Example obj:  data/meshes/CerealBox/a61cd12446207107d59ff053d1480d84.obj


torch.Size([1, 4, 4])

In [5]:
class GraspDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        so3_part = self.data[idx][:3,:3]
        translational_part = self.data[idx][:3,3]
        return so3_part, translational_part
    
grasp_dataset = GraspDataset(grasp_T.double())
trainloader = DataLoader(grasp_dataset, batch_size=100, shuffle=True)
testset = DataLoader(grasp_dataset, batch_size=100, shuffle=False)

In [None]:
def main_loop(model, optimizer, num_epochs=150, display=True):
    losses = []
    global_step = 0
    
    # Create a single progress bar for all epochs
    with tqdm(total=num_epochs * len(trainloader), desc="Training") as global_progress_bar:
        for epoch in range(num_epochs):
            epoch_losses = []
            
            if (epoch % 10) == 0:
                n_test = len(testset.dataset)
                traj = torch.tensor(Rotation.random(n_test).as_matrix()).to(device).reshape(-1, 9)
                for t in torch.linspace(0, 1, 200):
                    t = torch.tensor([t]).to(device).repeat(n_test).requires_grad_(True)
                    dt = torch.tensor([1/200]).to(device)
                    traj = inference(model, traj, t, dt)
                final_traj = rearrange(traj, 'b (c d) -> b c d', c=3, d=3)
            
            for _, (so3_data, trnslt_part) in enumerate(trainloader):
                optimizer.zero_grad()
                
                # Repeat the data if needed
                so3_data = so3_data.repeat(1000, 1, 1)
                x1 = so3_data.to(device).double()
                x0 = torch.tensor(Rotation.random(x1.size(0)).as_matrix(), dtype=torch.float64).to(device)
                
                t, xt, ut = FM.sample_location_and_conditional_flow_simple(x0, x1)
                
                vt = model(torch.cat([rearrange(xt, 'b c d -> b (c d)', c=3, d=3), t[:, None]], dim=-1))
                vt = rearrange(vt, 'b (c d) -> b c d', c=3, d=3)
                
                loss = loss_fn(vt, ut, xt)
                epoch_losses.append(loss.detach().item())
                losses.append(loss.detach().cpu().numpy())
                
                loss.backward()
                optimizer.step()
                
                # Update the global progress bar
                global_progress_bar.update(1)
                global_progress_bar.set_postfix({
                    'Epoch': epoch, 
                    'Loss': f'{loss.item():.4f}', 
                    'Avg Loss': f'{np.mean(epoch_losses):.4f}'
                })
                
                global_step += 1
    
    return model, np.array(losses)

# Run training
model, losses = main_loop(model, optimizer, num_epochs=1000, display=True)

Training: 100%|██████████| 1000/1000 [01:25<00:00, 11.73it/s, Epoch=999, Loss=0.0735, Avg Loss=0.0735]


In [7]:
n_test = len(grasp_dataset)
traj = torch.tensor(Rotation.random(n_test).as_matrix()).to(device).reshape(-1, 9)
for t in torch.linspace(0, 1, 200):
    t = torch.tensor([t]).to(device).repeat(n_test)
    dt = torch.tensor([1/200]).to(device)
    traj = inference(model, traj, t, dt)
final_traj = rearrange(traj, 'b (c d) -> b c d', c=3, d=3)
final_traj,grasp_dataset.data[:3,:3]

(tensor([[[-0.3679, -0.3304,  0.8692],
          [-0.1471,  0.9437,  0.2964],
          [-0.9181, -0.0188, -0.3958]]], dtype=torch.float64),
 tensor([[[-0.3808, -0.3274,  0.8647, -0.2228],
          [-0.1488,  0.9447,  0.2922, -0.1666],
          [-0.9126, -0.0174, -0.4085,  0.1002]]], dtype=torch.float64))

In [8]:
class SE3VelocityField(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=64): #trial for translation 
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_dim + 1, hidden_dim),  # Include time t as dim+1
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)  # 3 for translation, we will implement SO3 later rt = expr0(tlogr0(r1)) with linalg inverse 
        )

    def forward(self, T, t):
        #T_flat = T.view(T.shape[0], -1)  # Flatten T
        input_data = torch.cat([T,t ], dim=1)
        return self.net(input_data)


In [9]:
def conditional_flow_matching_loss(flow_model, x):
    #Question: Should we calculate one for each time step or generate one time at a time?
    
    sigma_min = 1e-4
    t = torch.rand(x.shape[0], device=x.device).unsqueeze(-1)
    noise = torch.randn_like(x).to(x.device)

    x_t = (1 - (1 - sigma_min) * t) * noise + t* x
    optimal_flow = x - (1 - sigma_min) * noise
    predicted_flow = flow_model(x_t, t)

    return (predicted_flow - optimal_flow).square().mean()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SE3VelocityField().to(device)
x = grasp_T[:, :3, 3]
x_train = x.repeat(1000, 1).to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

for epoch in range(100000):
    model.zero_grad()
    loss = conditional_flow_matching_loss(model,x_train)
    if epoch % 100 == 0:
        print(f'Epoch: {epoch}',loss.item())
    loss.backward()
    optimizer.step()



Epoch: 0 1.0490559339523315
Epoch: 100 0.18860095739364624
Epoch: 200 0.17464306950569153
Epoch: 300 0.14537473022937775
Epoch: 400 0.10317289084196091
Epoch: 500 0.09229150414466858
Epoch: 600 0.09045705199241638
Epoch: 700 0.07828841358423233
Epoch: 800 0.07142215222120285
Epoch: 900 0.06375863403081894
Epoch: 1000 0.06571261584758759
Epoch: 1100 0.055483411997556686
Epoch: 1200 0.05109038203954697
Epoch: 1300 0.05603230372071266
Epoch: 1400 0.04873558506369591
Epoch: 1500 0.04899878799915314
Epoch: 1600 0.049523208290338516
Epoch: 1700 0.05539128556847572
Epoch: 1800 0.04877127707004547
Epoch: 1900 0.05637989193201065
Epoch: 2000 0.04719003662467003
Epoch: 2100 0.03758149594068527
Epoch: 2200 0.04356033727526665
Epoch: 2300 0.034619301557540894
Epoch: 2400 0.040414467453956604
Epoch: 2500 0.034439023584127426
Epoch: 2600 0.036813341081142426
Epoch: 2700 0.0403117835521698
Epoch: 2800 0.03305071219801903
Epoch: 2900 0.031559381633996964
Epoch: 3000 0.030370160937309265
Epoch: 3100 0.

In [10]:
def run_flow(flow_model, x_0, steps=100, device='cuda'):
    t = torch.linspace(0, 1, steps).to(device)
    def ode_func(t, x):
        # Reshape t to match model input expectations
        t_batch = torch.full((x.shape[0], 1), t.item(), device=device)
        return flow_model(x, t_batch)
    # Integrate from t=0 to t=1
    trajectory = odeint(
        ode_func,
        x_0,
        t,
        method='rk4'  # You can also try 'dopri5' for adaptive stepping
    )
    
    return trajectory

noise = torch.randn_like(grasp_T[:,:3,3]).to(device)
trajectory = run_flow(model, noise, steps=100, device=device)
print(trajectory[-1],grasp_T[:,:3,3])

tensor([[-0.2271, -0.1689,  0.1003]], grad_fn=<SelectBackward0>) tensor([[-0.2228, -0.1666,  0.1002]])


In [11]:
import h5py

def explore_group(group, indent=""):
    """Recursively explore an HDF5 group and its contents"""
    for name, item in group.items():
        if isinstance(item, h5py.Group):
            print(f"{indent}Group: {name}")
            print(f"{indent}  Contents: {list(item.keys())}")
            explore_group(item, indent + "  ")
        elif isinstance(item, h5py.Dataset):
            print(f"{indent}Dataset: {name}")
            print(f"{indent}  Shape: {item.shape}")
            print(f"{indent}  Type: {item.dtype}")
            try:
                print(f"{indent}  First few values: {item[:2]}")
            except Exception as e:
                print(f"{indent}  Could not print values: {e}")

with h5py.File(example_grasp, 'r') as h5file:
    print("Top-level groups:", list(h5file.keys()))
    
    print("\nExploring complete structure:")
    for top_group_name in h5file.keys():
        print(f"\n=== {top_group_name} ===")
        top_group = h5file[top_group_name]
        explore_group(top_group)

Top-level groups: ['grasps', 'gripper', 'object']

Exploring complete structure:

=== grasps ===
Group: qualities
  Contents: ['flex']
  Group: flex
    Contents: ['object_in_gripper', 'object_motion_during_closing_angular', 'object_motion_during_closing_linear', 'object_motion_during_shaking_angular', 'object_motion_during_shaking_linear']
    Dataset: object_in_gripper
      Shape: (2000,)
      Type: int64
      First few values: [1 1]
    Dataset: object_motion_during_closing_angular
      Shape: (2000,)
      Type: float64
      First few values: [0.54927611 0.15047622]
    Dataset: object_motion_during_closing_linear
      Shape: (2000,)
      Type: float64
      First few values: [0.09023842 0.02939418]
    Dataset: object_motion_during_shaking_angular
      Shape: (2000,)
      Type: float64
      First few values: [0.0185062  0.04676599]
    Dataset: object_motion_during_shaking_linear
      Shape: (2000,)
      Type: float64
      First few values: [0.00300531 0.00718987]
Dat