In [22]:
import torch
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from transformers import AutoFeatureExtractor
from morphodiff.train import CustomStableDiffusionPipeline
from morphodiff.perturbation_encoder import PerturbationEncoderInference

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# Path to the checkpoint folder you unzipped:
#   bbbc021_14_compounds_morphodiff_ckpt/checkpoint
ckpt_path = "/proj/aicell/users/x_aleho/MorphoDiff/models/bbbc021_14_compounds_morphodiff_ckpt/checkpoint"

unet = UNet2DConditionModel.from_pretrained(
    ckpt_path, 
    subfolder="unet_ema"
)

vae = AutoencoderKL.from_pretrained(
    ckpt_path, 
    subfolder="vae"
)

scheduler = DDPMScheduler.from_pretrained(
    ckpt_path, 
    subfolder="scheduler"
)


feature_extractor = AutoFeatureExtractor.from_pretrained(
    ckpt_path + "/feature_extractor"
)

# 5. Create the custom text/perturbation encoder for BBBC021
#    (just as MorphoDiff does, e.g. 'conditional' mode, 
#    dataset_id = "BBBC021_experiment_01_resized", etc.)
perturbation_encoder = PerturbationEncoderInference(
    dataset_id="BBBC021_experiment_01_resized",  # adjust if needed
    model_type="conditional", 
    model_name="SD"
)

pipeline = CustomStableDiffusionPipeline(
    vae=vae,
    unet=unet,
    text_encoder=perturbation_encoder,   # Replaces usual CLIP text encoder
    feature_extractor=feature_extractor,
    scheduler=scheduler,
    # No safety_checker by default in MorphoDiff, so set safety_checker=None
    safety_checker=None
)

pipeline = pipeline.to(device)

The config attributes {'decay': 0.9999, 'inv_gamma': 1.0, 'min_decay': 0.0, 'optimization_step': 500, 'power': 0.6666666666666666, 'update_after_step': 0, 'use_ema_warmup': False} were passed to UNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.




In [None]:
import torch
import matplotlib.pyplot as plt

# Move the pipeline to that device
pipeline = pipeline.to(device)

# Set a compound name
prompt = "aphidicolin"

# Optional: fix a seed for reproducibility, on the correct device
generator = torch.Generator(device=device).manual_seed(42)

# Run inference
with torch.autocast(device.type):
    output = pipeline(prompt, generator=generator, guidance_scale=1.0)

image = output.images[0]

# Display inline
plt.figure(figsize=(6, 6))
plt.imshow(image)
plt.axis("off")
plt.title(f"Generated for '{prompt}'")
plt.show()

  0%|          | 0/50 [00:00<?, ?it/s]