In [None]:
#library for data processing
import pandas as pd
from bs4 import BeautifulSoup
import string
import re
import itertools
import io
import json
import os
import sys
import ast
import time
import requests

import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import PIL
import scipy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import transformers
import diffusers
import accelerate
import clip
import torchvision.transforms.functional as TF

#from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer

import datasets
from datasets import load_dataset

import pickle
import gym

base_dir='/content/drive/My Drive/Colab Notebooks/CV'

device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
cpu=torch.device("cpu")

In [None]:
class Clip(nn.Module):
  def __init__(self, model):
    super(Clip, self).__init__()
    self.model = model
    self.trans = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

  def forward(self, x, y):
    x = TF.resize(x, (224, 224), interpolation=TF.InterpolationMode.BICUBIC)
    x = self.trans(x)

    image_features = self.model.encode_image(x)
    image_features = image_features / image_features.norm(dim=1, keepdim=True)
    logits_per_image = 100 * image_features @ y.t()
    return -1 * logits_per_image

  def encode(self, x):
    x = TF.resize(x, (224, 224), interpolation=TF.InterpolationMode.BICUBIC)
    x = self.trans(x)

    image_features = self.model.encode_image(x)
    image_features = image_features / image_features.norm(dim=1, keepdim=True)

    return image_features

