In [None]:
!pip install transformers diffusers ftfy accelerate

In [None]:
import os
import torch
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

In [None]:
torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"

In [None]:
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_8step_unet.safetensors" # Use the correct ckpt for your step setting!

# Load model.
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch_device, torch.float16)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=torch_device))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(torch_device)

# Ensure sampler uses "trailing" timesteps.
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

In [None]:
pipe.load_textual_inversion('concept-art.bin', '<concept-art>')
pipe.load_textual_inversion('doose-s-realistic-art-style.bin', '<doose-realistic>')
pipe.load_textual_inversion('line-art.bin', '<line_art>')
pipe.load_textual_inversion('rickyart.bin', '<RickyArt>')
pipe.load_textual_inversion('tony-diterlizzi-s-planescape-art.bin', '<tony-diterlizzi-planescape>')

In [None]:
# Ensure using the same inference steps as the loaded model and CFG set to 0.
pipe("A mouse in the style of <concept-art>", num_inference_steps=8, guidance_scale=0).images[0]

In [None]:
def guide_loss(diffusion_pipeline, step, timestep, kwarg):
    print(diffusion_pipeline.dtype)

In [None]:
# Ensure using the same inference steps as the loaded model and CFG set to 0.
pipe("Sunset on mountains in the style of <tony-diterlizzi-planescape>", num_inference_steps=8,
     guidance_scale=0, callback_on_step_end = guide_loss).images[0]