# Cardiac Valve 4D Time-Series Conversion to USD

This notebook demonstrates converting time-varying cardiac valve simulation data from VTK format to animated USD.

## Dataset: CHOP-Valve4D (Alterra)

One cardiac valve model with time-varying geometry:

- **Alterra**: 265 time steps (cardiac cycle simulation)

This dataset represents 4D (3D + time) simulation of a prosthetic heart valve during a cardiac cycle.

## Goals

1. Load and inspect time-varying VTK data
2. Convert entire time series to animated USD
3. Handle large datasets efficiently
4. Preserve all simulation data as USD primvars
5. Create multiple variations (full resolution, subsampled, etc.)

In [None]:
from pathlib import Path
import re
import time as time_module

import shutil

## Configuration

Control which time series conversions to compute.

In [None]:
# Configuration: Control which conversions to run
# Set to True to compute full time series (all frames) - takes longer
# Set to False to only compute subsampled time series (faster, for preview)
COMPUTE_FULL_TIME_SERIES = True  # Default: only subsampled

print("Time Series Configuration:")
print(f"  - Compute Full Time Series: {COMPUTE_FULL_TIME_SERIES}")
print("  - Compute Subsampled Time Series: Always enabled")
print()
if not COMPUTE_FULL_TIME_SERIES:
    print("⚠️  Full time series conversion is DISABLED for faster execution.")
    print("   Set COMPUTE_FULL_TIME_SERIES = True to enable full conversion.")
else:
    print("✓ Full time series conversion is ENABLED (this will take longer).")

In [None]:
import logging
import numpy as np

# Import the vtk_to_usd library
from physiomotion4d.vtk_to_usd import (
    VTKToUSDConverter,
    ConversionSettings,
    MaterialData,
    cell_type_name_for_vertex_count,
    read_vtk_file,
    validate_time_series_topology,
)

# Import USDTools for post-processing colormap
from physiomotion4d.usd_tools import USDTools
from physiomotion4d.usd_anatomy_tools import USDAnatomyTools

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

## 1. Discover and Organize Time-Series Files

In [None]:
# Define data directories (Alterra only)
data_dir = Path.cwd().parent.parent / "data" / "CHOP-Valve4D"
Alterra_dir = data_dir / "Alterra"
output_dir = Path.cwd() / "output" / "valve4d-alterra"
output_dir.mkdir(parents=True, exist_ok=True)

print(f"Data directory: {data_dir}")
print(f"Output directory: {output_dir}")
print("\nDirectory status:")
print(f"  Alterra: {'✓' if Alterra_dir.exists() else '✗'} {Alterra_dir}")

In [None]:
def discover_time_series(directory, pattern=r"\.t(\d+)\.vtk$"):
    """Discover and sort time-series VTK files.

    Args:
        directory: Directory containing VTK files
        pattern: Regex pattern to extract time step number

    Returns:
        list: Sorted list of (time_step, file_path) tuples
    """
    vtk_files = list(Path(directory).glob("*.vtk"))

    # Extract time step numbers and pair with files
    time_series = []
    for vtk_file in vtk_files:
        match = re.search(pattern, vtk_file.name)
        if match:
            time_step = int(match.group(1))
            time_series.append((time_step, vtk_file))

    # Sort by time step
    time_series.sort(key=lambda x: x[0])

    return time_series


# Discover Alterra time series
Alterra_series = discover_time_series(Alterra_dir)

print("=" * 60)
print("Time-Series Discovery (Alterra)")
print("=" * 60)
print("\nAlterra:")
print(f"  Files found: {len(Alterra_series)}")
if Alterra_series:
    print(f"  Time range: t{Alterra_series[0][0]} to t{Alterra_series[-1][0]}")
    print(f"  First file: {Alterra_series[0][1].name}")
    print(f"  Last file: {Alterra_series[-1][1].name}")

## 2. Inspect First Frame

Examine the first time step to understand the data structure.

