# YHLee - https://gist.github.com/yhlee-add/00f7aae96e547427fb857e24e9cde1d4

In [None]:
import numpy as np
from scipy.spatial.transform import Rotation as R

def spherical_grid(size: int, theta: float, phi: float, fov: float = 90.0):
    """
    Compute xyz coordinates of a perspective grid projected on a sphere
    """
    coord_seqs = [np.linspace((size - 1) / size, (1 - size) / size, size) * np.tan(np.radians(fov / 2))] * 2
    grid = np.meshgrid(*coord_seqs, indexing='xy')
    pos = np.stack([np.ones_like(grid[0]), *grid], axis=-1)

    pos /= np.linalg.norm(pos, axis=-1, keepdims=True)
    pos = pos.reshape(-1, pos.shape[-1])
    pos = R.from_euler('ZY', [-theta, -phi], degrees=True).apply(pos).reshape((size, size, 3))
    return pos

def erp_grid(pos, erp_h: int = 512):
    """
    Compute position on erp image from xyz coordinates
    """
    lat = np.arcsin(pos[..., 2])
    lon = np.arctan2(pos[..., 1], pos[..., 0])
    erp_pos = np.stack([0.5 - lat / np.pi, 1 - lon / np.pi], axis=-1) * erp_h
    return erp_pos

def erp_to_xyz(points, erp_h: int = 512):
    """
    Inverse function of erp_grid
    """
    lat = (0.5 - points[..., 0] / erp_h) * np.pi
    lon = (1 - points[..., 1] / erp_h) * np.pi
    pos = np.stack([np.cos(lat) * np.cos(lon), np.cos(lat) * np.sin(lon), np.sin(lat)], axis=-1)
    return pos

def xyz_to_normal_pers(pos, theta: float, phi: float, fov: float = 90.0):
    """
    Funtion to use together with F.grid_sample.
    Given a matrix of xyz positions, find the normalized perspective view position.
    Also provide mask that indicates valid points.
    """
    h, w, c = pos.shape
    assert c == 3
    pos = pos.reshape(-1, 3)
    pos = R.from_euler('ZY', [-theta, -phi], degrees=True).inv().apply(pos)

    mask_x = pos[:, 0] > 1e-8
    pos /= -pos[:, :1] # minus sign for matching the axis direction of F.grid_sample
    pos /= np.tan(np.radians(fov / 2))

    mask_y = (-1 < pos[:, 1]) & (pos[:, 1] < 1)
    mask_z = (-1 < pos[:, 2]) & (pos[:, 2] < 1)
    mask = mask_x & mask_y & mask_z

    return pos[:, 1:].reshape((h, w, 2)), mask.reshape((h, w))


directions = [
    (0, 0), (45, 0), (22.5, 30), (22.5, -30)
]

In [None]:
theta_range = (0, 360)
phi_range = (-90, 90)

num_phi = 4
num_theta = [3, 6, 6, 3]
assert num_phi == len(num_theta)

directions = []
phis = np.linspace(*phi_range, num_phi, endpoint=True)
for i in range(num_phi):
    thetas = np.linspace(*theta_range, num_theta[i], endpoint=False)
    for theta in thetas:
        directions.append((theta, phis[i]))
    print(*directions[-num_theta[i]:])

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
logging.set_verbosity_error()

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torchvision.transforms as T

import matplotlib.pyplot as plt

In [None]:
device = "cuda"

In [None]:
latent_erp_resolution = 512  # for hiwyn

In [None]:
import random
from tqdm.auto import tqdm
from scipy.spatial import KDTree

# stable diffusion RGB output size
grid_size = 512

