# References
- ~~[LoRA](https://github.com/cloneofsimo/lora)~~
- [Lora for Diffusers](https://github.com/haofanwang/Lora-for-Diffusers)
- [Long Prompts](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_examples#long-prompt-weighting-stable-diffusion)

# Dependencies

In [None]:
! git clone https://github.com/AguilarLagunasArturo/diffusers.git
! pip install -U git+https://github.com/AguilarLagunasArturo/diffusers.git
! pip install omegaconf
! pip install transformers safetensors accelerate

In [None]:
# ! pip install git+https://github.com/cloneofsimo/lora.git
# ! pip install accelerate

# Imports

In [None]:
import os
import torch
import random
import datetime
from PIL import Image

# Functions

In [None]:
def imgSave(image, path, override_name=None):
  if override_name:
      save_name = override_name
  else:
      dt = datetime.datetime.now()
      save_name = f'{dt.year}-{dt.month}-{dt.year}_{dt.hour}-{dt.minute}-{dt.second}.png'
  os.makedirs(path, exist_ok=True)
  image.save(
      os.path.join(
          path,
          save_name
    )
  )

# Paths

In [None]:
drive_root = '/content/drive/MyDrive'
root_path = os.path.join(drive_root, 'Stable Diffusion')

models_folder = '_models'
models_path = os.path.join(root_path, models_folder)

text2image_folder = 'text2image'
text2image_path = os.path.join(root_path, text2image_folder)

text2video_folder = 'text2video'
text2video_path = os.path.join(root_path, text2video_folder)

In [None]:
base_models = []
lora_models = []
extensions = ('safetensors', 'ckpt')

for file in os.listdir(models_path):
    if file.endswith(extensions):
        lora_models.append(file)
    else:
        base_models.append(file)

base_model = base_models[-1]
lora_model = lora_models[-1]

# override models
# base_model = base_models[-1]
# lora_model = lora_models[-1]
base_model = base_models[0]
lora_model = lora_models[-1]

print(f'base_model: {base_model}')
print(f'lora_model: {lora_model}')

base_model_path = os.path.join(models_path, base_model)
lora_model_path = os.path.join(models_path, lora_model)

custom_model_name = f'{base_model}_{lora_model}'.replace(' ', '')
for ext in extensions:
    custom_model_name = custom_model_name.replace(f'.{ext}','')
custom_model_path = os.path.join(models_path, custom_model_name)

print(f'custom model: {custom_model_path}')

# Main

## Upscale

In [None]:
!pip install git+https://github.com/sberbank-ai/Real-ESRGAN.git

In [None]:
from RealESRGAN import RealESRGAN
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# model = RealESRGAN(device, scale=4)
# model.load_weights('weights/RealESRGAN_x4.pth', download=True)

model_scale = "4" # ["2", "4", "8"]
model = RealESRGAN(device, scale=int(model_scale))
model.load_weights(f'weights/RealESRGAN_x{model_scale}.pth', download=True)

# up_img = model.predict(images[0])
# imgSave(up_img)

## Load & merge w/ LoRA

### Modules

In [None]:
from safetensors.torch import load_file
#from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline
from diffusers import DPMSolverMultistepScheduler

### Img

In [None]:
# load diffusers model
model_id = base_model_path
# pipeline = StableDiffusionPipeline.from_pretrained(
#     model_id,
#    torch_dtype=torch.float32
#)
pipeline = DiffusionPipeline.from_pretrained(
    model_id,
    custom_pipeline="lpw_stable_diffusion",
    torch_dtype=torch.float16
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)

# load lora weight
model_path = lora_model_path
state_dict = load_file(model_path)

LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te'

alpha = 0.75
# alpha = 0.54
# alpha = 0.65
visited = []

# directly update weight in diffusers model
for key in state_dict:
    
    # it is suggested to print out the key, it usually will be something like below
    # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
    
    # as we have set the alpha beforehand, so just skip
    if '.alpha' in key or key in visited:
        continue
        
    if 'text' in key:
        layer_infos = key.split('.')[0].split(LORA_PREFIX_TEXT_ENCODER+'_')[-1].split('_')
        curr_layer = pipeline.text_encoder
    else:
        layer_infos = key.split('.')[0].split(LORA_PREFIX_UNET+'_')[-1].split('_')
        curr_layer = pipeline.unet

    # find the target layer
    temp_name = layer_infos.pop(0)
    while len(layer_infos) > -1:
        try:
            curr_layer = curr_layer.__getattr__(temp_name)
            if len(layer_infos) > 0:
                temp_name = layer_infos.pop(0)
            elif len(layer_infos) == 0:
                break
        except Exception:
            if len(temp_name) > 0:
                temp_name += '_'+layer_infos.pop(0)
            else:
                temp_name = layer_infos.pop(0)
    
    # org_forward(x) + lora_up(lora_down(x)) * multiplier
    pair_keys = []
    if 'lora_down' in key:
        pair_keys.append(key.replace('lora_down', 'lora_up'))
        pair_keys.append(key)
    else:
        pair_keys.append(key)
        pair_keys.append(key.replace('lora_up', 'lora_down'))
    
    # update weight
    if len(state_dict[pair_keys[0]].shape) == 4:
        weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
        weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
        curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
    else:
        weight_up = state_dict[pair_keys[0]].to(torch.float32)
        weight_down = state_dict[pair_keys[1]].to(torch.float32)
        curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
        
     # update visited list
    for item in pair_keys:
        visited.append(item)

pipeline = pipeline.to("cuda")
pipeline.safety_checker = lambda images, clip_input: (images, False)


### Vid

In [None]:
! pip install stable_diffusion_videos
from stable_diffusion_videos import StableDiffusionWalkPipeline #, Interface

In [None]:
# load diffusers model
model_id = base_model_path
# pipeline = StableDiffusionPipeline.from_pretrained(
#     model_id,
#    torch_dtype=torch.float32
#)
pipeline = StableDiffusionWalkPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    safety_checker=None,  # Very important for videos...lots of false positives while interpolating
    revision="fp16",
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)

# load lora weight
model_path = lora_model_path
state_dict = load_file(model_path)

LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te'

alpha = 0.75
# alpha = 0.54
# alpha = 0.65

visited = []

# directly update weight in diffusers model
for key in state_dict:
    
    # it is suggested to print out the key, it usually will be something like below
    # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
    
    # as we have set the alpha beforehand, so just skip
    if '.alpha' in key or key in visited:
        continue
        
    if 'text' in key:
        layer_infos = key.split('.')[0].split(LORA_PREFIX_TEXT_ENCODER+'_')[-1].split('_')
        curr_layer = pipeline.text_encoder
    else:
        layer_infos = key.split('.')[0].split(LORA_PREFIX_UNET+'_')[-1].split('_')
        curr_layer = pipeline.unet

    # find the target layer
    temp_name = layer_infos.pop(0)
    while len(layer_infos) > -1:
        try:
            curr_layer = curr_layer.__getattr__(temp_name)
            if len(layer_infos) > 0:
                temp_name = layer_infos.pop(0)
            elif len(layer_infos) == 0:
                break
        except Exception:
            if len(temp_name) > 0:
                temp_name += '_'+layer_infos.pop(0)
            else:
                temp_name = layer_infos.pop(0)
    
    # org_forward(x) + lora_up(lora_down(x)) * multiplier
    pair_keys = []
    if 'lora_down' in key:
        pair_keys.append(key.replace('lora_down', 'lora_up'))
        pair_keys.append(key)
    else:
        pair_keys.append(key)
        pair_keys.append(key.replace('lora_up', 'lora_down'))
    
    # update weight
    if len(state_dict[pair_keys[0]].shape) == 4:
        weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
        weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
        curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
    else:
        weight_up = state_dict[pair_keys[0]].to(torch.float32)
        weight_down = state_dict[pair_keys[1]].to(torch.float32)
        curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
        
     # update visited list
    for item in pair_keys:
        visited.append(item)

pipeline = pipeline.to("cuda")
pipeline.safety_checker = None
pipeline.enable_attention_slicing()
#pipeline.safety_checker = lambda images, clip_input: (images, False)
pipe_walk = pipeline

print(model_id)

## Gen img

### Single

In [None]:
seed = int(random.random()*100000)

prompts = [
    '',
]
negative = [
    '',
]

# p = random.choice(prompts)
# n = random.choice(negative)

p = prompts[-1]
n = negative[-1]

mobile = True
portrait = True

# [720, 1280]
res = [360, 640] if mobile else [512, 768]
# res = [512, 512]
res = sorted(res, reverse=True) if portrait else sorted(res)

print(res)

with torch.no_grad():
    img = pipeline(
        prompt=p,
        negative_prompt = n,
        height=res[0], # 512
        width=res[1], # 768
        num_inference_steps=80,
        guidance_scale=8,
        max_embeddings_multiples = 3,
    ).images[0]
img

In [None]:
imgSave(model.predict(img), text2image_path)

### Loop

In [None]:
mobile = False
portrait = True
res = [360, 640] if mobile else [512, 768]
res = sorted(res, reverse=True) if portrait else sorted(res)

seed = int(random.random()*100000)

prompts = [
    '',
]
negative = [
    '',
]

total = 100
for i in range(total):

    # mobile = random.choice([True, False])
    res = [360, 640] if mobile else [512, 768]

    p = random.choice(prompts)
    n = random.choice(negative)

    with torch.no_grad():
        img = pipeline(
            prompt=p,
            negative_prompt=n,
            height=res[1], # 512
            width=res[0], #random.choice( [res[1], res[0]] ), # 768
            num_inference_steps=75,
            guidance_scale=8,
            max_embeddings_multiples = 3,
        ).images[0]
    if mobile:
        imgSave(model.predict(img), os.path.join(text2image_path, 'mobile') )
    else:
        imgSave(model.predict(img), text2image_path)
    n = i+1
    print(f"[+] Progress {n}/{total}: {((n)/total)*100:.1f}%")

## Gen Vid

In [None]:
prompts = [
    '',
    '',
    '',
    '',
]
n = ''


video_path = pipe_walk.walk(
    prompts=prompts,
    negative_prompt=n,
    seeds=[
        int(random.random()*1000),
        int(random.random()*1000),
        int(random.random()*1000),
        int(random.random()*1000)
    ],
    num_interpolation_steps=50,
    output_dir=text2video_path,
    name = custom_model_path.split('/')[-1],
    batch_size=2,
    width=360,
    height=640,
    guidance_scale=8,
    num_inference_steps=80,
    fps=10,
    upsample=True
    # make_video=False
)

# End session

In [None]:
from google.colab import runtime
runtime.unassign()