In [41]:
import h5py
import numpy as np
from tqdm import tqdm

Convert D-FAUST dataset into .ply file

In [15]:
file_path = "/usr/stud/srinivaa/storage/slurm/cadex/dfaust_resgistered/Humans/50004_hips/mesh_seq_registered/00000000.npz"
data = np.load(file_path)
# pcd = o3d.geometry.PointCloud()
# pcd.points = o3d.utility.Vector3dVector(data['points'])
# o3d.io.write_point_cloud("registered_dfaust.ply", pcd)

o3d_mesh = o3d.geometry.TriangleMesh()
o3d_mesh.vertices = o3d.utility.Vector3dVector(data["points"]) # verify what is the name of the attribute for vertices in data dictionary. It should be vertices as given in load method in MeshField class.
o3d_mesh.triangles = o3d.utility.Vector3iVector(data["triangles"])
o3d.io.write_triangle_mesh("registered_dfaust.ply",o3d_mesh)

True

### Extract data from registred DFAUST scans

In [10]:
filename = "/usr/stud/srinivaa/storage/slurm/cadex/dfaust_resgistered/registrations_f.hdf5"
f =  h5py.File(filename, "r")
a_dataset_keys = list(f.keys()) #extract the different groups of meshes    

In [11]:
dataset_list = [a_dataset_keys[0],a_dataset_keys[-1]]

In [14]:
for data in dataset_list:
    data_face = f[data]
data_face_arr = np.asarray(data_face)

for data in a_dataset_keys[:1]:
    data_i = f[data]
data_i_arr = np.asarray(data_i)

In [65]:
o3d_mesh = o3d.geometry.TriangleMesh()
o3d_mesh.vertices = o3d.utility.Vector3dVector(data_i_arr[:,:,0]) # verify what is the name of the attribute for vertices in data dictionary. It should be vertices as given in load method in MeshField class.
o3d_mesh.triangles = o3d.utility.Vector3iVector(data_face_arr)
o3d.io.write_triangle_mesh("registered_dfaust.ply",o3d_mesh)

True

## Custom dataset of CaDeX

In [1]:
import numpy as np
import os
import random
import dataset.oflow_dataset as oflow_dataset
from torch.utils import data
from torch.utils.data import DataLoader

Write a class for custom dataset 