class ChoiceSet:
    def __init__(self, seq = None, seed = 42):
        if seq is None:
            self.items = []
            self.item_to_pos = {}
        else:
            self.items = list(seq)
            self.item_to_pos = {}
            for i, item in enumerate(self.items):
                self.item_to_pos[item] = i

        self.rng = random.Random(seed)

    def add(self, item):
        if item in self.item_to_pos:
            return
        self.item_to_pos[item] = len(self.items)
        self.items.append(item)

    def remove(self, item):
        pos = self.item_to_pos.pop(item)
        last_item = self.items.pop()
        if pos != len(self.items):
            self.items[pos] = last_item
            self.item_to_pos[last_item] = pos

    def choice(self):
        return self.rng.choice(self.items)

    def __len__(self):
        return len(self.items)

    def __contains__(self, item):
        return item in self.item_to_pos


def generate_points(size: int):
    points = np.concatenate([spherical_grid(size, *direction).reshape(-1, 3) for direction in directions])
    return points, KDTree(points)

def do_matching():
    """
    O(n^2) -> O(n log n) using KDTree
    Also made `points` immutable
    """
    points, point_tree = generate_points(grid_size)

    # heuristic to make point order not swapped
    threshold = (np.pi / (grid_size * 4))

    remaining = ChoiceSet(range(len(directions) * grid_size * grid_size))
    result_points = []
    assignment = {}

    with tqdm(total=len(remaining)) as pbar:
        while remaining:
            t_i = remaining.choice() # target
            t_p = points[t_i] # target position. shape: [3]
            curr = []

            # find all neighbors
            neighbors = point_tree.query_ball_point(t_p, threshold)
            displacement = points[neighbors] - t_p # shape: [len(neighbors), 3]
            dist_square = (displacement * displacement).sum(axis=-1) # shape: [len(neighbors)]

            to_merge = [None] * len(directions)

            for _, i in sorted(zip(dist_square, neighbors)):
                d = i // (grid_size * grid_size) # get which direction grid it is from
                if i in remaining and to_merge[d] is None:
                    to_merge[d] = i
                    remaining.remove(i)
                    curr.append(points[i])
                    assignment[i] = len(result_points)

            result_points.append(curr)
            pbar.update(len(curr))

    # sanity check: all points must be consumed
    assert len(assignment) == len(points)
    assert sum(len(l) for l in result_points) == len(points)

    return result_points, assignment

omni_points, omni_assign = do_matching()

In [None]:
fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(111, projection='3d')

merge_count = [len(ps) for ps in omni_points]

cmap = plt.get_cmap('viridis', max(merge_count) - min(merge_count) + 1)

sc = ax.scatter(
    [ps[0][0] for ps in omni_points],
    [ps[0][1] for ps in omni_points],
    [ps[0][2] for ps in omni_points],
    c=merge_count,
    cmap=cmap,
    vmin=min(merge_count) - 0.5,
    vmax=max(merge_count) + 0.5,
    edgecolor=None,
    s=0.01,
)

ax.set_xlabel('X')
ax.set_xlim(-1, 1)
ax.set_ylabel('Y')
ax.set_ylim(-1, 1)
ax.set_zlabel('Z')
ax.set_zlim(-1, 1)
ax.set_box_aspect((1, 1, 1))

# Add colorbar
fig.colorbar(
    sc,
    ax=ax,
    label='Number of merged points',
    shrink=0.5,
    ticks=np.arange(min(merge_count), max(merge_count) + 1)
)

plt.show()

In [None]:
class OmniDiffusion(nn.Module):
    """
    Pretty much copyed code from
    https://github.com/omerbt/MultiDiffusion/blob/master/panorama.py
    """

    def __init__(self, device):
        super().__init__()

        self.device = device
        model_key = "stabilityai/stable-diffusion-2-base"

        self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(device)
        self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
        self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(device)
        self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(device)

        self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")

    @torch.no_grad()
    def get_text_embeds(self, prompt, negative_prompt):
        # prompt, negative_prompt: [str]

        # Tokenize text and get embeddings
        text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
                                    truncation=True, return_tensors='pt')
        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]

        # Do the same for unconditional embeddings
        uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
                                      return_tensors='pt')

        uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

        # Cat for final embeddings
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
        return text_embeddings

    @torch.no_grad()
    def decode_latents(self, latents):
        latents = 1 / 0.18215 * latents
        imgs = self.vae.decode(latents).sample
        imgs = (imgs / 2 + 0.5).clamp(0, 1)
        return imgs

