In [None]:
%cd ..

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import random
from copy import deepcopy

import PIL
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
from matplotlib import pyplot as plt
from omegaconf.dictconfig import DictConfig
from huggingface_hub import hf_hub_url, cached_download

from kandinsky2 import CONFIG_2_1, Kandinsky2_1 

### Training parameters 
#### Necessarily fill initializer_token, placeholder_token, data_root and out_folder. You can change other parameters if you want.

In [3]:
device = 'cuda'
task_type = 'text2img'
cache_root = '/tmp/kandinsky2'

# Fill here -------------------------------------------------------
initializer_token = 'train' # Global object name (train, bear, etc)
placeholder_token = 'sapsan' # Unique object name (sapsan, *, etc)
data_root = './finetune/input/sapsan/instance_images' # your folder with images
out_folder = './finetune/output/sapsan' # your folder for saved embeddings
# ----------------------------------------------------------------

os.makedirs(out_folder, exist_ok=True)

img_size = 512
epochs = 3000
log_image_frequency = 250 # -1 disable image logging
log_embed_frequency = 250

lr = 1e-4
beta1 = 0.9
beta2 = 0.999
weight_decay = 1e-2
epsilon = 1e-08

num_workers = 0
batch_size = 1

### Helper functions

In [4]:
def download_models_if_not_exist(
    task_type="text2img",
    cache_dir="/tmp/kandinsky2",
    use_auth_token=None,
):
    cache_dir = os.path.join(cache_dir, "2_1")
    if task_type == "text2img":
        model_name = "decoder_fp16.ckpt"
        config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name)
    elif task_type == "inpainting":
        model_name = "inpainting_fp16.ckpt"
        config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name)
    cached_download(
        config_file_url,
        cache_dir=cache_dir,
        force_filename=model_name,
        use_auth_token=use_auth_token,
    )
    prior_name = "prior_fp16.ckpt"
    config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=prior_name)
    cached_download(
        config_file_url,
        cache_dir=cache_dir,
        force_filename=prior_name,
        use_auth_token=use_auth_token,
    )
    cache_dir_text_en = os.path.join(cache_dir, "text_encoder")
    for name in [
        "config.json",
        "pytorch_model.bin",
        "sentencepiece.bpe.model",
        "special_tokens_map.json",
        "tokenizer.json",
        "tokenizer_config.json",
    ]:
        config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=f"text_encoder/{name}")
        cached_download(
            config_file_url,
            cache_dir=cache_dir_text_en,
            force_filename=name,
            use_auth_token=use_auth_token,
        )
    config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="movq_final.ckpt")
    cached_download(
        config_file_url,
        cache_dir=cache_dir,
        force_filename="movq_final.ckpt",
        use_auth_token=use_auth_token,
    )
    config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="ViT-L-14_stats.th")
    cached_download(
        config_file_url,
        cache_dir=cache_dir,
        force_filename="ViT-L-14_stats.th",
        use_auth_token=use_auth_token,
    )
    
def add_noise(original_samples, noise, timesteps):
    num_diffusion_timesteps = 1000
    scale = 1000 / num_diffusion_timesteps
    beta_start = scale * 0.00085
    beta_end = scale * 0.012
        
    betas = torch.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=original_samples.dtype)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
        
    alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
    timesteps = timesteps.to(original_samples.device)

    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
    while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
    while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

    noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
    return noisy_samples