In [2]:
class HumansDataset(data.Dataset):
    def __init__(
        self,
        dataset_folder,
        fields,
        split=None,
        categories=None,
        no_except=True,
        transform=None,
        length_sequence=2,
        n_files_per_sequence=-1,
        offset_sequence=0,
        ex_folder_name="pcl_seq",
        **kwargs
    ):
        # Attributes
        self.dataset_folder = dataset_folder
        self.fields = fields
        self.no_except = no_except
        self.transform = transform
        self.length_sequence = length_sequence
        self.n_files_per_sequence = n_files_per_sequence
        self.offset_sequence = offset_sequence
        self.ex_folder_name = ex_folder_name
       
        # Read metadata file
        metadata_file = os.path.join(dataset_folder, "metadata.yaml")

        
        self.metadata = {c: {"id": c, "name": "n/a"} for c in categories}

        # Set index
        for c_idx, c in enumerate(categories):
            self.metadata[c]["idx"] = c_idx #only one category: D-FAUST. contains single ID only

        # Get all models
        self.models = []
        for c_idx, c in enumerate(categories):
            subpath = os.path.join(dataset_folder, c) #subpath: /usr/stud/srinivaa/code/new_CaDeX/CaDeX/resource/data/Humans/D-FAUST
           
            if split is not None and os.path.exists(os.path.join(subpath, split + ".lst")):
                split_file = os.path.join(subpath, split + ".lst") # for train mode: /usr/stud/srinivaa/code/new_CaDeX/CaDeX/resource/data/Humans/D-FAUST/train.lst
                with open(split_file, "r") as f:
                    models_c = f.read().split("\n") # All files in train.lst for training mode
           
            models_c = list(filter(lambda x: len(x) > 0, models_c))
            models_len = self.get_models_seq_len(subpath, models_c) # gives the total number .npz files in each model
            models_c, start_idx = self.subdivide_into_sequences(models_c, models_len)
            self.models += [
                {"category": c, "model": m, "start_idx": start_idx[i]}
                for i, m in enumerate(models_c)
            ]
        
       

    def __len__(self):
       return len(self.models)
    
    
    def __getitem__(self,idx):
        category = self.models[idx]["category"]
        model = self.models[idx]["model"]
        start_idx = self.models[idx]["start_idx"]
        c_idx = self.metadata[category]["idx"]

        model_path = os.path.join(self.dataset_folder, category, model)

        data = {}
        # files = os.listdir(os.path.join(model_path,"mesh_registered"))
        # folder_length = np.arange(len(files))
        # file_indexes = random.sample(list(folder_length),k=17)
        
        for field_name, field in self.fields.items():
             
            field_data = field.load(model_path, idx, c_idx, start_idx)
 
            if isinstance(field_data, dict):
                for k, v in field_data.items():
                    if k is None:
                        data[field_name] = v
                    else:
                        data["%s.%s" % (field_name, k)] = v
            else:
                data[field_name] = field_data
           

        if self.transform is not None:
            data = self.transform(data)

        
        
        return data

    def get_models_seq_len(self, subpath, models):
        """Returns the sequence length of a specific model.

        This is a little "hacky" as we assume the existence of the folder
        self.ex_folder_name. However, in our case this is always given.

        Args:
            subpath (str): subpath of model category
            models (list): list of model names
        """
        ex_folder_name = self.ex_folder_name
        models_seq_len = []
        for m in models:
            _sublist = [
                f for f in os.listdir(os.path.join(subpath, m, ex_folder_name))# if "_" not in f
            ]
            models_seq_len.append(len(_sublist))
        # models_seq_len = [len(os.listdir(os.path.join(subpath, m, ex_folder_name))) for m in models]
        return models_seq_len

    def subdivide_into_sequences(self, models, models_len):
        """Subdivides model sequence into smaller sequences.

        Args:
            models (list): list of model names
            models_len (list): list of lengths of model sequences
        """
        length_sequence = self.length_sequence
        n_files_per_sequence = self.n_files_per_sequence
        offset_sequence = self.offset_sequence

        # Remove files before offset
        models_len = [l - offset_sequence for l in models_len]

        # Reduce to maximum number of files that should be considered
        if n_files_per_sequence > 0:
            models_len = [min(n_files_per_sequence, l) for l in models_len]

        models_out = []
        start_idx = []
        for idx, model in enumerate(models):
            for n in range(0, models_len[idx] - length_sequence + 1):
                models_out.append(model)
                start_idx.append(n + offset_sequence)

        return models_out, start_idx   
    

In [3]:
def get_transforms():
    """Returns transform objects.

    Args:
        cfg (yaml config): yaml config object
    """
    n_pcl = 100
    n_pt = 512
    n_pt_eval = 10000

    transf_pt = oflow_dataset.SubsamplePoints(n_pt)
    transf_pt_val = oflow_dataset.SubsamplePointsSeq(n_pt_eval, random=False)
    transf_pcl_val = oflow_dataset.SubsamplePointcloudSeq(n_pt_eval, random=False)
    transf_pcl = oflow_dataset.SubsamplePointcloudSeq(n_pcl, connected_samples=True)

    return transf_pt, transf_pt_val, transf_pcl, transf_pcl_val