net = OmniDiffusion('cuda')

In [None]:
def cell_number_grid(size: int, theta: float, phi: float, fov: float = 90.0):
    points_xyz = erp_to_xyz(np.stack(np.meshgrid(np.arange(size) + 0.5, np.arange(size * 2) + 0.5, indexing='ij'), axis=-1), erp_h=size)
    points_grid, mask = xyz_to_normal_pers(points_xyz, theta, phi, fov)

    # map to 64 by 64 grid
    scaled = (points_grid - (-1)) / (2 / 64)
    indices = np.clip(np.floor(scaled).astype(int), 0, 63)
    indices_x, indices_y = indices[..., 0], indices[..., 1]
    linear_indices = indices_y * 64 + indices_x

    return np.ma.array(linear_indices, mask=~mask)

In [None]:
def project_erp_to_pers(grid, latent_erp, requires_grad=False):
    latent_pers = torch.zeros((1, 4, 4096), device=device, requires_grad=requires_grad)
    count = torch.zeros(4096, dtype=int, device=device)

    index = torch.tensor(grid.compressed(), device=device)

    position_i, position_j = np.where(~grid.mask)
    source = latent_erp[:, :, position_i, position_j]
    ones = torch.ones_like(index)

    latent_pers = latent_pers.index_add(dim=2, index=index, source=source)
    count = count.index_add(dim=0, index=index, source=ones)

    latent_pers = latent_pers / torch.sqrt(count)  # normalize to get unit variance
    return latent_pers.reshape((1, 4, 64, 64))

In [None]:
def ddim_pred_orig(scheduler, model_output, timestep, sample):
    # Tweedie's formula
    alpha_prod_t = scheduler.alphas_cumprod[timestep]
    beta_prod_t = 1 - alpha_prod_t
    return (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)

def ddim_pred_prev(scheduler, sample, timestep, pred_orig):
    # psi of synctweedies paper

    alpha_prod_t = scheduler.alphas_cumprod[timestep]
    beta_prod_t = 1 - alpha_prod_t

    prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
    alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod

    model_output = (sample - alpha_prod_t ** (0.5) * pred_orig) / beta_prod_t ** (0.5)
    return alpha_prod_t_prev ** (0.5) * pred_orig + (1 - alpha_prod_t_prev) ** (0.5) * model_output

In [None]:
def do_encode(x):
    return net.vae.encode(x).latent_dist.sample() * 0.18215

def do_decode(z):
    return net.vae.decode(z / 0.18215).sample

direction_indices = [
    [omni_assign[d * grid_size * grid_size + j] for j in range(grid_size * grid_size)]
    for d in range(len(directions))
]

def synchronize_views(orig_pred_list):
    value = torch.zeros((1, 3, len(omni_points)), device=device)
    count = torch.tensor([len(ps) for ps in omni_points], device=device)

    for d, orig_pred in enumerate(orig_pred_list):
        x = do_decode(orig_pred)
        value[:, :, direction_indices[d]] += x.flatten(2)
    value /= count

    for d in range(len(directions)):
        orig_pred_list[d] = do_encode(value[:, :, direction_indices[d]].reshape((1, 3, 512, 512)))

    return orig_pred_list

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)


def prepare_text_embeddings(prompts, negative_prompts=''):
    if isinstance(prompts, str):
        prompts = [prompts]

    if isinstance(negative_prompts, str):
        negative_prompts = [negative_prompts]

    # Prompts -> text embeds
    return net.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

