In [None]:
# Step 0: load the checkpoints (autoencoder, diffusion, and classifier models)
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
from lsdm.models.diffusion.ddim import DDIMSampler
from omegaconf import OmegaConf
from lsdm.util import instantiate_from_config
import albumentations
from torchvision import transforms
from lsdm.data.general import DatasetWithClassifierAPIValidation
from torch.utils.data import DataLoader

# TODO: create meta config file.
device = "cuda:0"
key_feature = "others"  # for new csv files, choose "others"
lsdm_config = "configs/latent-diffusion/lsdm_general.yaml"
classifier_config = "configs/classifier/classifier_general.yaml"
classifier_ckpt = ""  # BreastMam classifier ckpt

def load_model_from_config(config, ckpt, device=device):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location=torch.device("cpu"))#, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.to(torch.device(device))  # load on cuda:1
    model.eval()
    return model

def get_lsdm_model():
    config = OmegaConf.load(lsdm_config)  
    model = load_model_from_config(config, config.model.params.unet_config.params.ckpt_path)
    return model

# load LSDM model (autoencoder+diffusion):
lsdm_model = get_lsdm_model()
# load the sampler with classifier model:
sampler = DDIMSampler(lsdm_model,
                      classifier_guidance=key_feature,
                      classifier_config=classifier_config,
                      classifier_ckpt=classifier_ckpt,)