# Conditional DDIM interpolation

# Imports

In [10]:
from pathlib import Path
import numpy as np
import torch

In [1]:
from src.conditional_ddim import ConditionialDDIMPipeline

# Device

In [11]:
device = "cuda"

# Load pretrained pipeline

In [4]:
path = Path("experiments", "REMOVEME_functional_test", "full_pipeline_save")
assert path.exists()

In [12]:
pipeline = ConditionialDDIMPipeline.from_pretrained(path)
unet = pipeline.unet
scheduler = pipeline.scheduler

# Make a `UNet2DConditionModel` from a `UNet2DModel`

In [None]:
unet_explicit_cond = ...

# Interpolate

## Get base class embedding

In [None]:
DMSO_embedding = ...
cyto_B30_embedding = ...

## Generate interpolated images

In [13]:
num_inference_steps = 100

In [None]:
image_list = []

for x in np.linspace(0, 1, 100):
    # sample gaussian noise to begin generation loop
    image_shape = (
        1,
        unet.config.in_channels,
        unet.config.sample_size,
        unet.config.sample_size,
    )

    image = torch.randn(image_shape, device=device, dtype=unet.dtype)

    # set step values
    scheduler.set_timesteps(num_inference_steps)

    # get interpolated class embedding
    class_embedding = x * DMSO_embedding + (1 - x) * cyto_B30_embedding

    for t in range(scheduler.timesteps):
        # 1. predict noise model_output
        model_output = unet_explicit_cond(image, t, class_embedding).sample

        # 2. predict previous mean of image x_t-1 and add variance depending on eta
        # eta corresponds to η in paper and should be between [0, 1]
        # do x_t -> x_t-1
        image = scheduler.step(
            model_output,
            t,
            image,
        ).prev_sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()
    image = pipeline.numpy_to_pil(image)

    image_list.append(image)