<a href="https://colab.research.google.com/github/aniketmore311/ml-cookbook/blob/main/stable_diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
!pip install diffusers --quiet
!pip install transformers accelerate scipy safetensors ftfy --quiet

In [20]:
!pip install basicsr facexlib --quiet

In [21]:
!pip install gfpgan realesrgan --quiet

In [22]:
diffusion_pipeline = None
running_model_name = None
imageWidget = None
showing_image_name = None
status_message = ""
imageHeight = 512
imageWidth = 512
# 512:512 -> 1:2
# 512:896 -> 4:7
# 512:912 -> 9:16
# 512:1024 -> 9:18
# 512:1136 -> 9:20

In [23]:
# function to generate images
def generate_image(image_prompt,suffix, negative_prompt, height_in_px, width_in_px,inference_steps,guidance_scale):
  print("generating image...")
  prompt = image_prompt + " " + suffix
  seed = random.randint(1,1000000)
  generator = torch.manual_seed(seed)
  image = diffusion_pipeline(prompt=prompt, generator=generator, negative_prompt=negative_prompt,height=height_in_px,width=width_in_px, num_inference_steps=inference_steps, guidance_scale=guidance_scale).images[0]
  print("seed: "+str(seed))
  return image

In [24]:
# function to restore image
def restore(input_path, output_path):
  current_dir = os.getcwd()
  img_path = input_path
  output_path = output_path
  url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
  # model_path = "./GFPGANv1.3.pth"
  upscale = 1
  arch = "clean"
  channel_multiplier = 2
  bg_upsampler=None
  weight = 0.5
  aligned = None
  only_center_face = None

  restorer = GFPGANer(
      model_path=url,
      upscale=upscale,
      arch=arch,
      channel_multiplier=channel_multiplier,
      bg_upsampler=bg_upsampler
  )

  input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
  # restore faces and background if necessary
  cropped_faces, restored_faces, restored_img = restorer.enhance(
      input_img,
      has_aligned=aligned,
      only_center_face=only_center_face,
      paste_back=True,
      weight=0.5)
  # print(restored_img)
  cv2.imwrite(output_path,restored_img)