def generate_clip_emb(model,
        prompt,
        batch_size=1,
        prior_cf_scale=1,
        prior_steps="5",
        negative_prior_prompt="",
    ):
    prompts_batch = [prompt for _ in range(batch_size)]
    prior_cf_scales_batch = [prior_cf_scale] * len(prompts_batch)
    prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=model.device)
    max_txt_length = model.prior.model.text_ctx
    tok, mask = model.tokenizer2.padded_tokens_and_mask(
        prompts_batch, max_txt_length
    )
    cf_token, cf_mask = model.tokenizer2.padded_tokens_and_mask(
        [negative_prior_prompt], max_txt_length
    )
    if not (cf_token.shape == tok.shape):
        cf_token = cf_token.expand(tok.shape[0], -1)
        cf_mask = cf_mask.expand(tok.shape[0], -1)
    tok = torch.cat([tok, cf_token], dim=0)
    mask = torch.cat([mask, cf_mask], dim=0)
    tok, mask = tok.to(device=model.device), mask.to(device=model.device)
    x = model.clip_model.token_embedding(tok).type(model.clip_model.dtype)
    x = x + model.clip_model.positional_embedding.type(model.clip_model.dtype)
    x = x.permute(1, 0, 2)  # NLD -> LND|
    x = model.clip_model.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = model.clip_model.ln_final(x).type(model.clip_model.dtype)
    txt_feat_seq = x
    txt_feat = (x[torch.arange(x.shape[0]), tok.argmax(dim=-1)] @ model.clip_model.text_projection)
    txt_feat, txt_feat_seq = txt_feat.float().to(model.device), txt_feat_seq.float().to(model.device)
    
    img_feat = model.prior(
        txt_feat,
        txt_feat_seq,
        mask,
        prior_cf_scales_batch,
        timestep_respacing=prior_steps,
    )
    return img_feat.to(model.model_dtype)


def save_embeds(model, save_path, placeholder_token, t1_place_token_id, t2_place_token_id):
    t1_embeds = model.text_encoder.model.transformer.get_input_embeddings().weight[t1_place_token_id]
    t2_embeds = model.clip_model.token_embedding.weight[t2_place_token_id]
    learned_embeds_dict = {
        't1': {
            placeholder_token: t1_embeds.cpu().detach(), 
        },
        't2':{
            placeholder_token: t2_embeds.cpu().detach(),
        },
    }
    torch.save(learned_embeds_dict, save_path)
    
    
def save_images(model, save_path, placeholder_token, img_size=512):
    gen_images = model.generate_text2img(
            f"a photo of a {placeholder_token}",
            num_steps=50, 
            batch_size=4, 
            guidance_scale=7.5,
            h=img_size, 
            w=img_size,
            sampler="p_sampler", 
            prior_cf_scale=4,
            prior_steps="5",
        )

    gen_images = np.hstack([np.array(img) for img in gen_images])
    Image.fromarray(gen_images).save(save_path)

def check_tokens_is_valid(model, placeholder_token, initializer_token):
    print("Check tokens...")
    if placeholder_token in model.tokenizer2.encoder: 
        raise ValueError(f"Word {placeholder_token} exists in tokenizer2. Please select another word.")

    if initializer_token not in model.tokenizer2.encoder:  
        raise ValueError(f"Word {initializer_token} doesn't exist in tokenizer2. Please select another word.")

    if len(model.tokenizer1.encode(placeholder_token)) == 3: 
        raise ValueError(f"Word {placeholder_token} exists in tokenizer1. Please select another word.")

    if len(model.tokenizer1.encode(initializer_token)) != 3: 
        raise ValueError(f"Word {initializer_token} doesn't exists in tokenizer1. Please select another word.")
    print("Selected tokens are correct")

def show_image(image, figsize=(5, 5), cmap=None, title='', xlabel=None, ylabel=None, axis=False):
    plt.figure(figsize=figsize)
    plt.imshow(image, cmap=cmap)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.axis(axis)
    plt.show();

def show_images(images, n_rows=1, title='', figsize=(5, 5), cmap=None, xlabel=None, ylabel=None, axis=False):
    n_cols = len(images) // n_rows
    if n_rows == n_cols == 1:
        show_image(images[0], title=title, figsize=figsize, cmap=cmap, xlabel=xlabel, ylabel=ylabel, axis=axis)
    else:
        fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
        fig.tight_layout(pad=0.0)
        axes = axes.flatten()
        for ax, img in zip(axes, images):
            ax.imshow(img, cmap=cmap)
            ax.set_title(title)
            ax.set_xlabel(xlabel)
            ax.set_ylabel(ylabel)
            ax.axis(axis)
        plt.show();

### Dataset class functions
Text templates and dataset class with small changes from <a href="https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion">diffusers</a> repo:  
  

