In [None]:
%cd ..

#### Current code adapted from <a href="https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py">diffusers</a>  
  

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import random
import hashlib
from copy import deepcopy
from pathlib import Path

import PIL
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
from matplotlib import pyplot as plt
from omegaconf.dictconfig import DictConfig
from huggingface_hub import hf_hub_url, cached_download

from kandinsky2 import CONFIG_2_1, Kandinsky2_1 

### Training parameters 

In [3]:
device = 'cuda'
task_type = 'text2img'
cache_root = '/tmp/kandinsky2'

# Fill here -------------------------------------------------------
class_prompt = 'train' # Global object name (train, bear, etc)
instance_prompt = 'sapsan' # Unique object name (sapsan, *, etc)
class_data_dir = './finetune/input/sapsan'  # folder with your images
instance_data_dir = './finetune/input/sapsan/instance_images' # folder with generated images for prior loss
out_folder = './finetune/output/sapsan' # your folder for saved model and images
# ----------------------------------------------------------------

# Pretrained weights -----------------------------------------------
# model_path = os.path.join(out_folder, "decoder_fp16.ckpt") # None if not exists
model_path = None  # Original Kandinsky-2.1 unet if None

# Generated Images -----------------------------------------------
# csv_path =  os.path.join(class_data_dir, "class_csv.csv") # None if not exists
csv_path = None  # Create new images if None

os.makedirs(out_folder, exist_ok=True)

img_size = 512
epochs = 4
log_image_frequency = -1 # -1 disable image logging
log_model_frequency = 50

prior_loss_weight = 1.0
num_class_images = 256
center_crop = False

lr = 5e-6
beta1 = 0.9
beta2 = 0.999
weight_decay = 1e-2
epsilon = 1e-08

num_workers = 0
train_batch_size = 1
sample_batch_size = 1

### Helper functions

In [4]:
def show_image(image, figsize=(5, 5), cmap=None, title='', xlabel=None, ylabel=None, axis=False):
    plt.figure(figsize=figsize)
    plt.imshow(image, cmap=cmap)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.axis(axis)
    plt.show();

def show_images(images, n_rows=1, title='', figsize=(5, 5), cmap=None, xlabel=None, ylabel=None, axis=False):
    n_cols = len(images) // n_rows
    if n_rows == n_cols == 1:
        show_image(images[0], title=title, figsize=figsize, cmap=cmap, xlabel=xlabel, ylabel=ylabel, axis=axis)
    else:
        fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
        fig.tight_layout(pad=0.0)
        axes = axes.flatten()
        for ax, img in zip(axes, images):
            ax.imshow(img, cmap=cmap)
            ax.set_title(title)
            ax.set_xlabel(xlabel)
            ax.set_ylabel(ylabel)
            ax.axis(axis)
        plt.show();
        
def download_models_if_not_exist(
    task_type="text2img",
    cache_dir="/tmp/kandinsky2",
    use_auth_token=None,
):
    cache_dir = os.path.join(cache_dir, "2_1")
    if task_type == "text2img":
        model_name = "decoder_fp16.ckpt"
        config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name)
    elif task_type == "inpainting":
        model_name = "inpainting_fp16.ckpt"
        config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name)
    cached_download(
        config_file_url,
        cache_dir=cache_dir,
        force_filename=model_name,
        use_auth_token=use_auth_token,
    )
    prior_name = "prior_fp16.ckpt"
    config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=prior_name)
    cached_download(
        config_file_url,
        cache_dir=cache_dir,
        force_filename=prior_name,
        use_auth_token=use_auth_token,
    )
    cache_dir_text_en = os.path.join(cache_dir, "text_encoder")
    for name in [
        "config.json",
        "pytorch_model.bin",
        "sentencepiece.bpe.model",
        "special_tokens_map.json",
        "tokenizer.json",
        "tokenizer_config.json",
    ]:
        config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=f"text_encoder/{name}")
        cached_download(
            config_file_url,
            cache_dir=cache_dir_text_en,
            force_filename=name,
            use_auth_token=use_auth_token,
        )
    config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="movq_final.ckpt")
    cached_download(
        config_file_url,
        cache_dir=cache_dir,
        force_filename="movq_final.ckpt",
        use_auth_token=use_auth_token,
    )
    config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="ViT-L-14_stats.th")
    cached_download(
        config_file_url,
        cache_dir=cache_dir,
        force_filename="ViT-L-14_stats.th",
        use_auth_token=use_auth_token,
    )
    