In [None]:
# Read first frame of Alterra
if Alterra_series:
    print("=" * 60)
    print("Alterra - First Frame Analysis")
    print("=" * 60)

    first_file = Alterra_series[0][1]
    mesh_data = read_vtk_file(first_file, extract_surface=True)

    print(f"\nFile: {first_file.name}")
    print("\nGeometry:")
    print(f"  Points: {len(mesh_data.points):,}")
    print(f"  Faces: {len(mesh_data.face_vertex_counts):,}")
    print(f"  Normals: {'Yes' if mesh_data.normals is not None else 'No'}")
    print(f"  Colors: {'Yes' if mesh_data.colors is not None else 'No'}")

    # Bounding box
    bbox_min = np.min(mesh_data.points, axis=0)
    bbox_max = np.max(mesh_data.points, axis=0)
    bbox_size = bbox_max - bbox_min
    print("\nBounding Box:")
    print(f"  Min: [{bbox_min[0]:.3f}, {bbox_min[1]:.3f}, {bbox_min[2]:.3f}]")
    print(f"  Max: [{bbox_max[0]:.3f}, {bbox_max[1]:.3f}, {bbox_max[2]:.3f}]")
    print(f"  Size: [{bbox_size[0]:.3f}, {bbox_size[1]:.3f}, {bbox_size[2]:.3f}]")

    print(f"\nData Arrays ({len(mesh_data.generic_arrays)}):")
    for i, array in enumerate(mesh_data.generic_arrays, 1):
        print(f"  {i}. {array.name}:")
        print(f"     - Type: {array.data_type.value}")
        print(f"     - Components: {array.num_components}")
        print(f"     - Interpolation: {array.interpolation}")
        print(f"     - Elements: {len(array.data):,}")
        if array.data.size > 0:
            print(f"     - Range: [{np.min(array.data):.6f}, {np.max(array.data):.6f}]")

    # Cell types (face vertex count) - TPV data has multiple cell types (triangle, quad, etc.)
    unique_counts, num_each = np.unique(
        mesh_data.face_vertex_counts, return_counts=True
    )
    print("\nCell types (faces by vertex count):")
    for u, n in zip(unique_counts, num_each):
        name = cell_type_name_for_vertex_count(int(u))
        print(f"  {name} ({u} vertices): {n:,} faces")

In [None]:
# Note: Helper functions removed - now using USDTools for primvar inspection and colorization
# The workflow has changed to: convert to USD first, then apply colormap post-processing

# Configuration: choose colormap for visualization
DEFAULT_COLORMAP = "viridis"  # matplotlib colormap name

# Enable automatic colorization (will pick strain/stress primvars if available)
ENABLE_AUTO_COLORIZATION = True

print("Colorization will be applied after USD conversion using USDTools methods")
print("  - USDTools.list_mesh_primvars() for inspection")
print("  - USDTools.pick_color_primvar() for selection")
print("  - USDTools.apply_colormap_from_primvar() for coloring")
print(f"  - Colormap: {DEFAULT_COLORMAP}")

In [None]:
## 2. Configure Conversion Settings

# Create converter settings
settings = ConversionSettings(
    triangulate_meshes=True,
    compute_normals=False,  # Use existing normals if available
    preserve_point_arrays=True,
    preserve_cell_arrays=True,
    separate_objects_by_cell_type=False,
    separate_objects_by_connectivity=True,
    up_axis="Y",
    times_per_second=60.0,  # 60 FPS for smooth animation
    use_time_samples=True,
)

print("Conversion settings configured")
print(f"  - Triangulate: {settings.triangulate_meshes}")
print(f"  - Separate objects by cell type: {settings.separate_objects_by_cell_type}")
print(f"  - FPS: {settings.times_per_second}")
print(f"  - Up axis: {settings.up_axis}")

## 3. Convert Full Time Series - TPV25

In [None]:
# Create material for Alterra
# Note: Vertex colors will be applied post-conversion by USDTools
Alterra_material = MaterialData(
    name="Alterra_valve",
    diffuse_color=(0.85, 0.4, 0.4),
    roughness=0.4,
    metallic=0.0,
    use_vertex_colors=False,  # USDTools will bind vertex color material during colorization
)