In [5]:
imagenet_templates_small = [
    "a photo of a {}",
    "a rendering of a {}",
    "a cropped photo of the {}",
    "the photo of a {}",
    "a photo of a clean {}",
    "a photo of a dirty {}",
    "a dark photo of the {}",
    "a photo of my {}",
    "a photo of the cool {}",
    "a close-up photo of a {}",
    "a bright photo of the {}",
    "a cropped photo of a {}",
    "a photo of the {}",
    "a good photo of the {}",
    "a photo of one {}",
    "a close-up photo of the {}",
    "a rendition of the {}",
    "a photo of the clean {}",
    "a rendition of a {}",
    "a photo of a nice {}",
    "a good photo of a {}",
    "a photo of the nice {}",
    "a photo of the small {}",
    "a photo of the weird {}",
    "a photo of the large {}",
    "a photo of a cool {}",
    "a photo of a small {}",
]

imagenet_style_templates_small = [
    "a painting in the style of {}",
    "a rendering in the style of {}",
    "a cropped painting in the style of {}",
    "the painting in the style of {}",
    "a clean painting in the style of {}",
    "a dirty painting in the style of {}",
    "a dark painting in the style of {}",
    "a picture in the style of {}",
    "a cool painting in the style of {}",
    "a close-up painting in the style of {}",
    "a bright painting in the style of {}",
    "a cropped painting in the style of {}",
    "a good painting in the style of {}",
    "a close-up painting in the style of {}",
    "a rendition in the style of {}",
    "a nice painting in the style of {}",
    "a small painting in the style of {}",
    "a weird painting in the style of {}",
    "a large painting in the style of {}",
]