In [4]:
def get_data_fields(mode):
    """Returns data fields.

    Args:
        mode (str): mode (train|val|test)
        cfg (yaml config): yaml config object
    """
    fields = {}
    seq_len_train = 2
   
    seq_len_val = seq_len_train
    p_folder = "points_seq" # points_seq: contains information regarding the points and their corresponding occupancy values
    pcl_folder = "pcl_seq" #pcl_seq : contains information regarding the points, scale and loc
    mesh_folder = "mesh_registered" #mesh_seq: non-existent. Utilize this to have a file containing points and faces for each model.
    generate_interpolate = False #False
    unpackbits = False # True
    
    training_all = False
    
    n_training_frames = 2

    # Transformation
    transf_pt, transf_pt_val, transf_pcl, transf_pcl_val = get_transforms()

    # Fields
    pts_iou_field = oflow_dataset.PointsSubseqField
    pts_corr_field = oflow_dataset.PointCloudSubseqField

    # MeshSubseqField can be used to load mesh fields

  
    not_choose_last = False
    training_multi_files = False
    
    loss_recon = "false"
    loss_corr = "true"

    if mode == "train":
        if loss_recon:
            if training_all:
                fields["points"] = pts_iou_field(
                    p_folder,
                    transform=transf_pt,
                    all_steps=True,
                    seq_len=seq_len_train,
                    unpackbits=unpackbits,
                    use_multi_files=training_multi_files,
                )
            else:
                fields["points"] = pts_iou_field(
                    p_folder,
                    sample_nframes=n_training_frames,
                    transform=transf_pt,
                    seq_len=seq_len_train,
                    fixed_time_step=0,
                    unpackbits=unpackbits,
                    use_multi_files=training_multi_files)
                
            # fields["points_t"] = pts_iou_field(
            #     p_folder,
            #     transform=transf_pt,
            #     seq_len=seq_len_train,
            #     unpackbits=unpackbits,
            #     not_choose_last=not_choose_last,
            #     use_multi_files=training_multi_files,
            # )


            fields["mesh"] = oflow_dataset.MeshField(
            mesh_folder, seq_len=seq_len_val)

    # only training can be boost by multi-files
    # modify here, if not train, val should also load the same as the test
    else:
        # fields["points"] = pts_iou_field(
        #     p_folder,
        #     transform=transf_pt_val,
        #     all_steps=True,
        #     seq_len=seq_len_val,
        #     unpackbits=unpackbits,
        # )
        fields[
            "points_mesh"
        ] = pts_corr_field(  # ? this if for correspondence? Checked, this is for chamfer distance, make sure that because here we use tranforms, teh pts in config file must be 100000
            pcl_folder, transform=transf_pcl_val, seq_len=seq_len_val
        )
    # Connectivity Loss:
    if loss_corr:
        fields["pointcloud"] = oflow_dataset.PointCloudField(mesh_folder, seq_len=seq_len_val)
    if mode == "test" and generate_interpolate:
        fields["mesh"] = oflow_dataset.MeshSubseqField(
            mesh_folder, seq_len=seq_len_val, only_end_points=True
        )
    #fields["oflow_idx"] = oflow_dataset.IndexField()
    return fields

In [5]:
def get_inputs_field(mode):
    
    input_type = "mesh_seq"
    seq_len_train = 2
    seq_len_val = seq_len_train
    
    seq_len = seq_len_train
    

    if input_type is None:
        inputs_field = None
    elif input_type == "img_seq":
        if mode == "train" and cfg["dataset"]["oflow_config"]["img_augment"]:
            resize_op = transforms.RandomResizedCrop(
                cfg["dataset"]["oflow_config"]["img_size"], (0.75, 1.0), (1.0, 1.0)
            )
        else:
            resize_op = transforms.Resize((cfg["dataset"]["oflow_config"]["img_size"]))

        transform = transforms.Compose(
            [
                resize_op,
                transforms.ToTensor(),
            ]
        )

        if mode == "train":
            random_view = True
        else:
            random_view = False

        inputs_field = oflow_dataset.ImageSubseqField(
            cfg["dataset"]["oflow_config"]["img_seq_folder"], transform, random_view=random_view
        )
    elif input_type == "pcl_seq":
        connected_samples = cfg["dataset"]["oflow_config"]["input_pointcloud_corresponding"]
        transform = transforms.Compose(
            [
                oflow_dataset.SubsamplePointcloudSeq(
                    cfg["dataset"]["oflow_config"]["input_pointcloud_n"],
                    connected_samples=connected_samples,
                ),
                oflow_dataset.PointcloudNoise(
                    cfg["dataset"]["oflow_config"]["input_pointcloud_noise"]
                ),
            ]
        )
        training_multi_files = False
        if "training_multi_files" in cfg["dataset"]["oflow_config"]:
            if cfg["dataset"]["oflow_config"]["training_multi_files"] and mode == "train":
                training_multi_files = True
                logging.info(
                    "Oflow D-FAUST PCL Field use multi files to speed up disk performation"
                )

        inputs_field = oflow_dataset.PointCloudSubseqField(
            cfg["dataset"]["oflow_config"]["pointcloud_seq_folder"],
            transform,
            seq_len=seq_len,
            use_multi_files=training_multi_files,
        )
    #TODO : get inputs fields for mesh sequence
    elif input_type == "mesh_seq":
        
        # transform = transforms.Compose(
        #     [
        #         #oflow_dataset.MeshNoise(),
        #         #oflow_dataset.DownSampleMesh(N = 512)
        #     ]
        # )

        inputs_field = oflow_dataset.MeshField(
            "mesh_registered"
        )
    elif input_type == "end_pointclouds":
        transform = oflow_dataset.SubsamplePointcloudSeq(
            cfg["dataset"]["oflow_config"]["input_pointcloud_n"],
            connected_samples=cfg["dataset"]["oflow_config"]["input_pointcloud_corresponding"],
        )

        inputs_field = oflow_dataset.PointCloudSubseqField(
            cfg["dataset"]["oflow_config"]["pointcloud_seq_folder"],
            only_end_points=True,
            seq_len=seq_len,
            transform=transform,
        )
    elif input_type == "idx":
        inputs_field = oflow_dataset.IndexField()
    else:
        raise ValueError("Invalid input type (%s)" % input_type)
    return inputs_field

