In [1]:
import itertools
import numpy as np
import os
import random
import torch
import torch.nn.functional as F
from collections import deque
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from itertools import chain
from PIL import Image
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import RandomHorizontalFlip
from transformers import CLIPTextModel, CLIPTokenizer

# Setup
diffusion_model_id = 'runwayml/stable-diffusion-v1-5'
text_encoder_model_id = 'openai/clip-vit-large-patch14'
device = 'cuda'
seed = 1024

# Hugging Face access token
token = ''
with open('hugging_face_token.txt', 'r') as secret:
    token = secret.readline().strip()

In [2]:
# Load model components

# Text Encoder + Tokenizer
tokenizer = CLIPTokenizer.from_pretrained(text_encoder_model_id)
text_encoder = CLIPTextModel.from_pretrained(text_encoder_model_id, torch_dtype=torch.float16)
text_encoder.to(device)

# Variational Autoencoder
vae = AutoencoderKL.from_pretrained(
    diffusion_model_id, subfolder='vae', torch_dtype=torch.float16,
    revision='fp16', use_auth_token=token)
vae.to(device)

# U-Net Model
u_net = UNet2DConditionModel.from_pretrained(
    diffusion_model_id, subfolder='unet', torch_dtype=torch.float16,
    revision='fp16', use_auth_token=token)
u_net.to(device)

# Noise Scheduler
noise_scheduler = DDPMScheduler.from_config(diffusion_model_id, subfolder='scheduler', use_auth_token=token)

# Freeze parameters for a model
def freeze_params(params):
    for param in params:
        param.requires_grad = False

# Freeze all pre-trained models except for token embeddings in the text encoder
freeze_params(vae.parameters())
freeze_params(u_net.parameters())
encoder_params_to_freeze = itertools.chain(
        text_encoder.text_model.encoder.parameters(),
        text_encoder.text_model.final_layer_norm.parameters(),
        text_encoder.text_model.embeddings.position_embedding.parameters(),
)
freeze_params(encoder_params_to_freeze)

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.4.self_attn.q_proj.bias', 'vision_model.encoder.layers.11.layer_norm2.weight', 'vision_model.encoder.layers.0.mlp.fc1.weight', 'vision_model.encoder.layers.13.self_attn.q_proj.bias', 'vision_model.encoder.layers.3.self_attn.k_proj.weight', 'vision_model.encoder.layers.0.mlp.fc1.bias', 'vision_model.encoder.layers.18.self_attn.q_proj.weight', 'vision_model.encoder.layers.14.self_attn.v_proj.weight', 'vision_model.encoder.layers.17.layer_norm2.weight', 'vision_model.encoder.layers.6.mlp.fc2.bias', 'vision_model.encoder.layers.16.mlp.fc1.bias', 'vision_model.encoder.layers.12.mlp.fc1.weight', 'vision_model.encoder.layers.2.self_attn.q_proj.weight', 'vision_model.encoder.layers.6.self_attn.v_proj.bias', 'vision_model.encoder.layers.9.layer_norm1.bias', 'vision_model.encoder.layers.23.self_attn.out_proj.weight', 'vision_model.encoder.layers.21.la

In [3]:
# Setup tokenizer and text encoder

# Special tokens
initializer_token = 'toy'  # Initial embedding for new property
placeholder_token = '<object>'  # Token that represents new property

# Add the placeholder token in tokenizer
num_added_tokens = tokenizer.add_tokens(placeholder_token)

# Convert the initializer_token, placeholder_token to ids
token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
initializer_token_id = token_ids[0]
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)

# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))

# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]

In [4]:
# Prompt templates
object_templates = [
    'a photo of a {}',
    'a rendering of a {}',
    'a cropped photo of the {}',
    'the photo of a {}',
    'a photo of a clean {}',
    'a photo of a dirty {}',
    'a dark photo of the {}',
    'a photo of my {}',
    'a photo of the cool {}',
    'a close-up photo of a {}',
    'a bright photo of the {}',
    'a cropped photo of a {}',
    'a photo of the {}',
    'a good photo of the {}',
    'a photo of one {}',
    'a close-up photo of the {}',
    'a rendition of the {}',
    'a photo of the clean {}',
    'a rendition of a {}',
    'a photo of a nice {}',
    'a good photo of a {}',
    'a photo of the nice {}',
    'a photo of the small {}',
    'a photo of the weird {}',
    'a photo of the large {}',
    'a photo of a cool {}',
    'a photo of a small {}',
]

In [5]:
# Dataset class
class TextualInversionDataset(Dataset):
    def __init__(
        self,
        data_root,
        learnable_property='object',
        repeats=10,  # 100
        flip_p=0.5,
        set='train',
        placeholder_token='<object>',
    ):
        self.data_root = data_root
        self.learnable_property = learnable_property
        self.placeholder_token = placeholder_token
        self.flip_p = flip_p
        self.flip_transform = RandomHorizontalFlip(p=self.flip_p)

        # Data settings
        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
        self.num_images = len(self.image_paths)
        self._length = self.num_images
        if set == 'train':
            self._length = self.num_images * repeats
        self.templates = object_templates

    def __len__(self):
        return self._length

    def __getitem__(self, i):
        # Get and prepare image
        image = Image.open(self.image_paths[i % self.num_images])
        image = self.flip_transform(image)
        image = np.array(image).astype(np.uint8)
        image = (image / 127.5 - 1.0).astype(np.float16)

        # Get text prompt
        text = random.choice(self.templates).format(self.placeholder_token)

        # Create example
        example = {}
        example['input_prompt'] = text
        example['pixel_values'] = torch.from_numpy(image).permute(2, 0, 1).to(device)
        return example

