In [None]:
#library for data processing
import pandas as pd
from bs4 import BeautifulSoup
import string
import re
import itertools
import io
import json
import os
import sys
import ast
import time
import requests
import random
import math
import inspect

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import PIL
import scipy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import einops
import transformers
import diffusers
import accelerate
#import clip
import torchvision.transforms.functional as TF

#from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer

import datasets
from datasets import load_dataset

import pickle

accelerator=accelerate.Accelerator()
device=accelerator.device

In [None]:
#login using your token
from huggingface_hub import login
login()

In [None]:
username="<YOUR USERNAME>"
repo_name="<YOUR REPO NAME>"

In [None]:
#resest tokenizer and text_encoder
text_encoder=transformers.CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder")
tokenizer=transformers.CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="tokenizer")
#save tokenizer and text_encoder to your repo
tokenizer.push_to_hub(repo_name)
text_encoder.push_to_hub(repo_name)

In [None]:
#prompt templates for training 
imagenet_templates_small = [
    "a photo of a {}",
    "a rendering of a {}",
    "a cropped photo of the {}",
    "the photo of a {}",
    "a photo of a clean {}",
    "a photo of a dirty {}",
    "a dark photo of the {}",
    "a photo of my {}",
    "a photo of the cool {}",
    "a close-up photo of a {}",
    "a bright photo of the {}",
    "a cropped photo of a {}",
    "a photo of the {}",
    "a good photo of the {}",
    "a photo of one {}",
    "a close-up photo of the {}",
    "a rendition of the {}",
    "a photo of the clean {}",
    "a rendition of a {}",
    "a photo of a nice {}",
    "a good photo of a {}",
    "a photo of the nice {}",
    "a photo of the small {}",
    "a photo of the weird {}",
    "a photo of the large {}",
    "a photo of a cool {}",
    "a photo of a small {}",
]

imagenet_style_templates_small = [
    "a painting in the style of {}",
    "a rendering in the style of {}",
    "a cropped painting in the style of {}",
    "the painting in the style of {}",
    "a clean painting in the style of {}",
    "a dirty painting in the style of {}",
    "a dark painting in the style of {}",
    "a picture in the style of {}",
    "a cool painting in the style of {}",
    "a close-up painting in the style of {}",
    "a bright painting in the style of {}",
    "a cropped painting in the style of {}",
    "a good painting in the style of {}",
    "a close-up painting in the style of {}",
    "a rendition in the style of {}",
    "a nice painting in the style of {}",
    "a small painting in the style of {}",
    "a weird painting in the style of {}",
    "a large painting in the style of {}",
]

