In [2]:
%load_ext autoreload
%autoreload 2

In [24]:
import seaborn as sns
import numpy as np
from plaid.constants import ICLR_MODEL_ID
import typing as T
from pathlib import Path
from omegaconf import OmegaConf

from plaid.diffusion import FunctionOrganismDiffusion
from plaid.denoisers import FunctionOrganismDiT, FunctionOrganismUDiT, DenoiserKwargs
from plaid.constants import COMPRESSION_INPUT_DIMENSIONS
from plaid.datasets import NUM_ORGANISM_CLASSES, NUM_FUNCTION_CLASSES

In [27]:
model_id = ICLR_MODEL_ID
model_ckpt_dir  = Path("/data/lux70/plaid/checkpoints/plaid-compositional")
model_path = model_ckpt_dir / model_id / "last.ckpt"

organism_idx: int = NUM_ORGANISM_CLASSES
function_idx: int = NUM_FUNCTION_CLASSES
cond_scale: float = 7
num_samples: int = -1
beta_scheduler_name: T.Optional[str] = "sigmoid"
sampling_timesteps: int = 20 
batch_size: int = -1
length: int = 32  # the final length, after decoding back to structure/sequence, is twice this value
return_all_timesteps: bool = False

config_path = model_ckpt_dir / model_id / "config.yaml"

cfg = OmegaConf.load(config_path)
compression_model_id = cfg['compression_model_id']
# shorten_factor = COMPRESSION_SHORTEN_FACTORS[compression_model_id]
input_dim = COMPRESSION_INPUT_DIMENSIONS[compression_model_id]

# instantiate the correct denoiser class
# UDiT supports skip connections and memory-efficient attention, while DiT does not
denoiser_kwargs = cfg.denoiser
denoiser_class = denoiser_kwargs.pop("_target_")

if denoiser_class == "plaid.denoisers.FunctionOrganismUDiT":
    denoiser = FunctionOrganismUDiT(**denoiser_kwargs, input_dim=input_dim)
elif denoiser_class == "plaid.denoisers.FunctionOrganismDiT":
    denoiser = FunctionOrganismDiT(**denoiser_kwargs, input_dim=input_dim)
else:
    raise ValueError(f"Unknown denoiser class: {denoiser_class}")

# lask.ckpt automatically links to the EMA
ckpt = torch.load(model_path)

# remove the prefix from the state dict if torch.compile was used during training
mod_state_dict = {}
for k, v in ckpt['state_dict'].items():
    if k[:16] == "model._orig_mod.":
        mod_state_dict[k[16:]] = v

# load weights and create diffusion object
denoiser.load_state_dict(mod_state_dict)
diffusion_kwargs = cfg.diffusion
diffusion_kwargs.pop("_target_")



In [30]:
from diffusers import LMSDiscreteScheduler

scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

In [31]:
height = 512                        # default height of Stable Diffusion
width = 512                         # default width of Stable Diffusion

num_inference_steps = 100           # Number of denoising steps

guidance_scale = 7.5                # Scale for classifier-free guidance

generator = torch.manual_seed(0)    # Seed generator to create the inital latent noise

batch_size = 8
scheduler.set_timesteps(num_inference_steps)

In [19]:
from tqdm.auto import tqdm

scheduler.set_timesteps(num_inference_steps)

for t in tqdm(scheduler.timesteps):
    # # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
    # latent_model_input = torch.cat([latents] * 2)

    # latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

    # # predict the noise residual
    # with torch.no_grad():
    #     noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

    # # perform guidance
    # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    # # compute the previous noisy sample x_t -> x_t-1
    # latents = scheduler.step(noise_pred, t, latents).prev_sample


  0%|          | 0/100 [00:00<?, ?it/s]