In [10]:
import argparse, os, sys, glob, datetime, yaml
import numpy as np
import torch
from tqdm import trange
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image
from ldm.util import instantiate_from_config
import time

In [11]:
def load_model_from_config(config, sd):
    model = instantiate_from_config(config)
    model.load_state_dict(sd,strict=False)
    model.cuda()
    model.eval()
    return model


def load_model(config, ckpt, gpu, eval_mode):
    if ckpt:
        print(f"Loading model from {ckpt}")
        pl_sd = torch.load(ckpt, map_location="cpu")
        global_step = pl_sd["global_step"]
    else:
        pl_sd = {"state_dict": None}
        global_step = None
    model = load_model_from_config(config.model,
                                   pl_sd["state_dict"])

    return model, global_step

In [12]:
def custom_to_np(x):
    # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
    sample = x.detach().cpu()
    # sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    # sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous()
    return sample

def custom_to_pil(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.) / 2.
    x = x.permute(1, 2, 0).numpy()
    x = (255 * x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

In [13]:
def save_logs(logs,gt, path, n_saved=0, key="sample", np_path=None):
    for k in logs:
        if k == key:
            batch = logs[key]
            if np_path is None:
                for x in batch:
                    img = custom_to_pil(x)
                    imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
                    img.save(imgpath)
                    n_saved += 1
            else:
                npbatch = custom_to_np(batch)
                shape_str = "x".join([str(x) for x in npbatch.shape])
                nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
                np.savez(nppath, npbatch,gt)
                n_saved += npbatch.shape[0]
    return n_saved

In [14]:
def sample(model,x):
    log = dict()
    t0 = time.time()
    x = x.to(memory_format=torch.contiguous_format).float()
    t1 = time.time()
    x_sample, _ = model(x.to("cuda"))
    log["sample"] = x_sample
    log["time"] = t1 - t0
    # log['throughput'] = sample.shape[0] / (t1 - t0)
    # print(f'Throughput for this batch: {log["throughput"]}')
    return log


In [15]:
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):

    import ldm.data.ppg2abp as ppg2abp
    dataset = ppg2abp.PPG2ABPDataset_v3_Test()
    n_samples = len(dataset)
    train_loader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=False)
    # batch_size = 128
    n_saved = 0
    # path = logdir
    all_images = []
    all_gt_images = []
    print(f"Running conditional sampling for {n_samples} samples")
    for i in trange(n_samples // batch_size, desc="Sampling Batches (conditional)"):
        data = next(iter(train_loader))
        input = data['gt_image']
        # vae 
        logs = sample(model,input)
        all_images.extend([custom_to_np(logs["sample"])])
        all_gt_images.extend([data['gt_image']])
        n_saved = save_logs(logs,data['gt_image'], logdir, n_saved=n_saved, key="sample",np_path=nplog)
        if n_saved >= n_samples:
            print(f'Finish after generating {n_saved} samples')
            break
    all_img = np.concatenate(all_images, axis=0)
    all_img = all_img[:n_samples]
    all_gt_img = np.concatenate(all_gt_images, axis=0)
    all_gt_img = all_gt_img[:n_samples]
    shape_str = "x".join([str(x) for x in all_img.shape])
    nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
    np.savez(nppath, all_img,all_gt_img)

In [16]:
def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-r",
        "--resume",
        type=str,
        nargs="?",
        help="load from logdir or checkpoint in logdir",
        default=r"./models/first_stage_models/v3_ppg2abp-kl-f4_2"
    )
    parser.add_argument(
        "-n",
        "--n_samples",
        type=int,
        nargs="?",
        help="number of samples to draw",
        default=50000
    )
    parser.add_argument(
        "-e",
        "--eta",
        type=float,
        nargs="?",
        help="eta for ddim sampling (0.0 yields deterministic sampling)",
        default=1.0
    )
    parser.add_argument(
        "-v",
        "--vanilla_sample",
        default=False,
        action='store_true',
        help="vanilla sampling (default option is DDIM sampling)?",
    )
    parser.add_argument(
        "-l",
        "--logdir",
        type=str,
        nargs="?",
        help="extra logdir",
        default="none"
    )
    parser.add_argument(
        "-c",
        "--custom_steps",
        type=int,
        nargs="?",
        help="number of steps for ddim and fastdpm sampling",
        default=50
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        nargs="?",
        help="the bs",
        default=10
    )
    return parser

In [17]:
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
sys.path.append(os.getcwd())
command = " ".join(sys.argv)

parser = get_parser()
opt, unknown = parser.parse_known_args()
ckpt = None

if not os.path.exists(opt.resume):
    raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
    # paths = opt.resume.split("/")
    try:
        logdir = '\\'.join(opt.resume.split('\\')[:-1])
        # idx = len(paths)-paths[::-1].index("logs")+1
        print(f'Logdir is {logdir}')
    except ValueError:
        paths = opt.resume.split("\\")
        idx = -2  # take a guess: path/to/logdir/checkpoints/model.ckpt
        logdir = "\\".join(paths[:idx])
    ckpt = opt.resume
else:
    assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
    logdir = opt.resume.rstrip("\\")
    ckpt = os.path.join(logdir, "abp.ckpt")
base_configs = sorted(glob.glob(os.path.join(logdir, "abp.yaml")))
print("config",base_configs)
opt.base = base_configs

configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)

