In [None]:
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import numpy as np

### Creating a Test Optical Flow Pattern

In [None]:
# Motion of the Camera
T_unorm = jnp.array([3, 0.3, 2])
T = T_unorm  # / jnp.linalg.norm(T_unorm)
Ω = jnp.array([0, 0.9, 0])

# Intrinsic Parameters
res = (100, 100)
f = 1

# Depth Value at each pixel
min_depth = 5  # Example depth ranges
max_depth = 10
Z = jax.random.uniform(
    jax.random.PRNGKey(0), (res[0] * res[1], 1), minval=min_depth, maxval=max_depth
)
K = jnp.array([[f, 0, res[0] / 2], [0, f, res[1] / 2], [0, 0, 1]])
K_inv = jnp.linalg.inv(K)

In [None]:
### Calculate A Matrix
A_ls = []
for y in range(res[0]):
    for x in range(res[1]):
        norm_cord = K_inv @ jnp.array([x, y, 1])
        x_norm = norm_cord[0]
        y_norm = norm_cord[1]
        A_ls.append(jnp.array([[-f, 0, x_norm], [0, -f, y_norm]]))
A = jnp.stack(A_ls)

In [None]:
### Calculate B Matrix
B_ls = []
for y in range(res[0]):
    for x in range(res[1]):
        norm_cord = K_inv @ jnp.array([x, y, 1])
        x_norm = norm_cord[0]
        y_norm = norm_cord[1]
        B_ls.append(
            jnp.array(
                [
                    [(x_norm * y_norm), -(f + (x_norm**2) / f), y_norm],
                    [f + (y_norm**2) / f, -(x_norm * y_norm) / f, -x_norm],
                ]
            )
        )
B = jnp.stack(B_ls)

In [None]:
### Calculate the motion field
v = (1 / Z) * (A @ T) + B @ Ω
flow = v.reshape(res[0], res[1], 2)

In [None]:
spacing = 4
xval = np.arange(0, flow.shape[1], spacing)
yval = np.arange(0, flow.shape[0], spacing)
xx, yy = np.meshgrid(xval, yval)

plt.quiver(
    xx,
    yy,
    flow[::spacing, ::spacing, 0],
    flow[::spacing, ::spacing, 1],
    scale=1,
    scale_units="xy",
    angles="xy",
    units="xy",
)

In [None]:
v