In [1]:
import torch
from einops import rearrange
from functorch import vmap

In [10]:
# Example inputs to simulate functionality
def generate_example_inputs(batch_size):
    # Generate random SO(3) rotation matrices
    x0 = torch.randn(batch_size, 3, 3, requires_grad=True)  # Simulated input (not guaranteed to be SO(3))
    x1 = torch.randn(batch_size, 3, 3, requires_grad=True)
    t = torch.rand(batch_size, requires_grad=True)  # Random time values
    return x0, x1, t

# The demonstration function for compute_conditional_flow_simple
def demo_compute_conditional_flow_simple(batch_size):
    # Step 1: Generate example inputs
    x0, x1, t = generate_example_inputs(batch_size)
    print("Initial rotation matrices x0:", x0)
    print("Shape of x0:", x0.shape)
    print("Initial rotation matrices x1:", x1)
    print("Shape of x1:", x1.shape)
    print("Time values t:", t)
    print("Shape of t:", t.shape)

    # Step 2: Flatten rotation matrices
    x0_flat = rearrange(x0, "b c d -> b (c d)", c=3, d=3)
    x1_flat = rearrange(x1 * t[:, None, None], "b c d -> b (c d)", c=3, d=3)  # Include t in computation
    print("Flattened x0:", x0_flat)
    print("Shape of flattened x0:", x0_flat.shape)
    print("Flattened x1:", x1_flat)
    print("Shape of flattened x1:", x1_flat.shape)

    # Step 3: Define the derivative computation helper function
    def index_time_derivative(i):
        return torch.autograd.grad(
            outputs=x1_flat,
            inputs=t,
            grad_outputs=i,
            create_graph=True,
            retain_graph=True
        )[0]

    # Step 4: Use vmap to compute derivatives for each coordinate
    identity_matrix = torch.eye(9).to(x0.device).repeat(batch_size, 1, 1)  # Identity per batch element
    print("Identity matrix for vmap:", identity_matrix)
    print("Shape of identity matrix:", identity_matrix.shape)

    x1_dot = vmap(index_time_derivative, in_dims=1)(identity_matrix)
    print("Raw derivatives x1_dot:", x1_dot)
    print("Shape of raw derivatives x1_dot:", x1_dot.shape)

    # Step 5: Reshape the result back into (batch, 3, 3) format
    x1_dot = rearrange(x1_dot, "(c d) b -> b c d", c=3, d=3)
    print("Reshaped derivatives x1_dot:", x1_dot)
    print("Shape of reshaped x1_dot:", x1_dot.shape)

    # Output the result
    return x1_dot


# Run the demonstration
batch_size = 2  # Adjust batch size as needed
output = demo_compute_conditional_flow_simple(batch_size)
print("Flow derivative computed for batch:", output)


Initial rotation matrices x0: tensor([[[-1.6438, -0.2921,  0.8579],
         [-0.5794,  0.5201, -1.2595],
         [ 0.2056, -0.1793, -0.7726]],

        [[ 1.6532,  0.4226,  0.7754],
         [-1.5206,  1.0661, -1.9345],
         [-0.2122, -0.2998, -0.7949]]], requires_grad=True)
Shape of x0: torch.Size([2, 3, 3])
Initial rotation matrices x1: tensor([[[-0.7484,  0.4143, -0.0590],
         [ 0.6639,  0.7768, -0.1967],
         [-1.6091,  0.3850, -0.5262]],

        [[ 0.7747,  0.3806, -1.4528],
         [-0.7646,  0.9479,  0.5471],
         [-0.8111, -0.6761,  0.3001]]], requires_grad=True)
Shape of x1: torch.Size([2, 3, 3])
Time values t: tensor([0.8670, 0.8481], requires_grad=True)
Shape of t: torch.Size([2])
Flattened x0: tensor([[-1.6438, -0.2921,  0.8579, -0.5794,  0.5201, -1.2595,  0.2056, -0.1793,
         -0.7726],
        [ 1.6532,  0.4226,  0.7754, -1.5206,  1.0661, -1.9345, -0.2122, -0.2998,
         -0.7949]], grad_fn=<ViewBackward0>)
Shape of flattened x0: torch.Size([2, 

  x1_dot = vmap(index_time_derivative, in_dims=1)(identity_matrix)