gpu = True
eval_mode = True
print(opt.logdir,logdir.split(os.sep),os.sep)
if opt.logdir != "none":
    locallog = logdir.split(os.sep)[-1]
    print(locallog)
    if locallog == "": locallog = logdir.split(os.sep)[-2]
    print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
    logdir = os.path.join(opt.logdir, locallog)

print(config)

model, global_step = load_model(config, ckpt, gpu, eval_mode)
print(f"global step: {global_step}")
print(75 * "=")
print("logging to:")
logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
imglogdir = os.path.join(logdir, "img")
numpylogdir = os.path.join(logdir, "numpy")

os.makedirs(imglogdir)
os.makedirs(numpylogdir)
print(logdir)
print(75 * "=")

# write config out
sampling_file = os.path.join(logdir, "sampling_config.yaml")
sampling_conf = vars(opt)

with open(sampling_file, 'w') as f:
    yaml.dump(sampling_conf, f, default_flow_style=False)
print(sampling_conf)




config ['./models/first_stage_models/v3_ppg2abp-kl-f4_2\\abp.yaml']
none ['./models/first_stage_models/v3_ppg2abp-kl-f4_2'] \
{'model': {'base_learning_rate': 4.5e-06, 'target': 'ldm.models.autoencoder1D_v1.AutoencoderKL', 'params': {'monitor': 'val/rec_loss', 'embed_dim': 3, 'image_key': 'gt_image', 'ddconfig': {'double_z': True, 'z_channels': 3, 'resolution': 256, 'in_channels': 1, 'out_ch': 1, 'ch': 128, 'ch_mult': [1, 2, 4], 'num_res_blocks': 2, 'attn_resolutions': [16], 'dropout': 0.0}, 'lossconfig': {'target': 'ldm.modules.losses.contperceptual.LPIPSWithDiscriminator_2', 'params': {'disc_start': 20001, 'kl_weight': 0.001, 'disc_weight': 0.5}}}}, 'data': {'target': 'main.DataModuleFromConfig', 'params': {'batch_size': 16, 'num_workers': 8, 'train': {'target': 'ldm.data.ppg2abp.PPG2ABPDataset_v3_Train', 'params': {'data_len': -1, 'size': 256}}, 'validation': {'target': 'ldm.data.ppg2abp.PPG2ABPDataset_v3_Val', 'params': {'size': 256}}}}, 'lightning': {'callbacks': {'image_logger': 

  pl_sd = torch.load(ckpt, map_location="cpu")


making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 3, 64) = 192 dimensions.
making attention of type 'vanilla' with 512 in_channels
global step: 187108
logging to:
./models/first_stage_models/v3_ppg2abp-kl-f4_2\samples\00187108\2024-12-06-01-07-47
{'resume': './models/first_stage_models/v3_ppg2abp-kl-f4_2', 'n_samples': 50000, 'eta': 1.0, 'vanilla_sample': False, 'logdir': 'none', 'custom_steps': 50, 'batch_size': 10, 'base': ['./models/first_stage_models/v3_ppg2abp-kl-f4_2\\abp.yaml']}


In [9]:
run(model, imglogdir, eta=opt.eta,
vanilla=opt.vanilla_sample,  n_samples=opt.n_samples, custom_steps=opt.custom_steps,
batch_size=opt.batch_size, nplog=numpylogdir)

print("done.")

True


Sampling Batches (conditional):   0%|          | 0/500 [00:00<?, ?it/s]

data prepared: (5000, 256, 2)
Running conditional sampling for 5000 samples


Sampling Batches (conditional): 100%|█████████▉| 499/500 [00:05<00:00, 91.70it/s] 

Finish after generating 5000 samples
done.





https://github.com/CompVis/latent-diffusion/issues/187

https://github.com/huggingface/diffusers/issues/437#issuecomment-1241827515

In [23]:
import torch
import torchvision
from torchvision.datasets.utils import download_and_extract_archive
from torchvision import transforms


num_workers = 4
batch_size = 12
# From https://github.com/fastai/imagenette
# IMAGENETTE_URL = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz'

torch.manual_seed(0)
torch.set_grad_enabled(False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# pretrained_model_name_or_path = 'CompVis/stable-diffusion-v1-4'
# vae = AutoencoderKL.from_pretrained(
#     pretrained_model_name_or_path,
#     subfolder='vae',
#     revision=None,
# )
vae = model
vae.to(device)

size = 256
# image_transform = transforms.Compose([
#     transforms.Resize(size),
#     transforms.CenterCrop(size),
#     transforms.ToTensor(),
#     transforms.Normalize([0.5], [0.5]),
# ])

# root = 'dataset'
# download_and_extract_archive(IMAGENETTE_URL, root)
import ldm.data.ppg2abp as ppg2abp
dataset = ppg2abp.PPG2ABPDataset_v3_Val(data_len=-1)
n_samples = len(dataset)
train_loader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=False)
# dataset = torchvision.datasets.ImageFolder(root, transform=image_transform)
# loader = torch.utils.data.DataLoader(
#     dataset,
#     batch_size=batch_size,
#     shuffle=True,
#     num_workers=num_workers,
# )

all_latents = []
for image_data in train_loader:
    image_data = image_data["gt_image"].to(device)
    latents = vae.encode(image_data).sample()
    all_latents.append(latents.cpu())

all_latents_tensor = torch.cat(all_latents)
std = all_latents_tensor.std().item()
normalizer = 1 / std
print(f'{normalizer = }')


True
data prepared: (62528, 256, 2)
normalizer = 0.42144223881543325