In [None]:
#Textual Inversion Dataset
class TextualInversionDataset(torch.utils.data.Dataset):
  def __init__(self, tokenizer, text_encoder, images, initial_token, placeholder_token, learnable_property="object", size=512, repeats=100, 
               interpolation="bicubic", flip_p=0.5, center_crop=False):
    super(TextualInversionDataset, self).__init__()
    assert learnable_property in ['object', 'style'], 'Learnable Property should be either "object" or "style"'
    #settings
    self.learnable_property=learnable_property
    self.size=size
    self.repeats=repeats
    self.flip_p=flip_p
    self.center_crop=center_crop

    #images preprocessing: assume input image is tensor
    self.center_crop_transform=torchvision.transforms.CenterCrop(self.size)
    self.image_transforms=torchvision.transforms.Compose([
        torchvision.transforms.Resize(self.size, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
        torchvision.transforms.RandomHorizontalFlip(p=self.flip_p),
    ])

    #raw input to data
    if len(images.size())==3:
      images=images.unsqueeze(dim=0)
    self.images=images
    self.data=torch.cat([images for _ in range(repeats)], dim=0)
    self.n_images=len(self.data)

    #inversion tokens
    self.initial_token=initial_token
    self.placeholder_token=placeholder_token

    #use CLIP tokenizer and text encoder
    self.tokenizer=tokenizer
    self.text_encoder=text_encoder
    self.text_encoder=accelerator.prepare(self.text_encoder)

    self.templates=imagenet_templates_small if self.learnable_property=='object' else imagenet_style_templates_small
  
  def __len__(self):
    #consider repetition of the whole dataset
    return self.n_images
  
  def __getitem__(self, idx):
    example={}
    #get image and text prompt
    image=self.data[idx]
    random_template=random.choice(self.templates).format(self.placeholder_token)
    
    #get prompt tokenized input ids
    tokenized=self.tokenizer(random_template, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, return_tensors="pt")
    input_ids, attention_mask=tokenized.input_ids.squeeze(dim=0), tokenized.attention_mask.squeeze(dim=0) #assumed that input ids is for a single prompt will be in size of (1,77)
    
    #get image pixel values
    pixel_values=self.image_transforms(self.center_crop_transform(image)) if self.center_crop else self.image_transforms(image)

    return input_ids, pixel_values

In [None]:
#Inversion Trainer
class InversionTrainer():
  def __init__(self):
    #frozen models for inversion training
    self.autoencoder=diffusers.AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="vae")
    self.cn_unet=diffusers.UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="unet")

    #save to my repo model to get updates
    self.tokenizer=transformers.CLIPTokenizer.from_pretrained('{:s}/{:s}'.format(username, repo_name))

    #train the pretrained one and replace the embeddings of my text_encoder with the trained encoder's embeddings
    self.trainable_text_encoder=transformers.CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder")
    self.text_encoder=transformers.CLIPTextModel.from_pretrained('{:s}/{:s}'.format(username, repo_name))

    #DDPM scheduler to get timesteps & corruption during training.
    self.train_scheduler=diffusers.DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler")

    #accelerator prepare
    self.autoencoder, self.cn_unet, self.text_encoder, self.trainable_text_encoder=accelerator.prepare(self.autoencoder, self.cn_unet, self.text_encoder, self.trainable_text_encoder)
  
  def freeze_module(self, module):
    for param in module.parameters():
      param.requires_grad=False
    return
  
  def unfreeze_module(self, module):
    for param in module.parameters():
      param.requires_grad=True
    return
  
  def freeze_all(self):
    #freeze except token embeddings and inversion module
    self.freeze_module(self.trainable_text_encoder.text_model.encoder)
    self.freeze_module(self.trainable_text_encoder.text_model.embeddings.position_embedding)
    self.freeze_module(self.trainable_text_encoder.text_model.final_layer_norm)

    #freeze other modules
    self.freeze_module(self.autoencoder)
    self.freeze_module(self.cn_unet)
    return
  
  def unfreeze_all(self):
    self.unfreeze_module(self.trainable_text_encoder)
    self.unfreeze_module(self.autoencoder)
    self.unfreeze_module(self.cn_unet)
    return
  
  def check_embedding_alignment(self):
    #check embeddings of first 49408 embeddings
    embeddings=self.text_encoder.text_model.embeddings.token_embedding.weight.data

    #compared to pretrained text encoder
    target_text_encoder=transformers.CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device)
    target_embeddings=target_text_encoder.get_input_embeddings().weight.data

    wrong_indices=[]
    #check default 49408 embeddings.
    for idx in range(49408):
      emb=embeddings[idx]
      target_emb=target_embeddings[idx]
      if torch.equal(emb, target_emb)==False:
        wrong_indices.append(idx)
    if len(wrong_indices)>0:
      print("{:d} default embeddings doesn't match".format(len(wrong_indices)))
      return False
    else:
      print("All default embeddings match")
      return True
  
  def plot_train_logs(self, train_logs):
    plt.figure(figsize=(10,5))
    #plotting loss history
    loss_history=train_logs['loss_history']
    x=np.arange(1, len(loss_history)+1, 1)
    plt.plot(x, loss_history)
    plt.xlabel("Update Steps")
    plt.ylabel("Loss")
    plt.show()
    return
    
  def train(self, images, learnable_property, initial_token, placeholder_token, lr=5e-4, scale_lr=True, max_train_steps=2000, train_batch_size=4, gradient_checkpointing=True, seed=42, save=True):
    #inherit hyperparams from Textual Inversion paper by default
    #inversion training learns the CLIP text embeddings of the placeholder token representing the style image
    inversion_train_logs={
        'config': {
            'lr': lr,
            'scale_lr': scale_lr,
            'max_train_steps': max_train_steps,
            'train_batch_size': train_batch_size,
            'gradient_accumulation_steps': accelerator.gradient_accumulation_steps,
            'gradient_checkpointing': gradient_checkpointing,
            'mixed_precision': accelerator.mixed_precision,
            'seed': seed
        },
        'loss_history': [],
    }

    #expand images if input is a single image.
    if len(images.size())==3:
      images=images.unsqueeze(dim=0)

    #update tokenizer
    print("Initial Tokenizer Length: {:d}".format(len(self.tokenizer)))
    #add placeholder token and check that it doesn't exist inside the current dictionary.
    num_added_tokens = self.tokenizer.add_tokens(placeholder_token)
    print("No. of added tokens: {:d}".format(num_added_tokens))
    if num_added_tokens == 0:
      raise ValueError(
          f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
          " `placeholder_token` that is not already in the tokenizer."
      )
    placeholder_token_id = self.tokenizer.convert_tokens_to_ids(placeholder_token)
    print("Placeholder Token: '{:s}', Placeholder Token ID: {:d}".format(placeholder_token, placeholder_token_id))
    
    #expand the size of embeddings of text encoder
    self.trainable_text_encoder.resize_token_embeddings(len(self.tokenizer))
    self.text_encoder.resize_token_embeddings(len(self.tokenizer))

    #check initial token
    if initial_token!=None:
      token_ids = self.tokenizer.encode(initial_token, add_special_tokens=False)
      # Check if initializer_token is a single token or a sequence of tokens
      if len(token_ids) > 1:
          raise ValueError("The initializer token must be a single token.")
      initial_token_id=token_ids[0]
      print("Initial Token: '{:s}', Initial Token ID: {:d}".format(initial_token, initial_token_id))
      token_embeds = self.text_encoder.get_input_embeddings().weight.data
      token_embeds[placeholder_token_id] = token_embeds[initial_token_id]

    print("Final Tokenizer Length: {:d}".format(len(self.tokenizer)))

    #setup inversion dataset
    inversion_dataset=TextualInversionDataset(self.tokenizer, self.trainable_text_encoder, images, initial_token, placeholder_token, learnable_property=learnable_property)

    #set up train loaders and DDPM scheduler.
    train_loader=torch.utils.data.DataLoader(inversion_dataset, batch_size=train_batch_size, shuffle=True)

    #try optimizing only the placeholder embedding
    optimizer=optim.AdamW(self.trainable_text_encoder.text_model.embeddings.token_embedding.parameters(), lr=lr)

    #using huggingface accelerate
    #1. Set up accelerator with gradient_accumulation_steps and mixed_precision using accelerate.Accelerator()
    #2. enable/disable gradient checkpointing depending on available memory
    #3. scale lr => to create effective lr from base_lr => learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes)
    #effective lr varies w.r.t effetive batch size. 
    #4. prepare using accelerator.prepare(models, optimizer, train_loader... etc)
    #5. make prepared modules to have accelerator's mixed precision (ex) torch.float16)
    #6. compute steps per epoch and total epochs => steps_per_epoch: len(train_loader)/gradient_accumulation_steps, num_train_epochs=max_train_steps/steps_per_epoch. (batch update steps)
    #len(train_loader): # of batches inside a train_loader. 
    #7. during training loop: use with accelerator.accumulate(model) to use gradient accumulation

    #print(accelerator.num_processes) #this is determined from gradient accumulation steps & GPU settings. => single GPU and single grad acc steps  => 1 process

    #gradient checkpointing: method to reduce memory at a cost of more recomputations
    if gradient_checkpointing:
      self.trainable_text_encoder.gradient_checkpointing_enable()
      self.cn_unet.enable_gradient_checkpointing()
    
    if scale_lr:
      #effective learning rate => b.c. loss is averaged w.r.t batch size
      lr = (lr * accelerator.gradient_accumulation_steps * train_batch_size * accelerator.num_processes)

    #optimizer, train_data_loader subject to training
    optimizer, train_loader = accelerator.prepare(optimizer, train_loader)
    
    #freeze all modules except text_encoder's token embeddings
    self.freeze_all()

    self.autoencoder.eval()
    self.cn_unet.train()

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_loader) / accelerator.gradient_accumulation_steps)
    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
    total_batch_size = train_batch_size * accelerator.num_processes * accelerator.gradient_accumulation_steps
    print("Dataset Size: {:d}, Update Steps per Epoch: {:d}, Train Epochs: {:d}".format(len(inversion_dataset), num_update_steps_per_epoch, num_train_epochs))

    mse_loss=nn.MSELoss(reduction="none")

    pbar=tqdm(desc="Textual Inversion Training", total=max_train_steps)
    global_step = 0
    for epoch in range(1, num_train_epochs+1):
      self.trainable_text_encoder.train()
      for batch_step, batch_data in enumerate(train_loader):
        with accelerator.accumulate(self.trainable_text_encoder):
          input_ids, pixel_values=batch_data
          #if global_step==0:
          #  print(input_ids)
          #print(input_ids.size(), attention_mask.size(), pixel_values.size())

          #compute latents from pixel_values using autoencoder.encode
          init_latents=self.autoencoder.encode(pixel_values).latent_dist.sample().detach() 
          init_latents=init_latents*0.18215 #z_0
          eps=torch.randn_like(init_latents).to(device)
          timesteps=torch.randint(low=0, high=self.train_scheduler.num_train_timesteps, size=[init_latents.size(0)]).long().to(device)
          #obtaining z_t from z_0 through corruption.
          noisy_latents = self.train_scheduler.add_noise(init_latents, eps, timesteps)

          prompt_embeddings=self.trainable_text_encoder(input_ids).last_hidden_state

          eps_pred=self.cn_unet(sample=noisy_latents, timestep=timesteps, encoder_hidden_states=prompt_embeddings).sample

          #train the text prompt embedding s.t. pretrained SD UNet appropriately estimates noise from the prompt embedding and latent from the image
          loss=mse_loss(eps_pred, eps).mean([1,2,3]).mean()
          inversion_train_logs['loss_history'].append(loss.item())
          accelerator.backward(loss)

          # Zero out the gradients for all token embeddings except the newly added
          # embeddings for the concept, as we only want to optimize the concept embeddings
          if accelerator.num_processes > 1:
            #for multiple GPUs
            emb_grads = self.trainable_text_encoder.module.get_input_embeddings().weight.grad
            index_grads_to_zero = torch.arange(len(self.tokenizer)) != placeholder_token_id #ids other than placeholder token.
            self.trainable_text_encoder.module.get_input_embeddings().weight.grad.data[index_grads_to_zero, :] = emb_grads.data[index_grads_to_zero, :].fill_(0)
          else:
            emb_grads = self.trainable_text_encoder.get_input_embeddings().weight.grad
            index_grads_to_zero = torch.arange(len(self.tokenizer)) != placeholder_token_id #ids other than placeholder token.
            self.trainable_text_encoder.get_input_embeddings().weight.grad.data[index_grads_to_zero, :] = emb_grads.data[index_grads_to_zero, :].fill_(0)
          #print("Step: {:d}".format(global_step))
          #print(emb_grads[self.inversion_dataset.placeholder_token_id])
          #check that embedding has gradient
          #print("Embedding Gradient: mean={:.8f}, std={:.8f}".format(torch.mean(emb_grads).item(), torch.std(emb_grads).item()))

          # Get the index for tokens that we want to zero the grads for

          #update only the placeholder token embedding from the token embeddings
          optimizer.step()
          optimizer.zero_grad()

        #progress if gradients sync well
        if accelerator.sync_gradients:
          pbar.update(1)
          global_step += 1
          #self.check_embedding_alignment()
        if global_step >= max_train_steps:
          break
    pbar.close()

    #plug the placeholder embedding of trainable into actual text_encoder
    placeholder_embedding=self.trainable_text_encoder.get_input_embeddings().weight.data[placeholder_token_id]
    self.text_encoder.get_input_embeddings().weight.data[placeholder_token_id]=placeholder_embedding

    matches=self.check_embedding_alignment()
    if matches:
      #save tokenizer and text_encoder do hf repository
      self.tokenizer.push_to_hub(repo_name)
      self.text_encoder.push_to_hub(repo_name)
      print("Push Success")

    #plot train loss
    self.plot_train_logs(inversion_train_logs)
    return inversion_train_logs