In [6]:
fields = get_data_fields("train")
inputs_field = get_inputs_field("train")

if inputs_field is not None:
    fields["inputs"] = inputs_field

In [7]:
dataset_folder = "/usr/stud/srinivaa/storage/slurm/data/animals"
categories = ["SMAL"]

Instantiate a custom dataset object

In [8]:
dataset = HumansDataset(
        dataset_folder,
        fields,
        split="train",
        categories=categories,
        length_sequence=2,
        n_files_per_sequence=-1,
        offset_sequence=1,
        ex_folder_name="mesh_registered",
    )

In [9]:
dataset[1]

<numpy.lib.npyio.NpzFile object at 0x7f9a5372df10>
['/usr/stud/srinivaa/storage/slurm/data/animals/SMAL/big_cats/points_seq/450-122410176-lions-natural-habitat.npz', '/usr/stud/srinivaa/storage/slurm/data/animals/SMAL/big_cats/points_seq/MaleLion800.npz']


{'points': array([[[ 0.6439394 , -0.11363637, -0.6136364 ],
         [-0.5984849 ,  0.7348485 , -0.5984849 ],
         [-0.70454544,  0.75      , -0.31060606],
         ...,
         [-0.03787879,  0.46212122,  0.5530303 ],
         [ 0.11363637, -0.0530303 ,  0.37121212],
         [-0.6439394 , -0.37121212, -0.6287879 ]],
 
        [[-0.20454545,  0.02272727, -0.5378788 ],
         [ 0.02272727, -0.70454544,  0.08333334],
         [-0.6439394 ,  0.3409091 , -0.21969697],
         ...,
         [ 0.68939394, -0.06818182,  0.68939394],
         [-0.5833333 ,  0.09848485, -0.38636363],
         [ 0.20454545,  0.41666666,  0.12878788]]], dtype=float32),
 'points.occ': array([[[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]],
 
        [[0.],
         [0.],
         [0.],
         ...,
         [0.],
         [0.],
         [0.]]], dtype=float32),
 'points.time': array([0., 1.], dtype=float32),
 'mesh.vertices': array([[[0.22017938, 0.60612228,

### Study shape class

In [12]:
from core.models.utils_arap.shape_utils import Shape
import torch

In [13]:
shape_1 = Shape(vert=torch.from_numpy(dataset[1]['mesh.vertices']),triv=torch.from_numpy(dataset[1]['mesh.triangles']))

### Convert .npz file into .ply file

In [5]:
# activate cadex_exp venv
import open3d as o3d
import numpy as np

In [23]:
vertices_path = "//usr/stud/srinivaa/code/new_CaDeX/CaDeX/cdc_1.npz"
data_vertices = np.load(file_path)
#data_faces = np.load("/usr/data/cvpr_shared/marvin/Data/CaDeX/data/Humans/D-FAUST/50002_chicken_wings/mesh_registred/00000000.npz")["triangles"]
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.vstack(data_vertices['points']))
o3d.io.write_point_cloud("cdc_1.ply", pcd)

# o3d_mesh = o3d.geometry.TriangleMesh()
# o3d_mesh.vertices = o3d.utility.Vector3dVector(np.vstack(data_vertices["points"])) # verify what is the name of the attribute for vertices in data dictionary. It should be vertices as given in load method in MeshField class.
# o3d_mesh.triangles = o3d.utility.Vector3iVector(data_faces)
# o3d.io.write_triangle_mesh("cdc_0_mesh.ply",o3d_mesh)

True

In [83]:
cdc_vertices = torch.from_numpy(np.load("/usr/stud/srinivaa/code/new_CaDeX/CaDeX/external_results/cdc.npz")['points'][0][0])
query_faces = torch.from_numpy(np.load("/usr/stud/srinivaa/code/new_CaDeX/CaDeX/external_results/cdc.npz")['faces'][0][0])
query_vertices = torch.from_numpy(np.load("/usr/stud/srinivaa/code/new_CaDeX/CaDeX/external_results/query.npz")['points'][0][0])

In [1]:
import numpy as np

cdc_vertices = np.load("/usr/stud/srinivaa/code/new_CaDeX/CaDeX/external_results/cdc.npz")['points']
cdc_faces = np.load("/usr/stud/srinivaa/code/new_CaDeX/CaDeX/external_results/cdc.npz")['faces']

### Study of MDS loss

In [12]:
import numpy as np
import torch
from numba import jit
import time

In [13]:
cdc_vertices = torch.from_numpy(np.load("/usr/stud/srinivaa/code/new_CaDeX/CaDeX/external_results/cdc.npz")["points"])
query_vertices = torch.from_numpy(np.load("/usr/stud/srinivaa/code/new_CaDeX/CaDeX/external_results/query.npz")["points"])
faces = torch.from_numpy(np.load("/usr/stud/srinivaa/code/new_CaDeX/CaDeX/external_results/cdc.npz")["faces"])

In [19]:
def my_zeros(shape):
    return torch.zeros(shape,dtype=torch.float32)


def my_ones(shape):
    return torch.ones(shape, dtype=torch.float32)

def arap_exact(vert_diff_cdc, vert_diff_query, neigh, n_vert):
    S_neigh = torch.bmm(vert_diff_cdc.unsqueeze(2),vert_diff_query.unsqueeze(1))
   
    S = my_zeros([n_vert, 3, 3])

    S = torch.index_add(S, 0, neigh[:, 0], S_neigh)
    S = torch.index_add(S, 0, neigh[:, 1], S_neigh)
    
    # Kabsch algorithm
    U, _, V = torch.svd(S, compute_uv=True)
    
    R = torch.bmm(U, V.transpose(1, 2))
    
    Sigma = my_ones((R.shape[0], 1, 3))
    Sigma[:, :, 2] = torch.det(R).unsqueeze(1)

    R = torch.bmm(U * Sigma, V.transpose(1, 2))
    
    return R

def arap_energy_exact(vert_cdc, vert_query, neigh, lambda_reg_len=1e-6):
    n_vert = vert_cdc.shape[0]
    
    vert_diff_cdc = torch.sub(vert_cdc[neigh[:,0]],vert_cdc[neigh[:,1]])
    vert_diff_query = torch.sub(vert_query[neigh[:,0]],vert_query[neigh[:,1]])

    # Beginning of exact minimization scheme
    # Assuming deformed coordinates (cdc coords) are correct, find value of rotation matrix that ARAP
    R_t = arap_exact(vert_diff_cdc, vert_diff_query, neigh, n_vert)
    
    R_neigh_t = 0.5 * (
        torch.index_select(R_t, 0, neigh[:, 0])
        + torch.index_select(R_t, 0, neigh[:, 1])
    )

    # Assuming R_neigh_t is correct, obtain the deformed coordinates such that they are ARAP deformed
    vert_diff_query_rot = torch.bmm(R_neigh_t, vert_diff_query.unsqueeze(2)).squeeze() # obtain the new coordinates after deforming the shaper as rigid as possible 
    acc_t_neigh = vert_diff_cdc - vert_diff_query_rot # Minimize the difference between deformed coords and original coords

    E_arap = acc_t_neigh.norm() ** 2 + lambda_reg_len * (vert_cdc - vert_query).norm() ** 2
    E_arap = 0.1 * E_arap
  
    return E_arap


In [20]:
def get_neigh(triv):
    neigh = torch.cat(
        (triv[:, [0, 1]], triv[:, [0, 2]], triv[:, [1, 2]]), 0
    )

    return neigh.long()

In [21]:
def _loss_deform(query_vertices, query_triangles, canonical):
    E = 0
    neighbours = get_neigh(query_triangles[0][0])
    
    for i in range(canonical.shape[0]):
        canonical_vert_batch = canonical[i] #Single batch of cdc coordinates
        query_vert_batch_i = query_vertices[i] # Single batch of query space coords
        

        #E_arap,E_mds = self._loss_deform_single(query_vert_batch_i,canonical_vert_batch,neighbours) # Send in batch wise
        E_arap = _loss_deform_single(query_vert_batch_i,canonical_vert_batch,neighbours)
        E = E + E_arap
    return E

In [22]:
def _loss_deform_single(query_vertices, canonical_vertices, neighbours):
    E_deform_list = []

    for i in range(query_vertices.shape[0]):
        
        torch.set_num_threads(i + 1)
        E_y = arap_energy_exact(
            query_vertices[i], canonical_vertices[i],neighbours
        )

        
        if torch.isnan(E_y) ==  False:
            E_deform_list.append(E_y)

    E_deform_tensor = torch.tensor(E_deform_list)
    E_deform_mean = torch.mean(E_deform_tensor)

    return E_deform_mean

In [23]:
start = time.time()
deform_loss_i = _loss_deform(query_vertices.cpu(),faces.cpu(),cdc_vertices.cpu())
print("Time taken to execute deformation loss(in seconds):", time.time() - start)

Time taken to execute deformation loss(in seconds): 4.660088300704956


In [19]:
faces.cpu().is_cuda

False

## Study of correspondence loss

In [1]:
import numpy as np
import torch
from init.config_utils import load_config
from core.net_bank.lpdc_encoder import SpatioTemporalResnetPointnetCDC
from core.net_bank.oflow_point import ResnetPointnet
from core.net_bank.oflow_decoder import DecoderCBatchNorm
from torch import distributions as dist

In [2]:
cfg = load_config("/usr/stud/srinivaa/code/new_CaDeX/CaDeX/configs/dfaust/training/dfaust_w_st.yaml", default_path="/usr/stud/srinivaa/code/new_CaDeX/CaDeX/init/default.yaml")

In [3]:
seq_pc = torch.from_numpy(np.load("/usr/stud/srinivaa/code/new_CaDeX/CaDeX/external_results/query.npz")["points"])
uncompressed_cdc = torch.from_numpy(np.load("/usr/stud/srinivaa/code/new_CaDeX/CaDeX/external_results/cdc.npz")["points"])
cdc = torch.sigmoid(uncompressed_cdc) - 0.5

In [4]:
homeomorphism_encoder = SpatioTemporalResnetPointnetCDC(
                dim=3,
                **cfg["model"]["homeomorphism_encoder"],
                global_geometry=True,
                # * Note, here we still set global_geometry = True, but don't use this c_g
            )

In [5]:
network_dict = torch.nn.ModuleDict(
    {
        "homeomorphism_encoder" : homeomorphism_encoder,
        "canonical_geometry_encoder": ResnetPointnet(
                    dim=3, **cfg["model"]["canonical_geometry_encoder"]
                ),
        "canonical_geometry_decoder": DecoderCBatchNorm(
                    dim=3, z_dim=0, **cfg["model"]["canonical_geometry_decoder"]
                )
    }
)

In [6]:
def decode_by_cdc(observation_c, query):
    """
    The function uses OccNet to reconstruct the canonical shape in canonical deformaton coordinate space.

    Args:
    observayion_c : The embedding from PointNet in canonical deformation coordinate space
    query : The query points in a query position

    Returns:
    Occupancy probabilities 
    """
    B, T, N, _ = query.shape
    query = query.reshape(B, -1, 3)
    logits = network_dict["canonical_geometry_decoder"](
        query, None, observation_c
    ).reshape(B, T, N)

    
    return dist.Bernoulli(logits=logits)

In [9]:
pr = decode_by_cdc(observation_c=c_g, query=cdc)
occ_hat = pr.probs

In [11]:
occ_hat.shape

torch.Size([2, 17, 6890])

In [11]:
_,c_t = network_dict["homeomorphism_encoder"](seq_pc)

In [7]:
c_g = network_dict["canonical_geometry_encoder"](uncompressed_cdc.reshape(2, -1, 3))

In [None]:
c_g.shape

torch.Size([2, 128])