In [None]:
class UniversalGuidancePipeline(nn.Module):
  def __init__(self, scheduler_type):
    super(UniversalGuidancePipeline, self).__init__()
    assert scheduler_type in ['ddim', 'pndm'], "Invalid Scheduler Type, should be 'ddim' or 'pndm'"
    if scheduler_type=="ddim":
      self.scheduler=diffusers.DDIMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
    elif scheduler_type=="pndm":
      self.scheduler=diffusers.PNDMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
    self.tokenizer=transformers.CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

    clip_model, clip_preprocess = clip.load("RN50")
    clip_model.eval()
    for param in clip_model.parameters():
      param.requires_grad = False

    self.l_func = Clip(clip_model)
    self.l_func.eval()
    for param in self.l_func.parameters():
      param.requires_grad = False
    self.l_func = torch.nn.DataParallel(self.l_func).to(device)
    
    ldm=diffusers.StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
    #keep autoencoder parameters frozen
    self.autoencoder=ldm.vae
    for param in self.autoencoder.parameters():
      param.requires_grad = False

    self.unet=ldm.unet #subject to training, f=8
    #when is this used? for multi-modality.
    self.text_encoder=ldm.text_encoder
  
  def display_intermediates(self, clean_x0, timestep, g_step):
    print("Intermediate Generation for Timestep {:d}, Recurrent Step {:d}".format(timestep, g_step+1))
    n_imgs=clean_x0.size(0)
    clean_x0=torch.clamp(clean_x0, min=0, max=1)

    width=10
    n_rows=n_imgs//4 if n_imgs%4==0 else n_imgs//4+1
    n_cols=4
    height=5*n_rows
    plt.figure(figsize=(width, height))
    for idx, img in enumerate(clean_x0):
      plt.subplot(n_rows, n_cols, idx+1)
      plt.imshow(img.permute(1,2,0).detach().cpu().numpy())
    plt.show()
    return
  
  def normalize(self, tensor):
    return tensor*2-1
  
  def unnormalize(self, tensor):
    return (tensor+1)*0.5
  
  def get_x0_from_zt(self, latents, noise_pred, timestep):
    timestep_idx=timestep-1
    alpha_bar=self.scheduler.alphas_cumprod[timestep_idx]
    z0_pred=(latents-torch.sqrt(1-alpha_bar)*noise_pred)/torch.sqrt(alpha_bar)
    z0_pred=(1 / 0.18215)*z0_pred
    clean_x0=self.autoencoder.decode(z0_pred).sample
    clean_x0=self.unnormalize(clean_x0)
    return clean_x0
  
  def get_prompt_embeddings(self, batch_size, prompt, a_prompt, n_prompt):
    #get prompt embeddings including negative prompt for CFG.
    prompts=[prompt+", "+a_prompt for _ in range(batch_size)]
    uncond_prompts=[n_prompt for _ in range(batch_size)]
    whole_prompts=[*prompts, *uncond_prompts]
    with torch.no_grad():
      tokenized=self.tokenizer(whole_prompts, return_tensors='pt', padding=True)
      input_ids, attention_mask=tokenized.input_ids.to(device), tokenized.attention_mask.to(device)
      prompt_embeddings=self.text_encoder(input_ids, attention_mask=attention_mask).last_hidden_state
    return prompt_embeddings
  
  def predict_noise(self, input_latents, timestep, whole_prompt_embeddings, cfg_strength, no_grad=True):
    batch_size=input_latents.size(0)
    with torch.set_grad_enabled(not no_grad):
      x_in = torch.cat([input_latents] * 2)
      whole_noise = self.unet(x_in, timestep, whole_prompt_embeddings).sample
      cond_noise=whole_noise[:batch_size]
      uncond_noise=whole_noise[batch_size:]
      noise_pred = uncond_noise + cfg_strength * (cond_noise - uncond_noise)
    return noise_pred
    
  def forward(self, image_size, batch_size, prompt, style_image, a_prompt='best quality, extremely detailed', 
              n_prompt="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
              num_inference_steps=500, cfg_strength=3, self_recurrent_steps=6, guidance_strength_coeff=6):
    #guidance function settings
    operation_func = None
    other_guidance_func = None
    criterion = self.l_func
    other_criterion = None

    intermediates=[]

    #universal guidance for style transfer
    latent_h, latent_w=image_size[0]//8, image_size[1]//8
    whole_prompt_embeddings=self.get_prompt_embeddings(batch_size, prompt, a_prompt, n_prompt)
    style_features=self.l_func.module.encode(style_image.unsqueeze(dim=0))

    self.scheduler.set_timesteps(num_inference_steps)
    timesteps=self.scheduler.timesteps
    timestep_indices=timesteps-1

    latents=torch.randn(batch_size, 4, latent_h, latent_w).to(device)
    pbar=tqdm(desc=prompt, total=num_inference_steps)
    #for each sampling step
    for idx in range(num_inference_steps):
      timestep=timesteps[idx]
      timestep_idx=timestep_indices[idx]

      #for each self-recurrent steps
      for g_step in range(self_recurrent_steps):
        torch.set_grad_enabled(True)
        #detach input latents for gradient computation. + allow gradient computation.
        input_latents=latents.detach().requires_grad_(True)

        #CFG-applied ControlNet noise prediction
        noise_pred=self.predict_noise(input_latents, timestep, whole_prompt_embeddings, cfg_strength, no_grad=False)

        recons_image = self.get_x0_from_zt(input_latents, noise_pred, timestep)
        intermediates.append(recons_image.detach().clone())
        #self.display_intermediates(recons_image, timestep, g_step)

        #no gradient step for Stable Diffusion version
        #universal guidance step includes scheduler sampling.
        selected = -1 * criterion(recons_image, style_features)
        grad = torch.autograd.grad(selected.sum(), input_latents)[0]
        grad = grad * guidance_strength_coeff

        timestep_idx=timestep-1
        alpha_bar=self.scheduler.alphas_cumprod[timestep_idx]
        noise_pred = noise_pred - torch.sqrt(1-alpha_bar) * grad.detach()
        #print(torch.mean(noise_pred), torch.std(noise_pred))

        input_latents = input_latents.requires_grad_(False)
  
        torch.set_grad_enabled(False)

        #updating the input latents
        with torch.no_grad():
          #z_(tau-1)
          latents_prev=self.scheduler.step(noise_pred, timestep, latents).prev_sample

          #actually update img
          if idx==num_inference_steps-1:
            alpha_bar_prev=1
          else:
            alpha_bar_prev=self.scheduler.alphas_cumprod[timestep_indices[idx+1]]
          coeff=alpha_bar/alpha_bar_prev
          eps=torch.randn_like(latents).to(device)

          #should predict z_(tau) from z_(tau-1) => for this to work: should be using the DDPM whole sampling
          latents = torch.sqrt(coeff) * latents_prev + torch.sqrt(1-coeff) * eps
        
      #final prediction as new latent after guidance
      latents=latents_prev
      pbar.update(1)

    pbar.close()
    latents=latents.detach()
    with torch.no_grad():
      latents=(1 / 0.18215)*latents
      images=self.autoencoder.decode(latents).sample
      images=self.unnormalize(images)
      images=torch.clamp(images, min=0, max=1)
    return images, intermediates

