In [None]:
# install clip-interrogator
!pip install clip-interrogator

In [None]:
!pip install --upgrade diffusers transformers accelerate

# Load Libraries and Pipelines

In [None]:
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInstructPix2PixPipeline
import torch
import requests
import numpy as np
import pandas as pd
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
#from clip_interrogator import Config, Interrogator
import glob 
import nltk
from typing import Optional, List, Union
nltk.download('punkt')

In [None]:
# load the pipeline
# Model and Device
model_id = ["nitrosocke/mo-di-diffusion", "Linaqruf/anything-v3.0", "darkstorm2150/Protogen_x3.4_Official_Release"]
instruct_model_id = "timbrooks/instruct-pix2pix"
clip_model_id = "ViT-L-14/openai"
device = "cuda"

# text to image pipeline
#txt2imgpipe = StableDiffusionPipeline.from_pretrained(model_id[0], torch_dtype=torch.float16).to(device)

# image to image pipeline
img2imgpipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id[1], torch_dtype=torch.float16).to(device)
#anythingimg2imgpipe = StableDiffusionImg2ImgPipeline.from_pretrained("Linaqruf/anything-v3.0", revision="diffusers", torch_dtype=torch.float16).to(device)

# Instruct Pix2Pix
instructpixpipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(instruct_model_id, torch_dtype=torch.float16).to(device)

# CLIP Interrogator
#ci = Interrogator(Config(clip_model_name = clip_model_id))

In [None]:
# User defined Functions
def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

def multiple_rounds_img2img(
  init_image: Image,
  prompt: str,  
  negative_prompt: str,
  strength_array: List[float],
  guidance_array: Union[List[float], List[int]],
  final_images_to_return: Optional[int] = 5,
  num_rounds: Optional[int] = 4,
  seed: Optional[int] = 123) -> List:

  # Parameter checking
  ## init_image
  assert isinstance(init_image, Image.Image), "init_image must be an Image"

  ## prompt & negative_prompt
  assert isinstance(prompt, str) and len(prompt) > 0, "Prompt provided must be a comma separated string and cannot be an empty string" 
  assert isinstance(negative_prompt, str), "Negative Prompt provided must be a comma separated string"

  ## num rounds
  assert num_rounds > 0, "num_rounds must be greater than 0"

  ## strength_array & guidance array
  assert len(strength_array) == num_rounds, 'strength_array length must be identical to num_rounds'
  assert len(guidance_array) == num_rounds, 'guidance_array length must be identical to num_rounds'

  ## final_images_to_return
  assert final_images_to_return > 0, "final_images_to_return must be greater than 0"

  ## seed
  assert isinstance(seed, int), "seed must be an integer"
  
  # Main Body
  torch.manual_seed(seed)
  output_image_array = [init_image]

  for idx in list(range(0, num_rounds - 1)):
    
    img2imgpipeline = img2imgpipe(prompt = prompt,
                          image=output_image_array[idx],
                          strength=strength_array[idx],
                          guidance_scale=guidance_array[idx],
                          num_inference_steps=400,
                          num_images_per_prompt = 1,
                          negative_prompt = negative_prompt)

    output_image_array.append( img2imgpipeline.images[0] )

    # For final round of inference
    torch.manual_seed(seed)
    img2imgpipeline_final = img2imgpipe(prompt = prompt,
                            image=output_image_array[-1],
                            strength=strength_array[-1],
                            guidance_scale=guidance_array[-1],
                            num_inference_steps=400,
                            num_images_per_prompt = final_images_to_return,
                            negative_prompt = negative_prompt)

    return img2imgpipeline_final.images



# Load Data

In [None]:
filenames = glob.glob("/content/IMG*")

raw_images = [Image.open(i) for i in filenames]
imgs = [i.convert("RGB") for i in raw_images]

In [None]:
image_grid(imgs, rows = 2, cols = 4)

# Construct Prompt

In [None]:
ci = Interrogator(Config(clip_model_name = clip_model_id))

In [None]:
clip_prompt = ci.interrogate_fast(imgs[0])

In [None]:
# Strip first part of the returned prompt
img_description_prompt = clip_prompt.split(",")[0]

In [None]:
trigger_words = "cartoon, Pixar, Disney character, 3D render, modern disney style"

In [None]:
augmented_prompt = ", ".join( [img_description_prompt, trigger_words] )

In [None]:
print(augmented_prompt)

# Img2Img with Multiple Rounds

In [None]:
augmented_prompt = "a stuffed brown meerkat dressed in a zebra suit, cartoon, Pixar, Disney character, 3D render, modern disney style"

In [None]:
returned_imgs = multiple_rounds_img2img(
  init_image = imgs[0],
  prompt = augmented_prompt,
  negative_prompt = "disfigured, misaligned, ugly, blurry, grumpy, grey, dark, big eyes, person, human, fuzzy, furry",
  strength_array = [0.7, 0.6, 0.5, 0.4],
  guidance_array = [20.0, 18.0, 16.0, 14.0],
  final_images_to_return = 5,
  num_rounds = 4,
  seed = 123)

In [None]:
image_grid([imgs[0]] + returned_imgs, rows=2, cols = 3)

In [None]:
#returned_imgs[3].save("cartoonised_image.png")

# Instruct Pix2Pix to Edit Generated Images

---



In [None]:
prompt_to_edit = "Put a waterfall in the background"

torch.manual_seed(123)
edited_image = instructpixpipe(prompt_to_edit,
            image=returned_imgs[3],
            num_inference_steps=100,
            image_guidance_scale=2.10,
            guidance_scale=7.5,
            num_images_per_prompt = 4)

In [None]:
image_grid([ returned_imgs[3] ] + edited_image.images, rows=2, cols = 3)