class TextualInversionDataset(Dataset):
    def __init__(
        self,
        data_root,
        placeholder_token,
        img_size=512,
        learnable_property="object", # [object, style]
        flip_p=0.5,
        center_crop=False,
    ):
        self.data_root = data_root
        self.img_size = img_size
        self.placeholder_token = placeholder_token
        self.center_crop = center_crop
        self.flip_p = flip_p

        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
        self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, i):
        example = {}
        image = Image.open(self.image_paths[i])

        if not image.mode == "RGB":
            image = image.convert("RGB")

        placeholder_string = self.placeholder_token
        text = random.choice(self.templates).format(placeholder_string)
        example["text"] = text
        
        img = np.array(image).astype(np.uint8)

        if self.center_crop:
            h, w = img.shape[:2]
            min_side = min(h, w)
            img = img[(h - min_side) // 2 : (h + min_side) // 2, (w - min_side) // 2 : (w + min_side) // 2]

        image = Image.fromarray(img)
        image = image.resize((self.img_size, self.img_size), resample=PIL.Image.Resampling.BICUBIC)

        image = self.flip_transform(image)
        image = np.array(image).astype(np.uint8)
        image = (image / 127.5 - 1.0).astype(np.float32)

        example["image"] = torch.from_numpy(image).permute(2, 0, 1)
        return example

### Define config and create Kandinsky model

In [6]:
config = DictConfig(deepcopy(CONFIG_2_1))

cache_dir = os.path.join(cache_root, "2_1")

config["model_config"]["up"] = False
config["model_config"]["use_fp16"] = False
config["model_config"]["inpainting"] = False
config["model_config"]["cache_text_emb"] = False
config["model_config"]["use_flash_attention"] = False

config["tokenizer_name"] = os.path.join(cache_dir, "text_encoder")
config["text_enc_params"]["model_path"] = os.path.join(cache_dir, "text_encoder")
config["prior"]["clip_mean_std_path"] = os.path.join(cache_dir, "ViT-L-14_stats.th")
config["image_enc_params"]["ckpt_path"] = os.path.join(cache_dir, "movq_final.ckpt")

model_path = os.path.join(cache_dir, "decoder_fp16.ckpt")
prior_path = os.path.join(cache_dir, "prior_fp16.ckpt")

download_models_if_not_exist(task_type=task_type, cache_dir=cache_root)

model = Kandinsky2_1(config, model_path, prior_path, device, task_type=task_type)



making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.


In [7]:
check_tokens_is_valid(model, placeholder_token, initializer_token)

Check tokens...
Selected tokens are correct


In [8]:
# Convert the initializer_token and placeholder_token to tokenizer1 ids
num_added_tokens = model.tokenizer1.add_tokens(placeholder_token)
print(f'Num added tokens: {num_added_tokens}')

t1_init_token_id = model.tokenizer1.encode(initializer_token, add_special_tokens=False)[0]
t1_place_token_id = model.tokenizer1.convert_tokens_to_ids(placeholder_token)

model.text_encoder.model.transformer.resize_token_embeddings(len(model.tokenizer1))

print(f'Initializer ID: {t1_init_token_id} | Placeholder token ID: {t1_place_token_id}')

Num added tokens: 1
Initializer ID: 25550 | Placeholder token ID: 250002


In [9]:
# Initialise new placeholder weights with the embeddings of the initializer token
token_embeds = model.text_encoder.model.transformer.get_input_embeddings().weight.data
token_embeds[t1_place_token_id] = token_embeds[t1_init_token_id]

In [10]:
# Convert the initializer_token, placeholder_token to ids for tokenizer2
# and add placeholder_token to tokenizer2
t2p_index_to_add = len(model.tokenizer2.encoder)
model.tokenizer2.encoder[placeholder_token] = t2p_index_to_add
model.tokenizer2.decoder[t2p_index_to_add] = placeholder_token
model.tokenizer2.cache[placeholder_token] = placeholder_token

t2_place_token_id = model.tokenizer2.encode(placeholder_token)[0]
t2_place_token_str = model.tokenizer2.decode([t2_place_token_id])

print(f'Encode placeholder token: {t2_place_token_id} | Decode placeholder word: {t2_place_token_str}')

Encode placeholder token: 49408 | Decode placeholder word: lastochka


In [11]:
# 1.Convert the initializer_token and placeholder_token to tokenizer2 ids
# 2.Create new embeddings 
# 3.Copy old weights to the new embeddings and initialize new token 
    
t2_init_token_id = model.tokenizer2.encode(initializer_token)[0]
t2_init_token_str = model.tokenizer2.decode([t2_init_token_id])
print(f'Encode initializer token: {t2_init_token_id} | Decode initializer word: {t2_init_token_str}')

old_vocab_size, t2_embed_size = model.clip_model.token_embedding.weight.shape
print(f'T2 old vocab size: {old_vocab_size} | T2 Embed size: {t2_embed_size}')

new_embed = nn.Embedding(old_vocab_size + 1, t2_embed_size).to(device)
new_embed.weight.data[:old_vocab_size, :] = model.clip_model.token_embedding.weight.data.clone()
new_embed.weight.data[t2_place_token_id, :] = new_embed.weight.data[t2_init_token_id, :]

model.clip_model.token_embedding = deepcopy(new_embed)

print(f'T2 new vocab size: {model.clip_model.token_embedding.weight.shape[0]}')

Encode initializer token: 3231 | Decode initializer word: train 
T2 old vocab size: 49408 | T2 Embed size: 768
T2 new vocab size: 49409


In [12]:
## Freeze all except embeddings
model.image_encoder.requires_grad_(False)
model.model.requires_grad_(False)
model.prior.requires_grad_(False)

model.clip_model.token_embedding.requires_grad_(True)
model.clip_model.transformer.requires_grad_(False);

model.text_encoder.model.transformer.get_input_embeddings().requires_grad_(True)
model.text_encoder.model.transformer.embeddings.position_embeddings.requires_grad_(False)
model.text_encoder.model.transformer.embeddings.token_type_embeddings.requires_grad_(False)
model.text_encoder.model.transformer.encoder.requires_grad_(False)
model.text_encoder.model.transformer.pooler.requires_grad_(False);
model.text_encoder.model.LinearTransformation.requires_grad_(False);

In [13]:
# Configure optimizer, dataset and dataloader
optimizer = torch.optim.AdamW(
    list(model.text_encoder.model.transformer.get_input_embeddings().parameters()) +
    list(model.clip_model.token_embedding.parameters()),
    lr=lr,
    betas=(beta1, beta2),
    weight_decay=weight_decay,
    eps=epsilon,
)

dataset = TextualInversionDataset(
    data_root=data_root,
    placeholder_token=placeholder_token,
    img_size=img_size,
    center_crop=False,
)

train_dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)

In [14]:
# Save original embeddings from both models
orig_t1_params = model.text_encoder.model.transformer.get_input_embeddings().weight.data.clone()
orig_t2_params = model.clip_model.token_embedding.weight.data.clone()

