In [None]:
import jax
import jax.numpy as jnp
from PIL import Image
from ott.geometry import pointcloud
from ott.solvers.linear import sinkhorn
from ott.problems.linear import linear_problem
import matplotlib.pyplot as plt

EPS_SMALL = 1e-10


# ---------- Utility Functions ----------

def load_and_normalize_image(path_or_array, size=(64, 64)):
    if isinstance(path_or_array, str):
        img = Image.open(path_or_array).convert("RGB").resize(size, Image.Resampling.LANCZOS)
        img = jnp.array(img, dtype=jnp.float32) / 255.0
    else:
        img = path_or_array
        if img.shape != (size[0], size[1], 3):
            img = jax.image.resize(img, size + (3,), method="linear")
    return img  # No global normalization â€” preserve brightness


@jax.jit
def compute_ot_plan(a, b, pos_a, pos_b, epsilon=0.01):
    geom = pointcloud.PointCloud(pos_a, pos_b, epsilon=epsilon)
    prob = linear_problem.LinearProblem(geom, a=a, b=b)
    solver = sinkhorn.Sinkhorn(lse_mode=True, max_iterations=6000)
    out = solver(prob)
    return out.matrix


def interpolate_transport(P, pos_a, pos_b, t, shape):
    H, W = shape
    C, N, _ = P.shape

    interp_pos = (1 - t) * pos_a[None, :, :] + t * pos_b[:, None, :]  # (1, N, N, 2)

    def scatter_channel(Pc, interp):
        x, y = interp[..., 0], interp[..., 1]
        x0 = jnp.floor(x).astype(int)
        y0 = jnp.floor(y).astype(int)
        x1 = x0 + 1
        y1 = y0 + 1

        wx = x - x0
        wy = y - y0

        weights = jnp.stack([
            (1 - wx) * (1 - wy),
            wx * (1 - wy),
            (1 - wx) * wy,
            wx * wy
        ], axis=-1)

        coords = jnp.stack([
            jnp.clip(x0, 0, W - 1), jnp.clip(y0, 0, H - 1),
            jnp.clip(x1, 0, W - 1), jnp.clip(y0, 0, H - 1),
            jnp.clip(x0, 0, W - 1), jnp.clip(y1, 0, H - 1),
            jnp.clip(x1, 0, W - 1), jnp.clip(y1, 0, H - 1),
        ], axis=-1).reshape(N, N, 4, 2)

        canvas = jnp.zeros((H, W))
        for k in range(4):
            xi = coords[..., k, 0]
            yi = coords[..., k, 1]
            mass = Pc * weights[..., k]
            canvas = canvas.at[yi, xi].add(mass)

        return canvas

    out = jax.vmap(scatter_channel)(P, jnp.broadcast_to(interp_pos, (C, N, N, 2)))
    return out.transpose(1, 2, 0)


# ---------- Main Morph Function ----------

def morph_sequence(images, times, fps=24, size=(64, 64), debug=False):
    assert len(images) == len(times), "Each image must have a corresponding timestamp"
    assert all(times[i] < times[i+1] for i in range(len(times)-1)), "Timestamps must be increasing"

    imgs = [load_and_normalize_image(img, size=size) for img in images]

    if debug:
        print("Loaded Images:")
        for i, img in enumerate(imgs):
            plt.imshow(img)
            plt.title(f"Image {i}")
            plt.axis("off")
            plt.show()

    frames = []
    H, W = size
    x, y = jnp.meshgrid(jnp.linspace(0, W - 1, W), jnp.linspace(0, H - 1, H))
    positions = jnp.stack([x.ravel(), y.ravel()], axis=1)

    for i in range(len(images) - 1):
        img_a, img_b = imgs[i], imgs[i+1]
        t_start, t_end = times[i], times[i+1]
        num_frames = int((t_end - t_start) * fps)

        a = img_a.reshape(-1, 3).T
        b = img_b.reshape(-1, 3).T
        # a = a / jnp.sum(a, axis=1, keepdims=True)
        # b = b / jnp.sum(b, axis=1, keepdims=True)

        a = a.astype(jnp.float32)
        b = b.astype(jnp.float32)

        P = jax.vmap(lambda a_ch, b_ch:
                     compute_ot_plan(a_ch, b_ch, positions, positions))(a, b)

        if debug:
            print(f"\nOT Plan stats (Frame {i}):", P.shape)
            print("P[0] min/max:", P[0].min(), P[0].max(), "sum:", P[0].sum())

        for f in range(num_frames):
            alpha = f / num_frames
            img_interp = interpolate_transport(P, positions, positions, 1-alpha, size)

            if debug and f == 0:
                print("Interpolated Frame stats:", img_interp.min(), img_interp.max())

            frames.append(jnp.clip(img_interp, 0, 1))
        frames.append(img_b)

    return frames


# ---------- Usage ----------

images = ["./image1.jpg", "./image2.jpg"]
times = [0.0, 1.0]

frames = morph_sequence(images, times, fps=40, debug=True)

import matplotlib.animation as animation

fig, ax = plt.subplots()
ax.axis('off')

# Initialize with the first frame
im = ax.imshow(frames[0])

def update(frame):
    im.set_array(frame)
    return [im]

ani = animation.FuncAnimation(fig, update, frames=frames, blit=True)

# Save as GIF (requires pillow)
ani.save('morph.gif', writer='pillow', fps=24)
plt.close(fig)
