### **Enviroment Set-up**

In [None]:
## git clone repository
!git clone -q https://github.com/Dohyeon-Kim1/Multimodal_StyleTransfer.git MST

## set directory and install required packages
%cd /content/MST
!pip -q install -r requirements.txt

## download pretrained models
!wget -q -O "models/SAM/segment_anything/model_zoo/sam_vit_h.pth" https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
!wget -q -O "models/AdaIN/model_zoo/encoder.pth" https://drive.google.com/u/0/uc?id=1EpkBA2K2eYILDSyPTt0fztz59UjAIpZU&export=download
!wget -q -O "models/AdaIN/model_zoo/decoder.pth" https://drive.google.com/u/0/uc?id=1bMfhMMwPeXnYSQI6cDWElSZxOxc6aVyr&export=download

### **Set-up**

In [None]:
import os
import requests
import torch
import numpy as np
import cv2
from PIL import Image

## models
from diffusers import StableDiffusionPipeline
from models.SAM.segment_anything import sam_model_registry, SamPredictor
from models.AdaIN.inference import StyleTransfer

## utils
from utils.utils import empty_memory
from utils.utils import str_input, array_input, print_highlight
from utils.utils import show_image, show_masks, show_point_mask, show_image_mask_pairs
from utils.utils import mask_by_point, merge_masks

In [None]:
## load models
print("Loading Text-to-Image Model ..")
stable_diffusion = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16).to("cuda")
print("Success!\n")

print("Loading Instance Segmentation Model ..")
sam = sam_model_registry["vit_h"](checkpoint="models/SAM/segment_anything/model_zoo/sam_vit_h.pth").to("cuda")
print("Success!\n")

print("Load Style Transfer Model:")
adain = StyleTransfer(enc_path="models/AdaIN/model_zoo/encoder.pth", dec_path="models/AdaIN/model_zoo/decoder.pth").to("cuda")
print("Success!")

In [None]:
## load models
print("Loading Text-to-Image Model ..")
stable_diffusion = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16)
print("Success!\n")

print("Loading Instance Segmentation Model ..")
sam = sam_model_registry["vit_h"](checkpoint="models/SAM/segment_anything/model_zoo/sam_vit_h.pth")
print("Success!\n")

print("Load Style Transfer Model:")
adain = StyleTransfer(enc_path="models/AdaIN/model_zoo/encoder.pth", dec_path="models/AdaIN/model_zoo/decoder.pth")
print("Success!")

### **Step1: Prepare Content Image**

In [None]:
def prepare_image(model, usage, device="cpu"):

  ## select how to load image
  flag = str_input(f"Select how to load the {usage} iamge: [url/path/create]\n")
  assert flag in ["url","path","create"]

  ## load image from url
  if flag == "url":

    url = str_input("Enter the url:\n")
    image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

  ## load image from path
  elif flag == "path":

    path = str_input("Enter the path:\n")
    while not(os.path.exists(path)):
      path = str_input("There exists no path.\nEnter the path again:\n")
    image = Image.open(path).convert("RGB")

  ## create image
  elif flag == "create":

    model.to(device)
    print_highlight(f"Creating {usage.capitalize()} Image ..")

    while True:

      text = str_input(f"Enter text description of {usage} image which you want to create:\n")
      image = model(text, guidance_scale=5, num_inference_steps=25).images[0]

      show_image(image, axis="off")
      sign = str_input(f"Would you create new {usage} image? [yes/no]\n")
      assert sign in ["yes","no"]

      if sign == "no":
        break

  ## resize image for memory when either width or height is larger than 512
  w, h = image.size
  if max(w,h) > 512:
    resize_scale = 512 / max(w,h)
    image = image.resize((int(w*resize_scale), int(h*resize_scale)))

  ## show image
  print_highlight(f"Selected {usage.capitalize()} Image")
  show_image(image, axis="off")

  empty_memory()

  return image

In [None]:
content_image = prepare_image(model=stable_diffusion,       ## text-to-image model
                              usage="content",              ## usage of image [content/style]
                              device="cuda")                ## model's inference location [cpu/cuda]

### **Step2: Create Mask and Style Image Pairs**

