In [1]:
import jax
from imax.project import cam2pixel, meshgrid, nearest_sampler, pixel2cam
from jax import numpy as jnp
from matplotlib import pyplot as plt

In [46]:
def log_prob(mu, scale_diag, x, normalize=False):
    log_unnormalized = -0.5 * (x / scale_diag - mu / scale_diag) ** 2
    if not normalize:
        return log_unnormalized#.sum()
    log_normalization = (0.5 * jnp.log(2.0 * jnp.pi)) + jnp.log(scale_diag)
    return (log_unnormalized - log_normalization)#.sum()


def downcast_safe(x, dtype, margin=0):
    x = jnp.clip(x, jnp.finfo(dtype).min + margin, jnp.finfo(dtype).max - margin)
    return x.astype(dtype)


def generate_2D_gaussian_splatting(
    sigma_x,
    sigma_y,
    # sigma_z,
    rho,
    coords,
    colors,
    image_size,
    channels,
):
    dtype = rho.dtype
    
    rho = rho[..., None]
    sigma = jnp.concatenate([sigma_x[..., None], sigma_y[..., None]], axis=1)

#     batch_size = colors.shape[0]
#     kernel_size = max(image_size)
#     sigma_x = sigma_x.reshape((batch_size, 1, 1))
#     sigma_y = sigma_y.reshape((batch_size, 1, 1))
#     # sigma_z = sigma_z.reshape((batch_size, 1, 1))
    
    
    
#     rho = rho.reshape((batch_size, 1, 1))

#     covariance = jnp.stack(
#         [
#             jnp.stack([sigma_x**2 + 1e-6, rho * sigma_x * sigma_y], axis=-1),
#             jnp.stack([rho * sigma_x * sigma_y, sigma_y**2 + 1e-6], axis=-1),
#         ],
#         axis=-2,
#     ).astype("float32")

    # Check for positive semi-definiteness
    # determinant = (sigma_x**2) * (sigma_y**2) - (rho * sigma_x * sigma_y) ** 2
    # determinant = jax.lax.clamp(1e-6, determinant, jnp.inf)

#     inv_covariance = jnp.linalg.inv(covariance)

#     ax_batch = jnp.linspace(-1.0, 1.0, num=kernel_size, dtype="float32")[None, ...]

#     # Expanding dims for broadcasting
#     ax_batch_expanded_x = jnp.tile(ax_batch[..., None], (1, 1, kernel_size))
#     ax_batch_expanded_y = jnp.tile(ax_batch[:, None, ...], (1, kernel_size, 1))

#     # Creating a batch-wise meshgrid using broadcasting
#     xx, yy = ax_batch_expanded_x, ax_batch_expanded_y

#     xy = jnp.stack([xx, yy], axis=-1)
#     xy = xy + coords[:, None, None, :]
#     print(xy.shape)
    
#     z = jnp.einsum("b...i,b...ij,b...j->b...", xy, -0.5 * inv_covariance, xy)

#     z = jnp.clip(
#         z, a_max=jnp.log(jnp.finfo(dtype).max)
#     )  # mitigate overflow in fp16 mode

#     _, covariance_log_det = jnp.linalg.slogdet(covariance)

#     kernel = jnp.exp(
#         z - jnp.log(2 * jnp.pi) + 0.5 * covariance_log_det.reshape((batch_size, 1, 1))
#     )

#     kernel_max = kernel.max(axis=[-1, -2], keepdims=True)
#     kernel_max = jnp.where(kernel_max == 0, jnp.ones_like(kernel_max), kernel_max)
#     kernel_normalized = downcast_safe(kernel / kernel_max, dtype)

#     kernel_reshaped = jnp.reshape(
#         jnp.tile(kernel_normalized, (1, channels, 1)),
#         (batch_size * channels, kernel_size, kernel_size),
#     )
#     kernel_rgb = kernel_reshaped.reshape(batch_size, channels, kernel_size, kernel_size)


    
    pixel_coords = meshgrid(image_size[0], image_size[1], dtype=dtype, is_homogeneous=False)
    pixel_coords = pixel_coords.reshape((2, -1)).T
    print(pixel_coords.shape)
    
    
    kernel = jax.vmap(
        lambda mu, sig: jax.vmap(
            lambda x: log_prob(mu, sig, x)
        )(pixel_coords)
    )(rho, sigma)
    
    kernel= kernel.reshape((-1, image_size[0], image_size[1], 2))
    
    print(kernel.shape)