In [None]:
#gathering style images
preprocess=torchvision.transforms.Compose([
    torchvision.transforms.Resize(512),
    torchvision.transforms.CenterCrop(512),
    torchvision.transforms.ToTensor()
])

picasso_url='https://render.fineartamerica.com/images/rendered/default/print/6/8/break/images/artworkimages/medium/2/pablo-picasso-painting-raq-med.jpg'
gogh_url="https://th-thumbnailer.cdn-si-edu.com/GgmJe7fORYYh66TivAfZwNkCfv0=/fit-in/1600x0/filters:focal(640x640:641x641)/https://tf-cmsv2-smithsonianmag-media.s3.amazonaws.com/filer_public/3e/0b/3e0b2b4b-2b70-4309-a308-8bbf08360e94/national_gallery_of_the_faroe_islands_ai_exhibit_insprired_by_van_gogh.png"
pokemon_url="https://assets.reedpopcdn.com/-1645973608923.jpg/BROK/thumbnail/1600x900/quality/100/-1645973608923.jpg"
mario_url="https://www.pockettactics.com/wp-content/sites/pockettactics/2022/09/super-mario-maker-2-super-mario-bros-5-550x309.jpg"

response = requests.get(picasso_url)
picasso = preprocess(PIL.Image.open(io.BytesIO(response.content))).to(device)[:3]
response = requests.get(gogh_url)
gogh = preprocess(PIL.Image.open(io.BytesIO(response.content))).to(device)[:3]
response = requests.get(pokemon_url)
pokemon = preprocess(PIL.Image.open(io.BytesIO(response.content))).to(device)[:3]
response = requests.get(mario_url)
mario = preprocess(PIL.Image.open(io.BytesIO(response.content))).to(device)[:3]

#display example style images
plt.figure(figsize=(20,10))
plt.subplot(1,4,1)
plt.imshow(picasso.permute(1,2,0).cpu().numpy())
plt.subplot(1,4,2)
plt.imshow(gogh.permute(1,2,0).cpu().numpy())
plt.subplot(1,4,3)
plt.imshow(pokemon.permute(1,2,0).cpu().numpy())
plt.subplot(1,4,4)
plt.imshow(mario.permute(1,2,0).cpu().numpy())
plt.show()

In [None]:
pipeline=UniversalGuidancePipeline('ddim')

In [None]:
n_gen=2
prompt="a hot air balloon cruising over a local village at daytime"
style_image=pokemon

styled_images, intermediates=pipeline(image_size=(512,512), batch_size=n_gen, prompt=prompt, style_image=style_image, num_inference_steps=200)

In [None]:
print("Styled Images")
plt.figure(figsize=(15,5))
for idx, image in enumerate(styled_images):
  plt.subplot(1, n_gen, idx+1)
  plt.imshow(image.permute(1,2,0).detach().cpu().numpy())
plt.show()

In [None]:
plt.figure(figsize=(25,25))
len_interm=len(intermediates)
show_every=10
n_imgs=len_interm//show_every
n_row=int(np.sqrt(n_imgs))+1
for idx in range(n_imgs):
  plt.subplot(n_row,n_row,idx+1)
  img_idx=show_every*idx
  img=torch.clamp(intermediates[img_idx].squeeze(dim=0), min=0, max=1)
  plt.imshow(img.permute(1,2,0).detach().cpu().numpy())
plt.show()