In [None]:
@torch.no_grad()
def foo():
    text_embeds = prepare_text_embeddings("A photo of the Dolomites")
    net.scheduler.set_timesteps(50)
    guidance_scale = 7.5

    grids = [cell_number_grid(latent_erp_resolution, *d) for d in directions]

    # step 0
    latent_erp = torch.randn((1, 4, latent_erp_resolution, latent_erp_resolution * 2), device=device)

    # step 1
    latent_pers_list = [
        project_erp_to_pers(grid, latent_erp) for grid in grids
    ]

    for i, t in enumerate(tqdm(net.scheduler.timesteps)):

        # step 2
        noise_pred_list = []
        orig_pred_list = []

        with torch.autocast("cuda"):
            for d, latent_pers in enumerate(latent_pers_list):
                latent_model_input = torch.cat([latent_pers, latent_pers])
                noise_pred = net.unet(latent_model_input, t, encoder_hidden_states=text_embeds)["sample"]

                noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
                noise_pred_list.append(noise_pred)

                orig_pred = ddim_pred_orig(net.scheduler, noise_pred, t, latent_pers)
                orig_pred_list.append(orig_pred)

        # sync tweedies
        orig_pred_list = synchronize_views(orig_pred_list)

        # denoise each pers view
        latent_pers_list = [
            ddim_pred_prev(net.scheduler, sample, t, orig_pred)
            for orig_pred, sample in zip(orig_pred_list, latent_pers_list)
        ]

    # decode
    result = [
        T.ToPILImage()(net.decode_latents(latent_pers)[0])
        for latent_pers in latent_pers_list
    ]

    return result


res = foo()

In [None]:
# Projection for final result

rgb_erp_resolution = 768

def project_direction(img, direction):
    flow, mask = xyz_to_normal_pers(erp_to_xyz(
        np.stack(np.meshgrid(np.arange(rgb_erp_resolution) + 0.5, np.arange(rgb_erp_resolution * 2) + 0.5, indexing='ij'), axis=-1),
        erp_h=rgb_erp_resolution
    ), *direction)

    flow = torch.tensor(flow[np.newaxis], device='cuda', dtype=torch.float32)
    mask = torch.tensor(mask, device='cuda')
    sampled = nn.functional.grid_sample(img, flow, mode='bilinear', padding_mode='border', align_corners=False)
    return sampled[0] * mask, mask

def project_direction_all(imgs):
    count = torch.zeros((rgb_erp_resolution, rgb_erp_resolution * 2), dtype=torch.int32, device='cuda')
    value = torch.zeros((1, 3, rgb_erp_resolution, rgb_erp_resolution * 2), dtype=torch.float32, device='cuda')

    for img, direction in zip(imgs, directions):
        result = project_direction(img, direction)
        value += result[0]
        count += result[1]

    count[count == 0] = 1
    return value / count

T.ToPILImage()(
    project_direction_all([T.ToTensor()(img)[None].to(device) for img in res])[0]
)

In [None]:
# prompt: with matplotlib, subplot images. use `res` which is a list of PIL images with length 4

import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 4, figsize=(15, 4.5))

for i, ax in enumerate(axes.flat):
    ax.imshow(res[i])
    ax.set_title(str(directions[i]))

plt.show()

In [None]:
for i in range(num_phi):
    fig, axes = plt.subplots(1, num_theta[i], figsize=(5 * num_theta[i], 5))
    for j, ax in enumerate(axes.flat):
        ax.imshow(res[i * num_theta[i] + j])
        ax.set_title(str(directions[i * num_theta[i] + j]))

# Rotating Directions

In [None]:
import numpy as np
import torch
from tqdm.auto import tqdm


theta_range = (0, 360)
phi_range = (-45, 45)

num_phi = 4
num_theta = [3, 4, 4, 3]
assert num_phi == len(num_theta)

directions = []
phis = np.linspace(*phi_range, num_phi, endpoint=True)
for i in range(num_phi):
    thetas = np.linspace(*theta_range, num_theta[i], endpoint=False)
    for theta in thetas:
        directions.append((theta, phis[i]))
    print(*directions[-num_theta[i]:])