In [None]:
#Loading Images
preprocess=torchvision.transforms.Compose([
    torchvision.transforms.Resize((768,768)),
    torchvision.transforms.ToTensor()
])

def open_images(learnable_property, name):
  img_list=[]
  idx=1
  while True:
    folder=name
    img_name=name+"_{:d}".format(idx)
    dir=os.path.join(".", 'images/{:s}/{:s}.jpg'.format(folder, img_name))
    if os.path.exists(dir):
      img=preprocess(PIL.Image.open(dir)).to(device)
      img_list.append(img.unsqueeze(dim=0))
      idx+=1
    else:
      break
  images=torch.cat(img_list, dim=0)
  return images

def show_images(images):
  if len(images.size())==3:
    images=images.unsqueeze(dim=0)
  n_imgs=images.size(0)
  n_cols=4
  n_rows=math.ceil(n_imgs/n_cols)
  width=20
  height=5*n_rows
  plt.figure(figsize=(width, height))
  for idx, image in enumerate(images):
    plt.subplot(n_rows, n_cols, idx+1)
    plt.imshow(image.permute(1,2,0).detach().cpu().numpy())
  plt.show()

#test image set
illustration_images=open_images("style", "3d_illustration")

In [None]:
#Inversion Training of 3D style
images=illustration_images
initial_token="illustration"
placeholder_token="<3d_illustration>"
learnable_property="style"

#2k train steps take about 25min.s w/ A100 GPU
#5k steps will take about 60min.s w/ A100 GPU
train_logs=inversion_trainer.train(images=images, initial_token=initial_token, placeholder_token=placeholder_token, learnable_property=learnable_property,
                                   lr=5e-3, max_train_steps=5000, train_batch_size=4)

In [None]:
#Inference Models
pipeline=diffusers.StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1")
pipeline.tokenizer=transformers.CLIPTokenizer.from_pretrained('{:s}/{:s}'.format(username, repo_name))
pipeline.text_encoder=transformers.CLIPTextModel.from_pretrained('{:s}/{:s}'.format(username, repo_name))

pipeline.enable_model_cpu_offload()
pipeline.enable_attention_slicing()
pipeline.enable_xformers_memory_efficient_attention()

In [None]:
#Inference
prompt="a desk at a office with computers in the style of <3d_illustration>"

image=pipeline(prompt=prompt).images[0]

plt.figure()
plt.imshow(image)
plt.show()