In [None]:
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, AutoencoderKL

In [None]:
# model_base should match the base image used during training.
model_base = "runwayml/stable-diffusion-v1-5"
#model_lora = "finetune/lora/living_room/checkpoint-500/pytorch_model.bin"
#model_lora = "finetune/lora/living_room/20230604-110114/checkpoint-15000/pytorch_model.bin"
model_lora = "finetune/lora/living_room/20230604-162709/checkpoint-15000/pytorch_model.bin"
model_vae = "stabilityai/sd-vae-ft-mse"

In [None]:
# Load base model.
vae = AutoencoderKL.from_pretrained(
    model_vae,
    torch_dtype=torch.float16
)
pipe = StableDiffusionPipeline.from_pretrained(
    model_base,
    vae=vae,
    torch_dtype=torch.float16,
    safety_checker=None,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

In [None]:
# Load LoRA weights.
pipe.unet.load_attn_procs(model_lora)
pipe.to("cuda")

# Base vs. LoRA weighting

The `scale` parameter controls the relative weighting of the base model vs. LoRA weights.

- `scale=0` means use only the base model weights.
- `scale=1` means use only the LoRA model weights.

In [None]:
prompt = "interior design, living room, modern clean no clutter, high res, 4k"
seed = 5

## scale=0: Base model only

In [None]:
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(
    prompt,
    num_inference_steps=25,
    guidance_scale=7.5,
    cross_attention_kwargs={"scale": 0},
    generator=generator,
).images[0]
image

## scale=0.5: 50/50 base model and LoRA

In [None]:
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(
    prompt,
    num_inference_steps=25,
    guidance_scale=7.5,
    cross_attention_kwargs={"scale": 0.5},
    generator=generator,
).images[0]
image

## scale=0.8 Mostly LoRA

In [None]:
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(
    prompt,
    num_inference_steps=25,
    guidance_scale=7.5,
    cross_attention_kwargs={"scale": 0.8},
    generator=generator,
).images[0]
image