In [None]:
# Step 0: load the checkpoints (autoencoder, diffusion, and classifier models)
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from lsdm.models.diffusion.ddim import DDIMSampler
from omegaconf import OmegaConf
from lsdm.util import instantiate_from_config
import albumentations
from torchvision import transforms
from lsdm.data.general import DatasetWithClassifierAPIValidation
from torch.utils.data import DataLoader

"""
Parameters:
"""
# TODO: create meta config file.
device = "cuda:0"
key_feature = "others"  # for new csv files, choose "others"
lsdm_config = "configs/latent-diffusion/lsdm_general.yaml"
classifier_config = "configs/classifier/classifier_general.yaml"
classifier_ckpt = ""  # BreastMam classifier ckpt

def load_model_from_config(config, ckpt, device=device):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location=torch.device("cpu"))#, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.to(torch.device(device))  # load on cuda:1
    model.eval()
    return model

def get_lsdm_model():
    config = OmegaConf.load(lsdm_config)  
    model = load_model_from_config(config, config.model.params.unet_config.params.ckpt_path)
    return model

# load LSDM model (autoencoder+diffusion):
lsdm_model = get_lsdm_model()
# load the sampler with classifier model:
sampler = DDIMSampler(lsdm_model,
                      classifier_guidance=key_feature,
                      classifier_config=classifier_config,
                      classifier_ckpt=classifier_ckpt,)

In [None]:
# Step 0.1: randomly sample a batch from the dataset (only use the semantic maps for generation):
"""
Parameters:
"""
bs = 4
dataset_dir = ""

dataset = DatasetWithClassifierAPIValidation(
    dataset_dir=dataset_dir,
    data_csv_name="data/BreastMam_Test_coordinate_switched.csv",
    feature_name="pathology",
    slice_name="new_name",
    make_binary=False,
    positive_threshold=1,
    label_scale=1,
    training_mode=True,
    crop_size=64,
    num_semantic_labels=3,  # [0,1,2]
    abnormal_area_threshold=2,  # used in masks
    mask_maxpooling_pixels=4,
    random_rotation=False,
    masked_image=False,
    masked_guidance=True,
    return_original_label=True,
)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
# randomly sample a batch using PyTorch DataLoader and transfer the conditionings into a dict:
for i, batch in enumerate(dataloader):
    c_dict = {}
    c_dict["c_spade"] = [batch["label"].to("cuda")]
    # for SPADE+concat models, add "c_concat" conditionings:
    c_dict["c_concat"] = [lsdm_model.get_learned_conditioning(batch["concat"].to("cuda"))]
    c_dict["mask"] = batch["mask"]
    c_dict["position"] = batch["position"]
    c_dict["crop_size_half"] = batch["crop_size_half"]
    break
# show all the keys of the conditionings:
print(c_dict.keys())
# show all the file names picked as the current batch:
print(batch["filename"])
fig = plt.figure(figsize=(10, 10))
for i in range(bs):
    ax = fig.add_subplot(bs // 2,2,i+1)
    ax.imshow(batch["original_image"][i,:,:], cmap="gray")

In [None]:
# Step 0.2: visualize the semantic maps:
fig = plt.figure(figsize=(10, 15))
for i in range(bs):
    ax = fig.add_subplot(bs // 2,2,i+1)
    ax.imshow(batch["original_label"][i,:,:])

In [None]:
# Step 1: inference with classifier guidance:
"""
Parameters:
"""
sampling_steps = 20
class_label = 0  # {0,1}
guidance_scale = 20.
"""
parameters:

-- S: number of sampling steps using DDIM.
-- conditioning: all the useful conditional information as a dict.
-- class_label: target classfier label. (0: benign; 1: malignant)
-- batch_size: batch size of the generated samples.
-- shape: shape of the generated latent.
-- verbose: default: False.
-- classifier_guidance_scale: key hyper-parameter of the classifier guidance. The scale of the guidance intensity
"""
samples_ddim, intermediates = sampler.sample_classifier_guidance(S=sampling_steps,
                                                                 conditioning=c_dict,
                                                                 class_label=class_label,  # 0: benign
                                                                 batch_size=bs,
                                                                 shape=[3, 128, 128],
                                                                 verbose=False,
                                                                 # alter this parameter to change the guidance scale:
                                                                 classifier_guidance_scale=guidance_scale,
                                                                 )
# normalize:
latent = (samples_ddim - torch.min(samples_ddim)) / (torch.max(samples_ddim) - torch.min(samples_ddim))
latent = torch.permute(latent, (0,2,3,1)).cpu().numpy()
latent = (latent * 255).astype(np.uint8)
fig1 = plt.figure(figsize=(10, 10))
for i in range(bs):
    ax1 = fig1.add_subplot(bs // 2,2,i+1)
    ax1.imshow(latent[i,:,:,:])
    
# generated images:
lsdm_model.eval()
x_samples_ddim = lsdm_model.decode_first_stage(samples_ddim)
generated = (x_samples_ddim - torch.min(x_samples_ddim)) / (torch.max(x_samples_ddim) - torch.min(x_samples_ddim))
generated = torch.permute(generated, (0,2,3,1)).cpu().numpy()
fig2 = plt.figure(figsize=(10, 10))
for i in range(bs):
    ax2 = fig2.add_subplot(bs//2,2,i+1)
    ax2.imshow(generated[i,:,:,:], cmap="gray")
    ax2.axis("off")
    plt.tight_layout()

In [None]:
# Step 1: inference with classifier guidance:
"""
Parameters:
"""
sampling_steps = 20
class_label = 1  # {0,1}
guidance_scale = 20.
"""
parameters:

-- S: number of sampling steps using DDIM.
-- conditioning: all the useful conditional information as a dict.
-- class_label: target classfier label. (0: benign; 1: malignant)
-- batch_size: batch size of the generated samples.
-- shape: shape of the generated latent.
-- verbose: default: False.
-- classifier_guidance_scale: key hyper-parameter of the classifier guidance. The scale of the guidance intensity
"""
samples_ddim, intermediates = sampler.sample_classifier_guidance(S=sampling_steps,
                                                                 conditioning=c_dict,
                                                                 class_label=class_label,  # 0: benign
                                                                 batch_size=bs,
                                                                 shape=[3, 128, 128],
                                                                 verbose=False,
                                                                 # alter this parameter to change the guidance scale:
                                                                 classifier_guidance_scale=guidance_scale,
                                                                 )
# normalize:
latent = (samples_ddim - torch.min(samples_ddim)) / (torch.max(samples_ddim) - torch.min(samples_ddim))
latent = torch.permute(latent, (0,2,3,1)).cpu().numpy()
latent = (latent * 255).astype(np.uint8)
fig1 = plt.figure(figsize=(10, 10))
for i in range(bs):
    ax1 = fig1.add_subplot(bs // 2,2,i+1)
    ax1.imshow(latent[i,:,:,:])
    
# generated images:
lsdm_model.eval()
x_samples_ddim = lsdm_model.decode_first_stage(samples_ddim)
generated = (x_samples_ddim - torch.min(x_samples_ddim)) / (torch.max(x_samples_ddim) - torch.min(x_samples_ddim))
generated = torch.permute(generated, (0,2,3,1)).cpu().numpy()
fig2 = plt.figure(figsize=(10, 10))
for i in range(bs):
    ax2 = fig2.add_subplot(bs//2,2,i+1)
    ax2.imshow(generated[i,:,:,:], cmap="gray")
    ax2.axis("off")
    plt.tight_layout()