#     # Calculating the padding needed to match the image size
#     pad_h = int(image_size[0]) - kernel_size
#     pad_w = int(image_size[1]) - kernel_size

#     if pad_h < 0 or pad_w < 0:
#         raise ValueError("Kernel size should be smaller or equal to the image size.")

#     # Adding padding to make kernel size equal to the image size
#     padding = (
#         (0, 0),
#         (0, 0),
#         (pad_w // 2, pad_w // 2 + pad_w % 2),  # padding top and bottom
#         (pad_h // 2, pad_h // 2 + pad_h % 2),  # padding left and right
#     )

#     kernel_rgb_padded = jnp.pad(kernel_rgb, padding, "constant")
#     kernel_rgb_padded = jnp.transpose(kernel_rgb_padded, (0, 2, 3, 1))

    kernel = kernel.reshape((1, image_size[0], image_size[1], -1))
    #kernel = jnp.transpose(kernel, (3, 1, 2, 0))

    # Extracting shape information
    b, h, w, c = kernel.shape

    # Create a batch of 2D affine matrices
    intrinsics_0 = jnp.eye(2)
    intrinsics_0 = jnp.concatenate(
        [intrinsics_0, jnp.array([[(w - 1) / 2], [(h - 1) / 2]]), jnp.zeros((2, 1))],
        axis=1,
    )
    intrinsics_1 = jnp.array([[0, 0, 1, 0], [0, 0, 0, 1]])
    intrinsics = jnp.concatenate([intrinsics_0, intrinsics_1], axis=0)
    # intrinsics = jnp.tile(intrinsics[None, ...], (b,1,1))

    depth = jnp.ones(shape=(h, w))

    # Creating grid and performing grid sampling
    pixel_coords = meshgrid(h, w, dtype=dtype)  # "float32")
    # Convert pixel coordinates to the camera frame
    cam_coords = pixel2cam(depth, pixel_coords, intrinsics[:3, :3])

    src_pixel_coords = cam2pixel(cam_coords, intrinsics)
    
    print(src_pixel_coords.shape)

    if channels == 4:
        mask_value = jnp.array([0.0, 0.0, 0.0, 0.0])
    elif channels == 3:
        mask_value = jnp.array([0.0, 0.0, 0.0])
    else:
        raise NotImplementedError

    kernel_rgb_padded_translated = jax.vmap(
        lambda rgb: nearest_sampler(rgb, src_pixel_coords, mask_value=mask_value)
    )(kernel)

    rgb_values_reshaped = colors[..., None, None, :]

    final_image_layers = rgb_values_reshaped * kernel_rgb_padded_translated

    if channels == 4:
        final_image_layers = jnp.concatenate(
            [
                final_image_layers[..., :3] * final_image_layers[..., :1],
                final_image_layers[..., :1],
            ],
            axis=-1,
        )

    final_image = final_image_layers.sum(axis=0)
    # final_image = jax.lax.clamp(0.0, final_image, max=1.0)
    return final_image

In [47]:
rho = jnp.array([0.0, 0.0, -0.5])
sigma_x = jnp.array([0.4, 0.1, 0.1])
sigma_y = jnp.array([0.4, 0.1, 0.3])
vectors = jnp.array([(-0.5, -0.5), (0.8, 0.8), (0.5, 0.5)])
# colors = jnp.array([(1.0, 0.0, 0.0, 1.0), (0.0, 1.0, 0.0, 1.0), (0.0, 0.0, 1.0, 1.0)])
colors = jnp.array([(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0)])
img_size = (256, 256)
channels = 3

final_image = generate_2D_gaussian_splatting(
    sigma_x, sigma_y, rho, vectors, colors, img_size, channels
)

plt.imshow(final_image)
plt.axis("off")
plt.tight_layout()
plt.show()

(65536, 2)
(3, 256, 256, 2)
(256, 256, 2)


TypeError: mul got incompatible shapes for broadcasting: (256, 256, 6), (1, 1, 3).