In [6]:
# Auxiliary functions

# Encode input prompt
def encode_prompt(prompt):
    text_inputs = tokenizer(
        prompt, padding='max_length', max_length=tokenizer.model_max_length,
        truncation=True, return_tensors='pt')
    text_embeddings = text_encoder(text_inputs.input_ids.to(device))[0]
    return text_embeddings

In [7]:
# Save model to a pytorch file
def save_model(path_dir, filename):
    if not os.path.isdir(path_dir):
        os.makedirs(path_dir)
    learned_embeddings = text_encoder.get_input_embeddings().weight[placeholder_token_id]
    torch.save({placeholder_token: learned_embeddings.detach().cpu()}, os.path.join(path_dir, filename))

# Model training
def train_model(property, data_root, optimizer, num_train_epochs=50, batch_size=1):
    # Initialize dataset
    train_dataset = TextualInversionDataset(data_root)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Learning rate scheduler
    max_train_steps = num_train_epochs * len(train_dataloader)
    lr_scheduler = LambdaLR(optimizer, lambda _: 1)

    # Training loop
    print('***** Running Training *****')
    print(f'  Num. Examples = {len(train_dataset)}')
    print(f'  Num. Epochs = {num_train_epochs}')
    loss_queue = deque(maxlen=10)
    for epoch in range(num_train_epochs):
        # Save current model
        if not (epoch % 10):
            save_model(f'saved_models/{property}', f'{property}_{epoch // 10}.pt')

        # Train for another epoch
        text_encoder.train()
        for step, batch in enumerate(train_dataloader):
            # Convert images to latent space
            latents = vae.encode(batch['pixel_values']).latent_dist.sample()
            latents *= 0.18215

            # Sample noise that we'll add to the latents
            noise = torch.randn(latents.shape, dtype=torch.float16).to(latents.device)
            bsz = latents.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
            ).long()

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the text embedding for conditioning
            encoder_hidden_states = encode_prompt(batch['input_prompt'])

            # Predict the noise residual
            noise_pred = u_net(noisy_latents, timesteps, encoder_hidden_states).sample

            # Backwards pass
            loss = F.mse_loss(noise_pred, noise, reduction='none').mean([1, 2, 3]).mean()
            loss.backward()

            # Zero out the gradients for all token embeddings except the placeholder token
            grads = text_encoder.get_input_embeddings().weight.grad
            index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
            grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)

            # Optimizer pass
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            # Print logs
            current_loss = round(loss.detach().item(), 4)
            loss_queue.append(current_loss)
            recent_loss = round(sum(loss_queue) / 10, 4)
            print(f'loss: {current_loss}, last_10: {recent_loss}, lr: {lr_scheduler.get_last_lr()[0]}, '
                  f'epoch: {epoch + 1}, step: {step + 1}/{len(train_dataloader)}')
    print('***** Training Completed *****')

In [8]:
# Initialize optimizer
# TODO: Tune all and choose best one
#optimizer = torch.optim.SGD(text_encoder.get_input_embeddings().parameters(), lr=0.001)
optimizer = torch.optim.AdamW(text_encoder.get_input_embeddings().parameters(), lr=0.01, eps=1e-07)
#optimizer = torch.optim.Adam(text_encoder.get_input_embeddings().parameters(), lr=0.001, eps=1e-08)
#optimizer = torch.optim.Adagrad(text_encoder.get_input_embeddings().parameters(), lr=0.01, eps=1e-10)
#optimizer = torch.optim.Adadelta(text_encoder.get_input_embeddings().parameters(), lr=1.0, eps=1e-06)
#optimizer = torch.optim.RMSprop(text_encoder.get_input_embeddings().parameters(), lr=0.01, eps=1e-06)

# Training
property = 'cat_toy'
data_folder = f'../data/{property}'
train_model(property, data_folder, optimizer, num_train_epochs=5)

# Save final model
save_model(f'saved_models/{property}', f'{property}_final.pt')

***** Running Training *****
  Num. Examples = 60
  Num. Epochs = 5
loss: 0.044, last_10: 0.0044, lr: 0.01, epoch: 1, step: 1/60
loss: 0.3162, last_10: 0.036, lr: 0.01, epoch: 1, step: 2/60
loss: 0.05, last_10: 0.041, lr: 0.01, epoch: 1, step: 3/60
loss: 0.0212, last_10: 0.0431, lr: 0.01, epoch: 1, step: 4/60
loss: 0.1116, last_10: 0.0543, lr: 0.01, epoch: 1, step: 5/60
loss: 0.0041, last_10: 0.0547, lr: 0.01, epoch: 1, step: 6/60
loss: 0.322, last_10: 0.0869, lr: 0.01, epoch: 1, step: 7/60
loss: 0.3059, last_10: 0.1175, lr: 0.01, epoch: 1, step: 8/60
loss: 0.0262, last_10: 0.1201, lr: 0.01, epoch: 1, step: 9/60
loss: 0.01, last_10: 0.1211, lr: 0.01, epoch: 1, step: 10/60
loss: 0.0032, last_10: 0.117, lr: 0.01, epoch: 1, step: 11/60
loss: 0.2585, last_10: 0.1113, lr: 0.01, epoch: 1, step: 12/60
loss: 0.0111, last_10: 0.1074, lr: 0.01, epoch: 1, step: 13/60
loss: 0.0332, last_10: 0.1086, lr: 0.01, epoch: 1, step: 14/60
loss: 0.186, last_10: 0.116, lr: 0.01, epoch: 1, step: 15/60
loss: 0