def add_noise(original_samples, noise, timesteps):
    num_diffusion_timesteps = 1000
    scale = 1000 / num_diffusion_timesteps
    beta_start = scale * 0.00085
    beta_end = scale * 0.012
        
    betas = torch.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=original_samples.dtype)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
        
    alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
    timesteps = timesteps.to(original_samples.device)

    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
    while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
    while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

    noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
    return noisy_samples
    
def save_images(model, save_path, instance_prompt, class_prompt, img_size=512):
    simple_images = model.generate_text2img(
            f"a photo of a {instance_prompt} {class_prompt}",
            num_steps=50, 
            batch_size=2, 
            guidance_scale=7.5,
            h=img_size, 
            w=img_size,
            sampler="p_sampler", 
            prior_cf_scale=4,
            prior_steps="5",
    )
    cool_images = model.generate_text2img(
            f"Professional high-quality photo of a {instance_prompt} {class_prompt}. photorealistic, 4k, HQ",
            num_steps=50, 
            batch_size=2, 
            guidance_scale=7.5,
            h=img_size, 
            w=img_size,
            sampler="p_sampler", 
            prior_cf_scale=4,
            prior_steps="5",
    )
    images = [*simple_images, *cool_images]
    instance_images = np.hstack([np.array(img) for img in images])
    
    simple_images = model.generate_text2img(
            f"a photo of a {class_prompt}",
            num_steps=50, 
            batch_size=2, 
            guidance_scale=7.5,
            h=img_size, 
            w=img_size,
            sampler="p_sampler", 
            prior_cf_scale=4,
            prior_steps="5",
    )
    cool_images = model.generate_text2img(
            f"Professional high-quality photo of a {class_prompt}. photorealistic, 4k, HQ",
            num_steps=50, 
            batch_size=2, 
            guidance_scale=7.5,
            h=img_size, 
            w=img_size,
            sampler="p_sampler", 
            prior_cf_scale=4,
            prior_steps="5",
    )
    images = [*simple_images, *cool_images]
    class_images = np.hstack([np.array(img) for img in images])
    gen_images = np.vstack([instance_images, class_images])
    Image.fromarray(gen_images).save(save_path)

In [5]:
# prompts from original dreambooth repo + from diffusers + custom words
# https://github.com/google/dreambooth/blob/main/dataset/prompts_and_classes.txt
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py

instance_templates = [
    "a photo of a {}",
    "a rendering of a {}",
    "a cropped photo of the {}",
    "the photo of a {}",
    "a photo of a clean {}",
    "a photo of a dirty {}",
    "a dark photo of the {}",
    "a photo of my {}",
    "a photo of the cool {}",
    "a close-up photo of a {}",
    "a bright photo of the {}",
    "a cropped photo of a {}",
    "a photo of the {}",
    "a good photo of the {}",
    "a photo of one {}",
    "a close-up photo of the {}",
    "a rendition of the {}",
    "a photo of the clean {}",
    "a rendition of a {}",
    "a photo of a nice {}",
    "a good photo of a {}",
    "a photo of the nice {}",
    "a photo of the small {}",
    "a photo of the weird {}",
    "a photo of the large {}",
    "a photo of a cool {}",
    "a photo of a small {}",
]