In [None]:
def mask_style_pair(content_image, model_is, model_t2i, device="cpu"):

  ## show content image
  print_highlight("Content Image")
  show_image(content_image)

  ## select whether to apply style transfer to all parts of the content image
  sign = str_input("Do you want to apply style transfer to all parts of the content image? [yes/no]\n")
  assert sign in ["yes","no"]

  ## style transfer about the entire content image
  if sign == "yes":

    print_highlight("Creating Mask ..")
    entire_mask = np.ones_like(np.array(content_image)[:,:,0], dtype=bool)

    print_highlight("Created Mask")
    show_masks(content_image, [entire_mask], axis="off", only_arr=True, fix_color=True)

    print_highlight("Selecting Style Image ..")
    style_image = prepare_image(usage="style", model=model_t2i, device=device)

    mask_style = {"0": [entire_mask, style_image]}
    print_highlight("Created Mask & Style Image Pair")
    show_image_mask_pairs(content_image, [mask_style["0"]])

    empty_memory()

    return mask_style

  model_is = model_is.to(device)

  ## set image to manual mask generator

  manual_generator = SamPredictor(model_is)
  manual_generator.set_image(np.array(content_image))

  empty_memory()

  ## variables
  w, h = content_image.size
  idx = 0
  mask_style = {}

  ## create mask & style image pairs
  while True:

    print_highlight(f"Creating Mask & Style Image Pair [{idx+1}] ..")

    print_highlight("Creating Mask ..")
    selected_masks = []

    while True:

      ## select mask
      point = array_input(f"Enter (X,Y) coordinate: [ex) 100,100] [X<{w}, Y<{h}] \n", dtype=int)
      selected_mask, _, _ = manual_generator.predict(point_coords=np.expand_dims(point, axis=0),
                                                       point_labels=np.array([1]), multimask_output=False)
      selected_mask = np.squeeze(selected_mask, axis=0)

      if selected_mask is not None:
        show_point_mask(content_image, point, selected_mask)
        sign = str_input("Is this the mask you selected? [yes/no]\n")
        assert sign in ["yes","no"]
      else:
        print(f"There's no mask at the point ({point[0]},{point[1]})")
        sign = "no"

      while sign == "no":

        point = array_input(f"Enter (X,Y) coordinate again: [ex) 100,100] [X<{w}, Y<{h}]\n", dtype=int)
        selected_mask, _, _ = manual_generator.predict(point_coords=np.expand_dims(point, axis=0),
                                                         point_labels=np.array([1]), multimask_output=False)
        selected_mask = np.squeeze(selected_mask, axis=0)

        if selected_mask is not None:
          show_point_mask(content_image, point, selected_mask)
          sign = str_input("Is this the mask you selected? [yes/no]\n")
          assert sign in ["yes","no"]
        else:
          print(f"There's no mask at the point ({point[0]},{point[1]})")

      ## add selected mask to list
      selected_masks.append(selected_mask)

      sign = str_input("Would you add more part to the mask? [yes/no]\n")
      assert sign in ["yes","no"]

      if sign == "no":
        break

    ## merge mask
    merged_mask = merge_masks(selected_masks)
    print_highlight("Created Mask")
    show_masks(content_image, [merged_mask], axis="off", only_arr=True, fix_color=True)

    ## select style image
    print_highlight("Selecting Style Image ..")
    sign = str_input("Would you apply content image style for the mask part? [yes/no]\n")
    assert sign in ["yes","no"]

    if sign == "yes":
      style_image = None
      print_highlight("Selected Style Image")
      show_image(content_image, axis="off")
    elif sign == "no":
      style_image = prepare_image(usage="style", model=model_t2i, device=device)

    ## create and show mask & style image pair
    mask_style[str(idx)] = [merged_mask, style_image]
    print_highlight(f"Created Mask & Style Image Pair [{idx+1}]")
    show_image_mask_pairs(content_image, [mask_style[str(idx)]])

    ## select whether to create new mask & style image pair additionaly
    sign = str_input("Would you create more Mask & Style Image Pair? [yes/no]\n")
    assert sign in ["yes","no"]

    idx += 1

    ## create mask & style image pair about the background (not selected parts)
    if sign == "no":

      print_highlight("Creating Mask & Style Image Pair [background] ..")

      print_highlight("Creating Mask ..")
      background_mask = (merge_masks([mask_style[key][0] for key in mask_style.keys()]) == False)

      print_highlight("Created Mask")
      show_masks(content_image, [background_mask], axis="off", only_arr=True, fix_color=True)

      print_highlight("Selecting Style Image ..")
      sign = str_input("Would you apply content image style for the mask part? [yes/no]\n")
      assert sign in ["yes","no"]

      if sign == "yes":
        style_image = None
        print_highlight("Selected Style Image")
        show_image(content_image, axis="off")
      elif sign == "no":
        style_image = prepare_image(usage="style", model=model_t2i, device=device)

      mask_style[str(idx)] = [background_mask, style_image]
      print_highlight("Created Mask & Style Image Pair [background]")
      show_image_mask_pairs(content_image, [mask_style[str(idx)]])

      break

  empty_memory()

  return mask_style

In [None]:
mask_style = mask_style_pair(content_image=content_image,       ## content image -> PIL.Image
                             model_is=sam,                      ## instance segmentaiton model
                             model_t2i=stable_diffusion,        ## text-to-image model
                             device="cuda")                     ## model's inference location [cpu/cuda]

### **Step3: Style Transfer**

In [None]:
def style_transfer(content_image, mask_style_pair, model, device="cpu"):

  model.to(device)
  new_image = np.zeros_like(np.array(content_image), dtype=np.float32)

  ## create style transfered image
  for key, value in mask_style_pair.items():

    ## not style transfer for the mask part
    if value[1] is None:
      new_image += (np.array(content_image)/255) * np.stack([value[0],value[0],value[0]], axis=2)

    ## style transfer for the mask part
    else:
      transfered_image = model(np.array(content_image)/255, np.array(value[1])/255, alpha=1, device=device)
      if transfered_image.shape != new_image.shape:
        transfered_image = cv2.resize(transfered_image, (new_image.shape[1], new_image.shape[0]))
      new_image += transfered_image * np.stack([value[0],value[0],value[0]], axis=2)

  ## np.ndarray to PIL.Image
  new_image = Image.fromarray((new_image*255).astype(np.uint8))

  empty_memory()

  return new_image

In [None]:
new_image = style_transfer(content_image=content_image,        ## content image -> PIL.Image
                           mask_style_pair=mask_style,         ## mask & style image pairs -> Dictionary
                           model=adain,                        ## style transfer model
                           device="cuda")                      ## model's inference location [cpu/cuda]

In [None]:
display(new_image)