In [None]:
def do_matching(directions):
    """
    O(n^2) -> O(n log n) using KDTree
    Also made `points` immutable
    """
    points, point_tree = generate_points(grid_size)

    # heuristic to make point order not swapped
    threshold = (np.pi / (grid_size * 4))

    remaining = ChoiceSet(range(len(directions) * grid_size * grid_size))
    result_points = []
    assignment = {}

    with tqdm(total=len(remaining)) as pbar:
        while remaining:
            t_i = remaining.choice() # target
            t_p = points[t_i] # target position. shape: [3]
            curr = []

            # find all neighbors
            neighbors = point_tree.query_ball_point(t_p, threshold)
            displacement = points[neighbors] - t_p # shape: [len(neighbors), 3]
            dist_square = (displacement * displacement).sum(axis=-1) # shape: [len(neighbors)]

            to_merge = [None] * len(directions)

            for _, i in sorted(zip(dist_square, neighbors)):
                d = i // (grid_size * grid_size) # get which direction grid it is from
                if i in remaining and to_merge[d] is None:
                    to_merge[d] = i
                    remaining.remove(i)
                    curr.append(points[i])
                    assignment[i] = len(result_points)

            result_points.append(curr)
            pbar.update(len(curr))

    # sanity check: all points must be consumed
    assert len(assignment) == len(points)
    assert sum(len(l) for l in result_points) == len(points)

    return result_points, assignment

def get_rotated_directions(base_directions, timestep, total_timesteps, rotation_per_step):
    """
    Apply horizontal rotation to directions based on the current timestep.
    :param base_directions: List of original (theta, phi) directions.
    :param timestep: Current diffusion timestep.
    :param total_timesteps: Total number of diffusion timesteps.
    :param rotation_per_step: Horizontal rotation angle per timestep in degrees.
    :return: List of modified directions with horizontal rotation applied.
    """
    rotation_offset = rotation_per_step * (timestep / total_timesteps)  # Calculate rotation offset
    modified_directions = [
        (theta, (phi + rotation_offset) % 360)  # Add horizontal rotation to phi
        for theta, phi in base_directions
    ]
    return modified_directions

In [None]:
@torch.no_grad()
def foo_with_rotation_and_rematching(rotation_per_step=10.0):
    """
    Modified function with horizontal rotation and dynamic rematching per timestep.
    :param rotation_per_step: Degrees of horizontal rotation per timestep.
    """
    text_embeds = prepare_text_embeddings("A photo of the Dolomites")
    net.scheduler.set_timesteps(50)
    guidance_scale = 7.5

    # Initialize ERP latent
    latent_erp = torch.randn((1, 4, latent_erp_resolution, latent_erp_resolution * 2), device=device)

    # Iterate through timesteps
    for i, t in enumerate(tqdm(net.scheduler.timesteps)):
        noise_pred_list = []
        orig_pred_list = []

        # Calculate rotated directions for the current timestep
        rotated_directions = get_rotated_directions(
            directions, timestep=i, total_timesteps=len(net.scheduler.timesteps), rotation_per_step=rotation_per_step
        )

        # Perform do_matching for the rotated directions
        omni_points, omni_assign = do_matching(rotated_directions)

        # Update direction_indices based on the new `omni_assign`
        direction_indices = [
            [omni_assign[d * grid_size * grid_size + j] for j in range(grid_size * grid_size)]
            for d in range(len(rotated_directions))
        ]

        # Update grids for the current timestep
        grids = [cell_number_grid(latent_erp_resolution, *d) for d in rotated_directions]

        # Update perspective views
        latent_pers_list = [
            project_erp_to_pers(grid, latent_erp) for grid in grids
        ]

        # Denoising step
        with torch.autocast("cuda"):
            for d, latent_pers in enumerate(latent_pers_list):
                latent_model_input = torch.cat([latent_pers, latent_pers])
                noise_pred = net.unet(latent_model_input, t, encoder_hidden_states=text_embeds)["sample"]

                noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
                noise_pred_list.append(noise_pred)

                orig_pred = ddim_pred_orig(net.scheduler, noise_pred, t, latent_pers)
                orig_pred_list.append(orig_pred)

        # Synchronize views across directions
        orig_pred_list = synchronize_views(orig_pred_list)

        # Update latent_pers_list for the next timestep
        latent_pers_list = [
            ddim_pred_prev(net.scheduler, sample, t, orig_pred)
            for orig_pred, sample in zip(orig_pred_list, latent_pers_list)
        ]

    # Decode final images
    result = [
        T.ToPILImage()(net.decode_latents(latent_pers)[0])
        for latent_pers in latent_pers_list
    ]

    return result