In [1]:
weight_dtype = model.model.dtype
model.clip_model.to(weight_dtype)

progress_bar_epochs = tqdm(range(1, epochs + 1))

for epoch in progress_bar_epochs:
    model.text_encoder.train()
    model.clip_model.train()
    for batch in train_dataloader:
        model_kwargs = {}
        # Convert images to latent representation and add noise
        latents = model.image_encoder.encode(batch["image"].to(device=device, dtype=weight_dtype)).detach()
        latents = latents * model.scale

        noise = torch.randn_like(latents)

        timesteps = torch.randint(0, 1000, (batch_size,), device=latents.device)
        timesteps = timesteps.long()

        noisy_latents = add_noise(latents, noise, timesteps).to(weight_dtype)
        
        # Get hidden parameters for both models
        # First Model Parameters
        image_emb = generate_clip_emb(
            model,
            batch["text"][0],
            batch_size=batch_size,
            prior_cf_scale=4,
            prior_steps="5",
            negative_prior_prompt="",
        )
        
        model_kwargs["image_emb"] = image_emb.to(weight_dtype)
        
        # Second Model Parameters
        tokens = model.tokenizer1(
            batch["text"][0],
            padding="max_length",
            truncation=True,
            max_length=77,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt",
        )

        model_kwargs["full_emb"], model_kwargs["pooled_emb"] = model.text_encoder(
            tokens=tokens['input_ids'].long().to(device=device), 
            mask=tokens['attention_mask'].to(device=device),
        )

        model_kwargs["full_emb"] = model_kwargs["full_emb"].to(weight_dtype) 
        model_kwargs["pooled_emb"] = model_kwargs["pooled_emb"].to(weight_dtype) 
        
        # Predict noise obviously
        model_pred = model.model(noisy_latents, timesteps, **model_kwargs)[:, :4]

        loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # We don't need update all embeddings weights. Only new embeddings.
        with torch.no_grad():
            index_no_updates_t1 = torch.arange(len(model.tokenizer1)) != t1_place_token_id
            model.text_encoder.model.transformer.get_input_embeddings().weight[
                index_no_updates_t1
            ] = orig_t1_params[index_no_updates_t1]
            
            index_no_updates_t2 = torch.arange(model.clip_model.token_embedding.weight.shape[0]) != t2_place_token_id
            model.clip_model.token_embedding.weight[
                index_no_updates_t2
            ] = orig_t2_params[index_no_updates_t2]
            
        progress_bar_epochs.set_postfix(**{"loss": loss.cpu().detach().item()})
    
    if epoch % log_embed_frequency == 0:
        embed_save_path = os.path.join(out_folder, f"{epoch}_epoch_embeds.bin")
        save_embeds(model, embed_save_path, placeholder_token, t1_place_token_id, t2_place_token_id)
        
    if log_image_frequency > 0 and (epoch % log_image_frequency == 0):
        images_root = os.path.join(out_folder, "images")
        os.makedirs(images_root, exist_ok=True)
        image_save_path = os.path.join(images_root, f"{epoch}_epoch_images.jpg")
        save_images(model, image_save_path, placeholder_token)
        
        
embed_save_path = os.path.join(out_folder, "learned_embeds.bin")
save_embeds(model, embed_save_path, placeholder_token, t1_place_token_id, t2_place_token_id)

In [None]:
images = model.generate_text2img(
            f'photo of {placeholder_token}',
            num_steps=50, 
            batch_size=2, 
            guidance_scale=7.5,
            h=img_size, 
            w=img_size,
            sampler="ddim_sampler", 
            prior_cf_scale=4,
            prior_steps="8",
        )
show_images(images, n_rows=1, figsize=(15, 15))

In [None]:
images = model.generate_text2img(
            f'Professional high-quality photo of {placeholder_token}. photorealistic, 4k, HQ',
            num_steps=50, 
            batch_size=2, 
            guidance_scale=7.5,
            h=img_size, 
            w=img_size,
            sampler="ddim_sampler", 
            prior_cf_scale=4,
            prior_steps="8",
        )
show_images(images, n_rows=1, figsize=(15, 15))