class_templates = [
    "a photo of a {}",
    "a rendering of a {}",
    "a cropped photo of the {}",
    "the photo of a {}",
    "a photo of a clean {}",
    "a photo of a dirty {}",
    "a dark photo of the {}",
    "a photo of my {}",
    "a photo of the cool {}",
    "a close-up photo of a {}",
    "a bright photo of the {}",
    "a cropped photo of a {}",
    "a photo of the {}",
    "a good photo of the {}",
    "a photo of one {}",
    "a close-up photo of the {}",
    "a rendition of the {}",
    "a photo of the clean {}",
    "a rendition of a {}",
    "a photo of a nice {}",
    "a good photo of a {}",
    "a photo of the nice {}",
    "a photo of the small {}",
    "a photo of the weird {}",
    "a photo of the large {}",
    "a photo of a cool {}",
    "a photo of a small {}",
    'a {} in the jungle',
    'a {} in the snow',
    'a {} on the beach',
    'a {} on a cobblestone street',
    'a {} on top of pink fabric',
    'a {} on top of a wooden floor',
    'a {} with a city in the background',
    'a {} with a mountain in the background',
    'a {} with a blue house in the background',
    'a {} on top of a purple rug in a forest',
    'a {} with a wheat field in the background',
    'a {} with a tree and autumn leaves in the background',
    'a {} with the Eiffel Tower in the background',
    'a {} floating on top of water',
    'a {} floating in an ocean of milk',
    'a {} on top of green grass with sunflowers around it',
    'a {} on top of a mirror',
    'a {} on top of the sidewalk in a crowded street',
    'a {} on top of a dirt road',
    'a {} on top of a white rug',
    'a red {}',
    'a purple {}',
    'a shiny {}',
    'a wet {}',
    'a cube shaped {}',
]
           
extend_words = [
    'photorealistic', 'epic', 'high quality',
    'cinematic', 'extremely high detail', 
    'cinematic lighting', 'trending on artstation', 
    'cgsociety', 'realistic rendering of Unreal Engine 5', 
    '8k', '4k', 'HQ', 'wallpaper',
]

def get_prompt_extention(extend_words):
    total_samples = len(extend_words)
    n_samples = random.randint(0, total_samples)
    additional_samples = random.sample(extend_words, n_samples)
    random.shuffle(additional_samples)
    p_extention = ', '.join(additional_samples) if len(additional_samples) else ''
    return p_extention

class DreamBoothDataset(Dataset):
    def __init__(
        self,
        instance_data_root,
        instance_prompt,
        class_prompt,
        class_csv, # info about image_path and prompt
        size=512,
        center_crop=False,
    ):
        self.size = size
        self.center_crop = center_crop

        self.instance_data_root = Path(instance_data_root)
        if not self.instance_data_root.exists():
            raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")

        self.instance_images_path = list(Path(instance_data_root).iterdir())
        self.num_instance_images = len(self.instance_images_path)
        self.instance_prompt = instance_prompt
        self.class_prompt = class_prompt
        self.class_csv = class_csv
        self._length = class_csv.shape[0]

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image)
        example["instance_prompt"] = random.choice(instance_templates).format(f'{self.instance_prompt} {self.class_prompt}')  #f'a photo of {self.instance_prompt} {self.class_prompt}' 

        class_image = Image.open(class_csv.iloc[index % self._length]['image_path'])
        class_prompt = class_csv.iloc[index % self._length]['prompt']

        if not class_image.mode == "RGB":
            class_image = class_image.convert("RGB")

        example["class_images"] = self.image_transforms(class_image)
        example["class_prompt"] = class_prompt

        return example
    
    
