In [None]:
%load_ext memory_profiler

In [None]:
import torch
import matplotlib.pyplot as plt

from ldm_uncond.latent_diffusion_uncond import LDMPipeline

DTYPE = torch.float16
DEVICE = torch.device('cuda')

### Init pipeline, move to device and warmup

In [None]:
diffusion_pipeline = LDMPipeline()

diffusion_pipeline = diffusion_pipeline.to(device=DEVICE, dtype=DTYPE)
diffusion_pipeline.eval()
diffusion_pipeline.warmup()

### Sample an image and measure time

In [None]:
%%time
## Generate sample

noise = torch.randn((1, 3, 64, 64), dtype=DTYPE, device=DEVICE)
# with torch.cuda.amp.autocast():
%timeit %memit sample = diffusion_pipeline(noise)

### Visualize sample

In [None]:
plt.imshow(sample.cpu().float().numpy()[0]/255)

### Load optimized UNet

In [None]:
optimized_diffusion_pipeline = LDMPipeline()

optimized_diffusion_pipeline = optimized_diffusion_pipeline.to(device=DEVICE, dtype=DTYPE)
optimized_diffusion_pipeline.load_optimized_unet("uldm_unet_fp16_sim.ts")
optimized_diffusion_pipeline.eval()
optimized_diffusion_pipeline.warmup()

### Sample from optimized network

In [None]:
%%time
## Generate sample

noise = torch.randn((1, 3, 64, 64), dtype=DTYPE, device=DEVICE)
# with torch.cuda.amp.autocast():
%timeit %memit sample = optimized_diffusion_pipeline(noise)

### Visualize sample from optimized model

In [None]:
plt.imshow(sample.cpu().float().numpy()[0]/255)