In [None]:
import torch
import os
import time
import pandas as pd
from tqdm.auto import tqdm
import argparse

import sys
sys.path.append('../.')
from utils.load_util import load_sdxl_models, load_pipe



distillation_type='dmd' # what type of distillation model do you want to use ("dmd", "lcm", "turbo", "lightning")
device = 'cuda:0'
weights_dtype = torch.bfloat16

pipe, base_unet, base_scheduler, distilled_unet, distilled_scheduler = load_sdxl_models(distillation_type=distillation_type, 
                                                                                        weights_dtype=weights_dtype, 
                                                                                        device=device)

In [None]:
base_guidance_scale= 0
distilled_guidance_scale = 0

run_base_till_timestep = None # set to none if you want it to be automatically decided
run_distilled_from_timestep = 1


# how many total timesteps to set for schedulers
base_num_inference_steps = 4 
distilled_num_inference_steps = 4

# for paper consistent results use this
base_scheduler = distilled_scheduler

# set the timesteps for the model
base_scheduler.set_timesteps(base_num_inference_steps)
distilled_scheduler.set_timesteps(distilled_num_inference_steps)

# automatically figure out what is the natural point to turn off the base model
if run_base_till_timestep is None:
    # check the timestep from which you need to run the model
    distilled_timestep = distilled_scheduler.timesteps[run_distilled_from_timestep]

    # check the closest timestep in basemodel
    base_timesteps = abs(base_scheduler.timesteps - distilled_timestep)
    run_base_till_timestep = base_timesteps.argmin()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

num_images = 1
prompt = 'image of a dog'


# Initialize variables
total_time = 0
all_images = []

pipe.set_progress_bar_config(disable=True)
# Generate images
for i in tqdm(range(num_images)):
    # Generate random seed
    seed = np.random.randint(0, 2**32 - 1)
    generator = torch.manual_seed(seed)
    
    # First use base model
    pipe.unet = base_unet
    pipe.scheduler = base_scheduler
    
    start_time = time.perf_counter()
    base_latents = pipe(prompt=prompt, from_timestep=0, till_timestep=run_base_till_timestep, 
                         guidance_scale=base_guidance_scale, num_inference_steps=base_num_inference_steps, 
                         output_type='latent')
    
    # Switch to distilled model
    pipe.unet = distilled_unet
    pipe.scheduler = distilled_scheduler
    
    
    pil_image = pipe(prompt=prompt, start_latents=base_latents, guidance_scale=distilled_guidance_scale,
                      from_timestep=run_distilled_from_timestep, till_timestep=None, 
                      num_inference_steps=distilled_num_inference_steps)[0]
    end_time = time.perf_counter()
    
    runtime = end_time - start_time
    total_time += runtime
    
    display(pil_image)

# Individual Pipe Inference

In [None]:
import torch
import os
import time
import pandas as pd
from tqdm.auto import tqdm
import argparse

import sys
sys.path.append('.')
from utils.load_util import load_sdxl_models, load_pipe



distillation_type= None # set to None for base model
device = 'cuda:0'
weights_dtype = torch.bfloat16

pipe = load_pipe(distillation_type=distillation_type, 
                  weights_dtype=weights_dtype, 
                    device=device)

In [None]:
guidance_scale = 8
num_inference_steps = 50

In [None]:
import matplotlib.pyplot as plt
import numpy as np

num_images = 1  # 5x5 grid
prompt = 'image of a wizard'

# Initialize variables
total_time = 0
all_images = []

pipe.set_progress_bar_config(disable=True)
    
# Generate images
for i in tqdm(range(num_images)):
    # Generate random seed
    seed = np.random.randint(0, 2**32 - 1)
    generator = torch.manual_seed(seed)
    
    # First use base model    
    start_time = time.perf_counter()
    pil_image = pipe(prompt=prompt, guidance_scale=guidance_scale,
                      num_inference_steps=num_inference_steps)[0]
    end_time = time.perf_counter()
    
    runtime = end_time - start_time
    total_time += runtime
    
    # Convert PIL image to numpy array and append to list
    display(pil_image)