def collate_fn(examples, with_prior_preservation=True):
    input_prompt = [example["instance_prompt"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]

    if with_prior_preservation:
        input_prompt += [example["class_prompt"] for example in examples]
        pixel_values += [example["class_images"] for example in examples]

    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

    batch = {
        "prompt": input_prompt,
        "image": pixel_values,
    }
    return batch

### Define config and create Kandinsky model

In [6]:
config = DictConfig(deepcopy(CONFIG_2_1))

cache_dir = os.path.join(cache_root, "2_1")

config["model_config"]["up"] = False
config["model_config"]["use_fp16"] = False
config["model_config"]["inpainting"] = False
config["model_config"]["cache_text_emb"] = False
config["model_config"]["use_flash_attention"] = False

config["tokenizer_name"] = os.path.join(cache_dir, "text_encoder")
config["text_enc_params"]["model_path"] = os.path.join(cache_dir, "text_encoder")
config["prior"]["clip_mean_std_path"] = os.path.join(cache_dir, "ViT-L-14_stats.th")
config["image_enc_params"]["ckpt_path"] = os.path.join(cache_dir, "movq_final.ckpt")

model_path = os.path.join(cache_dir, "decoder_fp16.ckpt") if model_path is None else model_path
prior_path = os.path.join(cache_dir, "prior_fp16.ckpt")

download_models_if_not_exist(task_type=task_type, cache_dir=cache_root)



### Generate class images for prior loss

In [1]:
class_csv = pd.read_csv(csv_path) if csv_path is not None else None

if class_csv is None:
    class_csv = pd.DataFrame(columns=['image_path', 'prompt'])
cur_class_images = class_csv.shape[0]

if cur_class_images < num_class_images:
    model = Kandinsky2_1(config, model_path, prior_path, device, task_type=task_type)
    
    num_new_images = num_class_images - cur_class_images
    print(f"Number of class images to sample: {num_new_images}.")

    for index_example in tqdm(range(num_new_images // sample_batch_size), desc="Generating class images"):
        prompt = random.choice(class_templates).format(class_prompt) + '. ' + get_prompt_extention(extend_words)
        images = model.generate_text2img(
            prompt,
            num_steps=50, 
            batch_size=sample_batch_size, 
            guidance_scale=7.5,
            h=img_size, 
            w=img_size,
            sampler="p_sampler", 
            prior_cf_scale=4,
            prior_steps="8",
        )
        
        for i, image in enumerate(images):
            hash_image = hashlib.sha1(image.tobytes()).hexdigest()
            image_filename = os.path.join(class_data_dir, 'class_images', f"{index_example + cur_class_images}-{hash_image}.jpg")
            image.save(image_filename)
            
            class_csv = class_csv.append({
                'image_path': image_filename,
                'prompt': prompt
            }, ignore_index=True)
    
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    cur_class_images += num_new_images
    
    class_csv.to_csv(f'{class_data_dir}/class_csv.csv')

In [2]:
model = Kandinsky2_1(config, model_path, prior_path, device, task_type=task_type)

In [9]:
## Freeze all except unet
model.model.requires_grad_(True)

model.image_encoder.requires_grad_(False);
model.prior.requires_grad_(False);
model.clip_model.requires_grad_(False);
model.text_encoder.requires_grad_(False);

In [10]:
# Configure optimizer, dataset and dataloader
optimizer = torch.optim.AdamW(
    model.model.parameters(),
    lr=lr,
    betas=(beta1, beta2),
    weight_decay=weight_decay,
    eps=epsilon,
)

train_dataset = DreamBoothDataset(
    instance_data_root=instance_data_dir,
    instance_prompt=instance_prompt,
    class_prompt=class_prompt,
    class_csv=class_csv,
    size=img_size,
    center_crop=center_crop,
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn
)

In [11]:
def generate_clip_emb(
        model,
        prompts_batch,
        negative_prompts_batch,
        prior_cf_scale=1,
        prior_steps="5",
    ):
    prior_cf_scales_batch = [prior_cf_scale] * len(prompts_batch)
    prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=model.device)
    max_txt_length = model.prior.model.text_ctx
    tok, mask = model.tokenizer2.padded_tokens_and_mask(
        prompts_batch, max_txt_length
    )
    cf_token, cf_mask = model.tokenizer2.padded_tokens_and_mask(
        negative_prompts_batch, max_txt_length
    )
    if not (cf_token.shape == tok.shape):
        cf_token = cf_token.expand(tok.shape[0], -1)
        cf_mask = cf_mask.expand(tok.shape[0], -1)
    tok = torch.cat([tok, cf_token], dim=0)
    mask = torch.cat([mask, cf_mask], dim=0)
    tok, mask = tok.to(device=model.device), mask.to(device=model.device)
    x = model.clip_model.token_embedding(tok).type(model.clip_model.dtype)
    x = x + model.clip_model.positional_embedding.type(model.clip_model.dtype)
    x = x.permute(1, 0, 2)  # NLD -> LND|
    x = model.clip_model.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = model.clip_model.ln_final(x).type(model.clip_model.dtype)
    txt_feat_seq = x
    txt_feat = (x[torch.arange(x.shape[0]), tok.argmax(dim=-1)] @ model.clip_model.text_projection)
    txt_feat, txt_feat_seq = txt_feat.float().to(model.device), txt_feat_seq.float().to(model.device)
    
    img_feat = model.prior(
        txt_feat,
        txt_feat_seq,
        mask,
        prior_cf_scales_batch,
        timestep_respacing=prior_steps,
    )
    return img_feat.to(model.model_dtype)

In [3]:
weight_dtype = model.model.dtype
model.clip_model.to(weight_dtype)

progress_bar_epochs = tqdm(range(1, epochs + 1))

for epoch in progress_bar_epochs:
    model.model.train()
    for batch in train_dataloader:
        model_kwargs = {}
        # Convert images to latent representation and add noise
        latents = model.image_encoder.encode(batch["image"].to(device=device, dtype=weight_dtype))
        latents = latents * model.scale
        timesteps = torch.randint(0, 1000, (train_batch_size,), device=latents.device)
        timesteps = timesteps.long()
        noise = torch.randn_like(latents)
        noisy_latents = add_noise(latents, noise, timesteps).to(weight_dtype)
        
        image_emb = generate_clip_emb(
            model,
            batch["prompt"],
            ["" for _ in range(train_batch_size)],
            prior_cf_scale=4,
            prior_steps="8",
        )
        
        model_kwargs["image_emb"] = image_emb.to(weight_dtype)
        
        # Second Model Parameters
        tokens = model.tokenizer1(
            batch["prompt"],
            padding="max_length",
            truncation=True,
            max_length=77,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt",
        )

        model_kwargs["full_emb"], model_kwargs["pooled_emb"] = model.text_encoder(
            tokens=tokens['input_ids'].long().to(device=device), 
            mask=tokens['attention_mask'].to(device=device),
        )

        model_kwargs["full_emb"] = model_kwargs["full_emb"].to(weight_dtype) 
        model_kwargs["pooled_emb"] = model_kwargs["pooled_emb"].to(weight_dtype) 
        
        # Predict noise obviously
        model_pred = model.model(noisy_latents, timesteps, **model_kwargs)[:, :4]
        
        model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
        target, target_prior = torch.chunk(noise, 2, dim=0)

        # Compute instance loss
        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        # Compute prior loss
        prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

        # Add the prior loss to the instance loss.
        loss = loss + prior_loss_weight * prior_loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        progress_bar_epochs.set_postfix(**{"loss": loss.cpu().detach().item()})
        
    if log_image_frequency > 0 and (epoch % log_image_frequency == 0):
        images_root = os.path.join(out_folder, "images")
        os.makedirs(images_root, exist_ok=True)
        image_save_path = os.path.join(images_root, f"{epoch}_epoch_images.jpg")
        save_images(model, image_save_path, instance_prompt, class_prompt)

In [13]:
# Save our new unet model
out_model_path = os.path.join(out_folder, "decoder_fp16.ckpt")
torch.save(model.model.state_dict(), out_model_path)

In [4]:
images = model.generate_text2img(
            f'photo of {instance_prompt} {class_prompt}',
            num_steps=50, 
            batch_size=2, 
            guidance_scale=7.5,
            h=img_size, 
            w=img_size,
            sampler="p_sampler", 
            prior_cf_scale=4,
            prior_steps="8",
        )
show_images(images, n_rows=1, figsize=(15, 15))

In [5]:
images = model.generate_text2img(
            f'Professional high-quality photo of {instance_prompt} {class_prompt}. photorealistic, 4k, HQ',
            num_steps=50, 
            batch_size=2, 
            guidance_scale=7.5,
            h=img_size, 
            w=img_size,
            sampler="ddim_sampler", 
            prior_cf_scale=4,
            prior_steps="8",
        )
show_images(images, n_rows=1, figsize=(15, 15))

In [6]:
images = model.generate_text2img(
            f'a photo of a helicopter {instance_prompt} {class_prompt}.',
            num_steps=50, 
            batch_size=2, 
            guidance_scale=7.5,
            h=img_size, 
            w=img_size,
            sampler="ddim_sampler", 
            prior_cf_scale=4,
            prior_steps="8",
        )

show_images(images, n_rows=1, figsize=(15, 15))