In [1]:
import itertools
import numpy as np
import os
import random
import torch
import torch.nn.functional as F
from collections import deque
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from itertools import chain
from PIL import Image
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import RandomHorizontalFlip
from transformers import CLIPTextModel, CLIPTokenizer, logging

# Setup
diffusion_model_id = 'runwayml/stable-diffusion-v1-5'
text_encoder_model_id = 'openai/clip-vit-large-patch14'
device = 'cuda'
seed = 1024
logging.set_verbosity_error()

# Multi-Concept Textual Inversion settings
property_name_a = 'tornado'  # Name of first learned property to learn
placeholder_token_a = '<tornado>'  # Token that represents the first new property
initializer_token_a = 'tornado'  # Initial embedding for the first new property
property_name_b = 'ocean'  # Name of second property to learn
placeholder_token_b = '<ocean>'  # Token that represents the second new property
initializer_token_b = 'ocean'  # Initial embedding for the second new property

# Hugging Face access token
token = ''
with open('hugging_face_token.txt', 'r') as secret:
    token = secret.readline().strip()

In [2]:
# Load model components

# Text Encoder + Tokenizer
tokenizer = CLIPTokenizer.from_pretrained(text_encoder_model_id)
text_encoder = CLIPTextModel.from_pretrained(text_encoder_model_id, torch_dtype=torch.float16)
text_encoder.to(device)

# Variational Autoencoder
vae = AutoencoderKL.from_pretrained(
    diffusion_model_id, subfolder='vae', torch_dtype=torch.float16,
    revision='fp16', use_auth_token=token)
vae.to(device)

# U-Net Model
u_net = UNet2DConditionModel.from_pretrained(
    diffusion_model_id, subfolder='unet', torch_dtype=torch.float16,
    revision='fp16', use_auth_token=token)
u_net.to(device)

# Noise Scheduler
noise_scheduler = DDPMScheduler.from_config(diffusion_model_id, subfolder='scheduler', use_auth_token=token)

# Freeze parameters for a model
def freeze_params(params):
    for param in params:
        param.requires_grad = False

# Freeze all pre-trained models except for token embeddings in the text encoder
freeze_params(vae.parameters())
freeze_params(u_net.parameters())
encoder_params_to_freeze = itertools.chain(
        text_encoder.text_model.encoder.parameters(),
        text_encoder.text_model.final_layer_norm.parameters(),
        text_encoder.text_model.embeddings.position_embedding.parameters(),
)
freeze_params(encoder_params_to_freeze)

In [3]:
# Setup tokenizer and text encoder

# Add the placeholder tokens in tokenizer
num_added_tokens = tokenizer.add_tokens([placeholder_token_a, placeholder_token_b])

# Convert the initializer tokens and placeholder tokens to ids
token_ids = tokenizer.encode([initializer_token_a, initializer_token_b], add_special_tokens=False)
initializer_token_id_a = token_ids[0]
initializer_token_id_b = token_ids[1]
placeholder_token_id_a = tokenizer.convert_tokens_to_ids(placeholder_token_a)
placeholder_token_id_b = tokenizer.convert_tokens_to_ids(placeholder_token_b)

# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))

# Initialize the newly added placeholder tokens with the embeddings of the initializer tokens
token_embeds = text_encoder.get_input_embeddings().weight.data
token_embeds[placeholder_token_id_a] = token_embeds[initializer_token_id_a]
token_embeds[placeholder_token_id_b] = token_embeds[initializer_token_id_b]

In [4]:
# Prompt base templates

# Object
object_templates = [
    'a photo of a',
    'a rendering of a',
    'a cropped photo of the',
    'the photo of a',
    'a photo of a clean',
    'a photo of a dirty',
    'a dark photo of the',
    'a photo of my',
    'a photo of the cool',
    'a close-up photo of a',
    'a bright photo of the',
    'a cropped photo of a',
    'a photo of the',
    'a good photo of the',
    'a photo of one',
    'a close-up photo of the',
    'a rendition of the',
    'a photo of the clean',
    'a rendition of a',
    'a photo of a nice',
    'a good photo of a',
    'a photo of the nice',
    'a photo of the small',
    'a photo of the weird',
    'a photo of the large',
    'a photo of a cool',
    'a photo of a small',
]

In [5]:
# Extended prompt templates
prompt_templates = []
with open(f'../data/{property_name_a}_{property_name_b}/prompts.txt', 'r') as prompt_file:
    prompt_templates = [p.strip() for p in prompt_file.readlines()]

In [6]:
# Dataset class
class TextualInversionDataset(Dataset):
    def __init__(
        self,
        data_root,
        templates,
        repeats=10,  # 100
        flip_p=0.5,
    ):
        self.data_root = data_root
        self.flip_p = flip_p
        self.flip_transform = RandomHorizontalFlip(p=self.flip_p)

        # Data settings
        self.image_paths = sorted([os.path.join(self.data_root, file_path)
            for file_path in os.listdir(self.data_root) if not file_path.endswith('.txt')])
        self.num_images = len(self.image_paths)
        self._length = self.num_images * repeats
        self.templates = templates

    def __len__(self):
        return self._length

    def __getitem__(self, i):
        # Get and prepare image
        image = Image.open(self.image_paths[i % self.num_images])
        image = self.flip_transform(image)
        image = np.array(image).astype(np.uint8)
        image = (image / 127.5 - 1.0).astype(np.float16)

        # Get text prompt
        text = f'{random.choice(object_templates)} {self.templates[i % self.num_images]}'

        # Create example
        example = {}
        example['input_prompt'] = text
        example['pixel_values'] = torch.from_numpy(image).permute(2, 0, 1).to(device)
        return example