In [25]:
# function to upscale image
def upscale(input_path, output_path, scale):
    current_directory = os.getcwd()
    
    # model_name = 'RealESRGAN_x4plus_anime_6B'
    # model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
    # netscale = 4
    # file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']

    model_name="RealESRGAN_x4plus"
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
    netscale = 4
    file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
    
    ROOT_DIR = os.path.dirname(current_directory)
    for url in file_url:
        model_path = load_file_from_url(url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
    dni_weight = None
    tile = 0
    tile_pad = 10
    pre_pad = 0
    fp32 = False
    gpu_id = None
    outscale = scale
    
    # restorer
    upsampler = RealESRGANer(
        scale=netscale,
        model_path=model_path,
        dni_weight=dni_weight,
        model=model,
        tile=tile,
        tile_pad=tile_pad,
        pre_pad=pre_pad,
        half=not fp32,
        gpu_id=gpu_id)

    img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
    output, _ = upsampler.enhance(img, outscale=outscale)
    cv2.imwrite(output_path, output)

In [26]:
from IPython.display import Image, clear_output, HTML, FileLink
import ipywidgets as widgets
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler
import torch
import requests
import random
import os
from gfpgan import GFPGANer
import cv2
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from PIL import Image as PILImage

def createIntegerWidgetGroup(label_text,value):
    labelWidget = widgets.Label(label_text)
    intTextWidget = widgets.IntText(value=value)
    intTextWidgetGroup = widgets.VBox([labelWidget, intTextWidget])
    return intTextWidgetGroup, intTextWidget

def createFloatWidgetGroup(label_text,value):
    labelWidget = widgets.Label(label_text)
    intTextWidget = widgets.FloatText(value=value)
    intTextWidgetGroup = widgets.VBox([labelWidget, intTextWidget])
    return intTextWidgetGroup, intTextWidget

def createTextAreaWidgetGroup(label_text, height_str, width_str, value):
    labelWidget = widgets.Label(label_text)
    textareaWidget = widgets.Textarea()
    textareaWidget.layout.height = height_str
    textareaWidget.layout.width = width_str
    textareaWidget.value = value
    textareaWidgetGroup = widgets.VBox([labelWidget, textareaWidget])
    return textareaWidgetGroup, textareaWidget

# integer input widgets
heightGrp, height = createIntegerWidgetGroup("height",512)
widthGrp, width = createIntegerWidgetGroup("width",512) 
stepsGrp, steps = createIntegerWidgetGroup("steps",20)
guidanceGrp, guidance = createFloatWidgetGroup("guidance",7.5)
upscaleGrp, upscalefactor = createFloatWidgetGroup("upscale",2)

# dropdown widget
options = ["XpucT/Deliberate","SG161222/Realistic_Vision_V1.3","Lykon/DreamShaper","gsdf/Counterfeit-V2.5","andite/pastel-mix","andite/anything-v4.0","Anashel/rpg","darkstorm2150/Protogen_x3.4_Official_Release","dreamlike-art/dreamlike-photoreal-2.0", "stabilityai/stable-diffusion-2-1-base"]
modelLabelWidget = widgets.Label("Select the model")
modelWidget = widgets.Dropdown(options=options, value="XpucT/Deliberate")
modelWidgetGroup = widgets.VBox([modelLabelWidget, modelWidget])

# text input widgets
promptGrp, prompt = createTextAreaWidgetGroup("write prompt", "75px","500px","photo of a white lamborgini on road, night, cyberpunk, city, cityscape, dreamy")
suffixGrp, suffix = createTextAreaWidgetGroup("suffix", "75px","500px","professional, 8k, HDR, highly detailed, high resolution, cinematic, realistic, intricate, masterpiece, award winning, trending on artstation")
negativePromptGrp, negativePrompt = createTextAreaWidgetGroup("negative prompt", "75px","500px","deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, low quality")

# button widgets
button = widgets.Button(description="generate")
restoreButton = widgets.Button(description="restore face")
upscaleButton = widgets.Button(description="upscale")
clearOpButton = widgets.Button(description="clear output")

# message widgets
messageWidget = widgets.Text()
messageWidget.value = f"{status_message}"
messageWidget.disabled = True

imageNameWidget = widgets.Text()
imageNameWidget.value = f"showing: {showing_image_name}"
imageNameWidget.disabled = True

# image widget
imageWidget = widgets.Image(format="png",height=height.value,width=width.value)

# formatting widgets
promptCol = widgets.VBox([modelWidgetGroup,promptGrp, suffixGrp, negativePromptGrp])
dimentionGroup = widgets.HBox([heightGrp,widthGrp, stepsGrp, guidanceGrp])
btnGrp = widgets.HBox([button, restoreButton])
upscaleBtnGrp = widgets.HBox([upscaleGrp,upscaleButton])
clearOpBtnGrp = widgets.HBox([clearOpButton])
controlsCol = widgets.VBox([dimentionGroup,btnGrp,upscaleBtnGrp,clearOpButton])
controlsRow = widgets.HBox([promptCol,controlsCol])
imageRow = widgets.HBox([imageWidget])
containerVBox = widgets.VBox([imageRow,imageNameWidget,messageWidget,controlsRow])

#styling
button.layout.margin = "20px 5px 0 0"
restoreButton.layout.margin = "20px 5px 0 0"
upscaleButton.layout.margin = "35px 0 0 0"
promptCol.layout.margin = "0px 20px 0px 0"
clearOpButton.layout.margin = "20px 5px 0 0"
height.layout.width = "100px"
width.layout.width = "100px"
steps.layout.width = "100px"
guidance.layout.width = "100px"
upscalefactor.layout.width = "100px"
upscalefactor.layout.margin = "0 10px 0 0"

# handler runs on restore
def onRestore(btn):
  global showing_image_name
  global status_message
  
  status_message = "restoring image..."
  update()

  restore(showing_image_name,"image_rest.png")
  
  showing_image_name = "image_rest.png"
  status_message = "image restoration complete"
  update()

#handler runs on upscale
def onUpScale(btn):
  global showing_image_name
  global status_message
  global imageHeight
  global imageWidth
  
  status_message = "upscaling image please wait..."
  update()
  
  upscale(showing_image_name,"image_upscale.png",upscalefactor.value)

  image = PILImage.open("image_upscale.png")
  w,h = image.size
  
  imageHeight = h
  imageWidth = w
  showing_image_name = "image_upscale.png"
  status_message = "upscaling complete"
  update()

def onGenerate(btn):
    global diffusion_pipeline
    global running_model_name
    global imageWidget
    global showing_image_name
    global status_message
    global imageHeight
    global imageWidth

    status_message = "generating image please wait..."
    update()

    if diffusion_pipeline == None or running_model_name != modelWidget.value:
      status_message = "creating new pipeline wait a few moments..."
      render()
      diffusion_pipeline = StableDiffusionPipeline.from_pretrained(modelWidget.value, torch_dtype=torch.float16)
      diffusion_pipeline.scheduler = EulerDiscreteScheduler.from_config(diffusion_pipeline.scheduler.config)
      diffusion_pipeline.safety_checker = None
      running_model_name = modelWidget.value
      diffusion_pipeline.to("cuda")
      status_message = "pipeline created generating image..."
      update()
    image = generate_image(prompt.value, suffix.value, negativePrompt.value, height.value, width.value, steps.value, guidance.value)
    image.save("image.png")
    
    imageHeight = height.value
    imageWidth = width.value
    showing_image_name = "image.png"
    status_message = "image generation completed"
    update()

def onClearOp(btn):
  render()

button.on_click(onGenerate)
restoreButton.on_click(onRestore)
upscaleButton.on_click(onUpScale)
clearOpButton.on_click(onClearOp)

def update():
  global status_message
  global showing_image_name
  global imageHeight
  global imageWidth

  messageWidget.value = f"status: {status_message}"
  imageNameWidget.value = f"showing: {showing_image_name}"
  imgdata = open(showing_image_name,"rb").read()
  imageWidget.value = imgdata 
  imageWidget.height = imageHeight
  imageWidget.width = imageWidth


def render():
  global showing_image_name
  global status_message
  global imageHeight
  global imageWidth
  clear_output(wait=True)
  display(containerVBox)
  update()


In [27]:
## generate first image and set image widget
running_model_name = modelWidget.value;
print("\ncreating new pipeline wait a few moments...")
diffusion_pipeline = StableDiffusionPipeline.from_pretrained(modelWidget.value, torch_dtype=torch.float16)
diffusion_pipeline.scheduler = EulerDiscreteScheduler.from_config(diffusion_pipeline.scheduler.config)
diffusion_pipeline.safety_checker = None
diffusion_pipeline.to("cuda")
print("\npipeline created")
image = generate_image(prompt.value, suffix.value, negativePrompt.value, height.value, width.value, steps.value, guidance.value)
image.save("image.png")
showing_image_name = "image.png"
imgdata = open(showing_image_name,"rb").read()
imageWidget.value = imgdata

render()

VBox(children=(HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02\x00\x00\x00\x02\x0…

generating image...


  0%|          | 0/20 [00:00<?, ?it/s]

seed: 918518