print("=" * 60)
print("Converting Alterra Time Series")
print("=" * 60)
print(f"Dataset: {len(Alterra_series)} frames")

# Convert Alterra (full resolution)
if COMPUTE_FULL_TIME_SERIES and Alterra_series:
    converter = VTKToUSDConverter(settings)

    Alterra_files = [file_path for _, file_path in Alterra_series]
    Alterra_times = [float(time_step) for time_step, _ in Alterra_series]

    output_usd = output_dir / "Alterra_full.usd"

    print(f"\nConverting to: {output_usd}")
    print(f"Time codes: {Alterra_times[0]:.1f} to {Alterra_times[-1]:.1f}")
    print("\nThis may take several minutes...\n")

    start_time = time_module.time()

    # Read MeshData
    mesh_data_sequence = [read_vtk_file(f, extract_surface=True) for f in Alterra_files]

    # Validate topology consistency across time series
    validation_report = validate_time_series_topology(
        mesh_data_sequence, filenames=Alterra_files
    )
    if not validation_report["is_consistent"]:
        print(
            f"Warning: Found {len(validation_report['warnings'])} topology/primvar issues"
        )
        if validation_report["topology_changes"]:
            print(
                f"  Topology changes in {len(validation_report['topology_changes'])} frames"
            )

    # Convert to USD (preserves all primvars from VTK)
    stage = converter.convert_mesh_data_sequence(
        mesh_data_sequence=mesh_data_sequence,
        output_usd=output_usd,
        mesh_name="AlterraValve",
        time_codes=Alterra_times,
        material=Alterra_material,
    )

    shutil.copy(output_usd, output_usd.with_suffix(".save.usd"))

In [None]:
if COMPUTE_FULL_TIME_SERIES and Alterra_series:
    # Post-process: apply colormap visualization using USDTools
    if ENABLE_AUTO_COLORIZATION:
        usd_tools = USDTools()
        usd_anatomy_tools = USDAnatomyTools(stage)
        if settings.separate_objects_by_connectivity is True:
            mesh_path1 = "/World/Meshes/AlterraValve_object3"
            mesh_path2 = "/World/Meshes/AlterraValve_object4"
        elif settings.separate_objects_by_cell_type is True:
            mesh_path1 = "/World/Meshes/AlterraValve_triangle1"
            mesh_path2 = "/World/Meshes/AlterraValve_triangle1"
        else:
            mesh_path1 = "/World/Meshes/AlterraValve"
            mesh_path2 = None

        # Inspect and select primvar for coloring
        primvars = usd_tools.list_mesh_primvars(str(output_usd), mesh_path1)
        print(primvars)
        color_primvar = usd_tools.pick_color_primvar(
            primvars, keywords=("strain", "stress")
        )

        if color_primvar:
            print(f"\nApplying colormap to '{color_primvar}'")
            usd_tools.apply_colormap_from_primvar(
                str(output_usd),
                mesh_path1,
                color_primvar,
                intensity_range=(25, 200),
                cmap="jet",
                use_sigmoid_scale=True,
                bind_vertex_color_material=True,
            )
            if mesh_path2 is not None:
                mesh_prim = stage.GetPrimAtPath(mesh_path2)
                usd_anatomy_tools.apply_anatomy_material_to_prim(
                    mesh_prim, usd_anatomy_tools.bone_params
                )

    if not validation_report["is_consistent"]:
        print(
            f"Warning: Found {len(validation_report['warnings'])} topology/primvar issues"
        )
        if validation_report["topology_changes"]:
            print("\nNo strain/stress primvar found for coloring")

    print(f"  Size: {output_usd.stat().st_size / (1024 * 1024):.2f} MB")
    print(f"  Time range: {stage.GetStartTimeCode()} - {stage.GetEndTimeCode()}")
    print(
        f"  Duration: {(stage.GetEndTimeCode() - stage.GetStartTimeCode()) / settings.times_per_second:.2f} seconds @ {settings.times_per_second} FPS"
    )
elif not COMPUTE_FULL_TIME_SERIES:
    print("⏭️  Skipping Alterra full time series (COMPUTE_FULL_TIME_SERIES = False)")