# Run the modified pipeline
res = foo_with_rotation_and_rematching(rotation_per_step=10)  # 10 degrees of horizontal rotation per timestep

In [None]:
@torch.no_grad()
def precompute_matching(rotation_per_step, total_timesteps):
    """
    Precompute matching results for all timesteps with rotated directions.
    :param rotation_per_step: Degrees of horizontal rotation per timestep.
    :param total_timesteps: Total number of timesteps.
    :return: Precomputed matching results for all timesteps.
    """
    precomputed_matches = []

    for timestep in range(total_timesteps):
        # Calculate rotated directions
        rotated_directions = get_rotated_directions(
            directions, timestep=timestep, total_timesteps=total_timesteps, rotation_per_step=rotation_per_step
        )

        # Perform matching for rotated directions
        omni_points, omni_assign = do_matching(rotated_directions)

        # Save results
        precomputed_matches.append((rotated_directions, omni_points, omni_assign))

    return precomputed_matches


@torch.no_grad()
def foo_with_precomputed_matching(precomputed_matches):
    """
    Modified function using precomputed matching results.
    :param precomputed_matches: Precomputed matching results for all timesteps.
    """
    text_embeds = prepare_text_embeddings("A photo of the Dolomites")
    net.scheduler.set_timesteps(50)
    guidance_scale = 7.5

    # Initialize ERP latent
    latent_erp = torch.randn((1, 4, latent_erp_resolution, latent_erp_resolution * 2), device=device)

    # Iterate through timesteps
    for i, t in enumerate(tqdm(net.scheduler.timesteps)):
        noise_pred_list = []
        orig_pred_list = []

        # Load precomputed matching results for the current timestep
        rotated_directions, omni_points, omni_assign = precomputed_matches[i]

        # Update direction_indices based on the precomputed `omni_assign`
        direction_indices = [
            [omni_assign[d * grid_size * grid_size + j] for j in range(grid_size * grid_size)]
            for d in range(len(rotated_directions))
        ]

        # Update grids for the current timestep
        grids = [cell_number_grid(latent_erp_resolution, *d) for d in rotated_directions]

        # Update perspective views
        latent_pers_list = [
            project_erp_to_pers(grid, latent_erp) for grid in grids
        ]

        # Denoising step
        with torch.autocast("cuda"):
            for d, latent_pers in enumerate(latent_pers_list):
                latent_model_input = torch.cat([latent_pers, latent_pers])
                noise_pred = net.unet(latent_model_input, t, encoder_hidden_states=text_embeds)["sample"]

                noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
                noise_pred_list.append(noise_pred)

                orig_pred = ddim_pred_orig(net.scheduler, noise_pred, t, latent_pers)
                orig_pred_list.append(orig_pred)

        # Synchronize views across directions
        orig_pred_list = synchronize_views(orig_pred_list)

        # Update latent_pers_list for the next timestep
        latent_pers_list = [
            ddim_pred_prev(net.scheduler, sample, t, orig_pred)
            for orig_pred, sample in zip(orig_pred_list, latent_pers_list)
        ]

    # Decode final images
    result = [
        T.ToPILImage()(net.decode_latents(latent_pers)[0])
        for latent_pers in latent_pers_list
    ]

    return result


# Precompute matching results
total_timesteps = 50
rotation_per_step = 360 // total_timesteps
precomputed_matches = precompute_matching(rotation_per_step, total_timesteps)

# Run the modified pipeline
res = foo_with_precomputed_matching(precomputed_matches)