In [7]:
# Auxiliary functions

# Encode input prompt
def encode_prompt(prompt):
    text_inputs = tokenizer(
        prompt, padding='max_length', max_length=tokenizer.model_max_length,
        truncation=True, return_tensors='pt')
    text_embeddings = text_encoder(text_inputs.input_ids.to(device))[0]
    return text_embeddings

In [8]:
# Save model to a pytorch file
def save_model(path_dir, filename):
    if not os.path.isdir(path_dir):
        os.makedirs(path_dir)
    learned_embeddings = text_encoder.get_input_embeddings().weight
    torch.save({placeholder_token_a: learned_embeddings[placeholder_token_id_a].detach().cpu(),
                placeholder_token_b: learned_embeddings[placeholder_token_id_b].detach().cpu()},
                os.path.join(path_dir, filename))

# Model training
def train_model(data_root, optimizer, num_train_epochs=50, batch_size=1):
    # Initialize dataset
    train_dataset = TextualInversionDataset(data_root, prompt_templates)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Learning rate scheduler
    max_train_steps = num_train_epochs * len(train_dataloader)
    lr_scheduler = LambdaLR(optimizer, lambda _: 1)

    # Training loop
    print('***** Running Training *****')
    print(f'  Num. Examples = {len(train_dataset)}')
    print(f'  Num. Epochs = {num_train_epochs}')
    loss_queue = deque(maxlen=20)
    for epoch in range(num_train_epochs):
        # Save current model
        if not (epoch % 10):
            save_model(f'saved_models/mcti/{property_name_a}_{property_name_b}',
            f'{property_name_a}_{property_name_b}_{epoch // 10}.pt')

        # Train for another epoch
        text_encoder.train()
        for step, batch in enumerate(train_dataloader):
            # Convert images to latent space
            latents = vae.encode(batch['pixel_values']).latent_dist.sample()
            latents *= 0.18215

            # Sample noise that we'll add to the latents
            noise = torch.randn(latents.shape, dtype=torch.float16).to(latents.device)
            bsz = latents.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
            ).long()

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the text embedding for conditioning
            encoder_hidden_states = encode_prompt(batch['input_prompt'])

            # Predict the noise residual
            noise_pred = u_net(noisy_latents, timesteps, encoder_hidden_states).sample

            # Backwards pass
            loss = F.mse_loss(noise_pred, noise, reduction='none').mean([1, 2, 3]).mean()
            loss.backward()

            # Zero out the gradients for all token embeddings except the placeholder tokens
            grads = text_encoder.get_input_embeddings().weight.grad
            index_grads_to_zero = torch.arange(len(tokenizer)) < placeholder_token_id_a
            grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)

            # Optimizer pass
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            # Print logs
            current_loss = round(loss.detach().item(), 4)
            loss_queue.append(current_loss)
            recent_loss = round(sum(loss_queue) / len(loss_queue), 4)
            print(f'loss: {current_loss}, last_20: {recent_loss}, lr: {lr_scheduler.get_last_lr()[0]}, '
                  f'epoch: {epoch + 1}/{num_train_epochs}, step: {step + 1}/{len(train_dataloader)}')
    print('***** Training Completed *****')

In [9]:
# Initialize optimizer
optimizer = torch.optim.AdamW(text_encoder.get_input_embeddings().parameters(), lr=1e-4, eps=1e-7)
#optimizer = torch.optim.RMSprop(text_encoder.get_input_embeddings().parameters(), lr=1e-3, eps=1e-7)

# Training
data_folder = f'../data/{property_name_a}_{property_name_b}'
train_model(data_folder, optimizer, num_train_epochs=40)

# Save final model
save_model(f'saved_models/mcti/{property_name_a}_{property_name_b}', f'{property_name_a}_{property_name_b}_final.pt')

***** Running Training *****
  Num. Examples = 80
  Num. Epochs = 40
loss: 0.0173, last_20: 0.0173, lr: 0.0001, epoch: 1/40, step: 1/80
loss: 0.0239, last_20: 0.0206, lr: 0.0001, epoch: 1/40, step: 2/80
loss: 0.0272, last_20: 0.0228, lr: 0.0001, epoch: 1/40, step: 3/80
loss: 0.0234, last_20: 0.023, lr: 0.0001, epoch: 1/40, step: 4/80
loss: 0.0396, last_20: 0.0263, lr: 0.0001, epoch: 1/40, step: 5/80
loss: 0.0144, last_20: 0.0243, lr: 0.0001, epoch: 1/40, step: 6/80
loss: 0.5229, last_20: 0.0955, lr: 0.0001, epoch: 1/40, step: 7/80
loss: 0.0116, last_20: 0.085, lr: 0.0001, epoch: 1/40, step: 8/80
loss: 0.2158, last_20: 0.0996, lr: 0.0001, epoch: 1/40, step: 9/80
loss: 0.1498, last_20: 0.1046, lr: 0.0001, epoch: 1/40, step: 10/80
loss: 0.3965, last_20: 0.1311, lr: 0.0001, epoch: 1/40, step: 11/80
loss: 0.937, last_20: 0.1983, lr: 0.0001, epoch: 1/40, step: 12/80
loss: 0.0184, last_20: 0.1844, lr: 0.0001, epoch: 1/40, step: 13/80
loss: 0.0368, last_20: 0.1739, lr: 0.0001, epoch: 1/40, ste