In [None]:
#@title # GPU info { display-mode: "form" }
!nvidia-smi

# Stable Diffusion

## Setup

In [None]:
import os
from os import path as osp
from pathlib import Path as pth
from pathlib import PurePath as ppth

upload_root = pth("upload")
results_dir = pth("results")
download = pth("download")
converted = pth("converted")

deepbooru_image_dir = upload_root / "deepbooru_image"
init_image_dir = upload_root / "init_image"
mask_image_dir = upload_root / "mask_image"

for path in [upload_root, results_dir, deepbooru_image_dir, init_image_dir, mask_image_dir, download, converted]:
    path.mkdir(exist_ok=True)

In [None]:
!pip install -q timm ninja accelerate transformers ftfy safetensors
!pip install -q diffusers==0.10.2
!pip install -q gradio
!pip install -q "ipywidgets>=7,<8"
!pip install -q ftfy pathvalidate omegaconf

!sudo apt update
!wget --no-check-certificate https://mega.nz/linux/repo/xUbuntu_18.04/amd64/megacmd-xUbuntu_18.04_amd64.deb
!wget -nc "https://raw.githubusercontent.com/huggingface/diffusers/main/scripts/convert_original_stable_diffusion_to_diffusers.py"
!sudo apt install /content/megacmd-xUbuntu_18.04_amd64.deb
!sudo apt install aria2 -y

In [None]:
import math
import inspect
import warnings
from typing import List, Optional, Union
from io import BytesIO
from random import SystemRandom
import gc

import selectors
import subprocess
import sys
import shlex

import ast
import re
import itertools
import json
import pathlib
from pathvalidate import sanitize_filename
import shutil

import tarfile

import numpy as np
from tqdm.auto import tqdm

from PIL import Image
from ipywidgets import widgets

import google
from google.colab import files
from IPython import display

import torch
from torch import autocast
from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor
from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionImg2ImgPipeline,
    StableDiffusionInpaintPipeline,
)
from diffusers import DDIMScheduler, LMSDiscreteScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers import DiffusionPipeline

import gradio as gr

from google.colab import files

In [None]:
def bash(*args, param={}):
    param['bufsize'] = 1
    param['stdout'] = subprocess.PIPE
    param['stderr'] = subprocess.PIPE
    for command in args:
        process = subprocess.Popen(shlex.split(command), **param)

        sel = selectors.DefaultSelector()
        sel.register(process.stdout, selectors.EVENT_READ)
        sel.register(process.stderr, selectors.EVENT_READ)

        while True:
            for key, _ in sel.select():
                data = key.fileobj.read1().decode()
                if not data:
                    break
                if key.fileobj is process.stdout:
                    print(data, end="")
                else:
                    print(data, end="", file=sys.stderr)
            else:
                continue
            break

class mega_dl:
    def __init__(self):
        try:
            subprocess.Popen(shlex.split('mega-help'))
        except:
            bash(
                '''sudo apt update''',
                '''wget --no-check-certificate https://mega.nz/linux/repo/xUbuntu_18.04/amd64/megacmd-xUbuntu_18.04_amd64.deb''',
                '''sudo apt install /content/megacmd-xUbuntu_18.04_amd64.deb'''
            )
    def dl(self, url, dir):
        bash('mega-get {} {}'.format(url, dir))
mega = mega_dl()

In [None]:
rand = SystemRandom()

model_path_clip = pth("openai/clip-vit-large-patch14")
config_path = pth("conf.json")


# Define layouts
layouts = {
    "img_grid": widgets.Layout(
        grid_template_columns="repeat(3, 1fr)",
        display="inline-grid",
        gap="0px",
        padding="0px",
        margin="0px",
        align_content="flex-start",
    ),
    "img": widgets.Layout(padding="0px", margin="0px", vertical_align="bottom"),
    "output": widgets.Layout(max_height="768px", overflow="auto")
}


def clear_dir(dir):
    if dir.is_dir():
        shutil.rmtree(dir)
    dir.mkdir()


def load_init_image(path):
    """Load init image and resize to fit 512"""
    img = Image.open(str(path)).convert("RGB")

    if img is None:
        raise Exception(f"Failed to load image from {path}.")

    # resize to fit 512
    res = np.array(img.size, np.int32)

    new_res = res * 512 // np.amin(res) // 8 * 8
    res_mult = np.mean(new_res / res)

    # use bicubic for downsampling and lanczos for upsampling
    inter = Image.BICUBIC
    if res_mult > 1:
        inter = Image.LANCZOS
    
    img = img.resize(new_res, inter)

    print(f"{path}, resized to {new_res[0]}x{new_res[1]}")

    return img


def image_grid(imgs, rows, cols):
    assert len(imgs) <= rows * cols

    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid


def download_byte_img(img_bytes, path, fmt):
    """Save bytes to image file and download"""
    with open(str(path), "wb") as binary_file:
        binary_file.write(img_bytes)
    files.download(str(path))


def pil_to_bytes(img, fmt):
    """Encode a PIL image to bytes"""
    buff = BytesIO()
    img.save(buff, format=fmt, compress_level=1)
    return buff.getvalue()


def get_printable_json_config(config):
    """Convert a json config to a printable version"""
    string = ""
    for line in config.splitlines()[1:-1]:
        string += line.strip() + "\n"
    return string


def defined_kwargs(**kwargs):
    return {k: v for k, v in kwargs.items() if not v is None}


def remove_keys(dictionary, keys):
    dictionary = dictionary.copy()
    for item in keys:
        dictionary.pop(item)

    return dictionary


def rename_key(dictionary, old_name, new_name):
    tmp = dictionary[old_name]
    dictionary = remove_keys(dictionary, [old_name])
    dictionary[new_name] = tmp

    return dictionary


def read_config():
    """Read config for generating"""
    # Load the config from file
    with open(str(config_path)) as f:
        config = json.load(f)

    # Generate seed if used
    if config["seed"] < 0:
        config["seed"] = rand.randint(0, 2**14)

    return config


def generate_label(prompt, config, seed, iter):
    # Get label and filename
    label = prompt[0][:200]
    if len(prompt[0]) > 200:
        label = label + "... "
    if prompt[1] != ():
        label = f"variation: <b>{str(list(prompt[1]))[1:-1]}</b>"
    else:
        label = f"prompt: <b>{label}</b>"

    label += f"; seed: {seed}"

    if config["n_iter"] > 1:
        label += f"; iter: {iter + 1}"

    return label


def parse_prompt(prompt):
    """Parse lists in prompts to prompt variations"""

    pattern = r"\[([^[\]]*)\]"
    matches = re.finditer(pattern, prompt)

    groups = [m.group(0) for m in matches]

    # get all possible combinations
    combinations = list(itertools.product(*[ast.literal_eval(g) for g in groups]))

    # create a prompt for each combination
    new_prompts = []
    for comb in combinations:
        new_prompt = prompt
        for group, word in zip(groups, comb):
            new_prompt = new_prompt.replace(group, word)

        new_prompts.append(new_prompt)

    return list(zip(new_prompts, combinations))


def tags_to_prompt(prompt, tags, excluded_tags, max_tags, tag_confidence_threshold):
    def replace_underscores(string):
        return string.replace("_", " ")

    def accepted_substrings(string, accepted):
        for substring in accepted:
            if substring in string:
                return False
        
        return True

    def prep(string):
        return replace_underscores(string).strip()

    prompt = prep(prompt.replace("/", ""))
    excluded_tags = excluded_tags.split(",")
    excluded_tags = [prep(t) for t in excluded_tags]


    tags = []
    for i, k in enumerate(preds):
        if preds[k] < tag_confidence_threshold:
            break
        
        k = prep(k)
        
        if not accepted_substrings(k, excluded_tags):
            print(f"Excluded {k}")
            continue
        
        tags.append(k)

        if len(tags) == max_tags:
            break
    
    autotagger_prompt = ", ".join(tags)

    if prompt != "":
        autotagger_prompt = prompt + ", " + autotagger_prompt

    return autotagger_prompt


def upload_img(target_path):
    uploaded = files.upload()

    # clear upload directory
    shutil.rmtree(target_path)
    target_path.mkdir()

    uploaded_path = pth(list(uploaded.keys())[0])

    new_uploaded_path = target_path / uploaded_path.name
    uploaded_path.rename(new_uploaded_path)

    assert new_uploaded_path.is_file()

    return new_uploaded_path


def convert_to_diffusers(model_path, out_path):
    bash(f"python convert_original_stable_diffusion_to_diffusers.py --checkpoint_path {model_path} --dump_path {out_path} --scheduler_type 'dpm' --device 'cuda'")

In [None]:
class Deepbooru:
    def __init__(self):
        self.interface = None
    
    def __call__(self, path):
        if self.interface is None:
            self.interface = gr.Interface.load("spaces/hysts/DeepDanbooru")

        out = self.interface(path, 0)

        with open(out) as f:
            preds = json.load(f)["confidences"]
        
        return {x["label"]: x["confidence"] for x in preds}

deepbooru = Deepbooru()

In [None]:
from transformers.feature_extraction_utils import FeatureExtractionMixin

class dummy_safety_checker:
    def __init__(self):
        pass
    
    def __call__(self, images, **kwargs):
        return images, False

class dummy_feature_extractor(FeatureExtractionMixin):   
    def dummy(self, *args, **kwargs):
        return self

    def __init__(self, *args, **kwargs):
        self.pixel_values = self
        self.to = self.dummy
        pass
    
    def __call__(self, *args, **kwargs):
        return self


class SD:
    def __init__(
        self,
        device="cuda",
        torch_dtype=torch.float16,
    ):
        self.device = device
        self.torch_dtype = torch_dtype

        self.revision = None


    def __get_ctx(self):
        return {"device": self.device, "dtype": self.torch_dtype}


    def init_tokenizer(self, path=str(model_path_clip), **kwargs):        
        self.clip_tokenizer = CLIPTokenizer.from_pretrained(path, **kwargs)


    def init_encoder(self, path=str(model_path_clip), **kwargs):
        self.clip_model = CLIPTextModel.from_pretrained(path, **kwargs)
    

    def init_vae(self, path, subfolder="vae", **kwargs):
        model_args = {}
        model_args["torch_dtype"] = self.torch_dtype
        model_args["pretrained_model_name_or_path"] = path
        if subfolder:
            model_args["subfolder"] = subfolder

        model_args.update(kwargs)

        self.vae = AutoencoderKL.from_pretrained(**model_args)


    def init_unet(self, model_path, **kwargs):
        print(kwargs)
        return UNet2DConditionModel.from_pretrained(model_path, subfolder="unet", torch_dtype=self.torch_dtype, **kwargs).to(self.device)


    def __init_composed_unet(self, unets, alphas, unet_kwargs_list):
        def ensure_dict_format_validity(unet1, unet2):
            if not unet_dict.keys() == base_unet_dict.keys():
                raise Exception("Combined models must have the same architecture")
                
            for key in base_unet_dict:
                assert unet_dict[key].shape == base_unet_dict[key].shape
        
        n_unets = len(unets)

        if unet_kwargs_list is None:
            unet_kwargs_list = [{}] * n_unets
        
        assert n_unets == len(unet_kwargs_list)
        assert n_unets == len(alphas)
        
        
        base_unet = self.init_unet(unets[0], **unet_kwargs_list[0])
        base_unet_dict = base_unet.state_dict()


        # normalize alphas
        alphas = torch.tensor(alphas, dtype=torch.float16)
        alphas /= torch.sum(alphas)


        for key in base_unet_dict:
            base_unet_dict[key] *= alphas[0]


        unets = unets[1:]
        unet_kwargs_list = unet_kwargs_list[1:]
        alphas = alphas[1:]


        for i in range(n_unets - 1):
            unet_dict = self.init_unet(unets[i], **unet_kwargs_list[i]).state_dict()

            ensure_dict_format_validity(base_unet_dict, unet_dict)

            for key in base_unet_dict:
                base_unet_dict[key] += unet_dict[key] * alphas[i]


            print(base_unet.load_state_dict(base_unet_dict))
        
        return base_unet

    
    def init_unets(
        self, unets, alphas=None, unet_kwargs_list=None,
    ):
        n_unets = len(unets)
        
        if unet_kwargs_list is None:
            unet_kwargs_list = [{}] * n_unets
        if alphas is None:
            alphas = [1] * n_unets
        
        assert n_unets == len(unet_kwargs_list)
        
        if n_unets == 1:
            unet = self.init_unet(unets[0], **unet_kwargs_list[0])
        else:            
            unet = self.__init_composed_unet(unets, alphas, unet_kwargs_list)

        # garbage collect
        gc.collect()
        torch.cuda.empty_cache()

        self.unet = unet


    def to_device(self):
        ctx = self.__get_ctx()

        self.clip_model = self.clip_model.eval().to(**ctx)
        self.vae = self.vae.eval().to(**ctx)
        self.unet = self.unet.eval().to(**ctx)


    def intitalize_scheduler(
        self, type, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
    ):
        if type == "ddim":
            return DDIMScheduler(
                beta_start=beta_start,
                beta_end=beta_end,
                beta_schedule=beta_schedule,
                clip_sample=False,
                set_alpha_to_one=False,
            )

        elif type == "lms":
            return LMSDiscreteScheduler(
                beta_start=beta_start,
                beta_end=beta_end,
                beta_schedule=beta_schedule,
            )
        elif type == "dpm":
            return DPMSolverMultistepScheduler(
                beta_start=beta_start,
                beta_end=beta_end,
                beta_schedule=beta_schedule,
            )
        elif type == "euler":
            return EulerDiscreteScheduler(
                beta_start=beta_start,
                beta_end=beta_end,
                beta_schedule=beta_schedule,
                prediction_type="v_prediction",
            )

        else:
            raise Exception(f"{type} is not a supported scheduler.")
    
    
    def generate(self, seed, scheduler_type, **kwargs):
        pipe_type = StableDiffusionPipeline
        if kwargs["image"]:
            pipe_type = StableDiffusionImg2ImgPipeline
        elif kwargs["mask_image"]:
            pipe_type = StableDiffusionInpaintPipeline

        test_pipe = pipe_type(
            vae=self.vae,
            text_encoder=self.clip_model,
            tokenizer=self.clip_tokenizer,
            unet=self.unet,
            scheduler=self.intitalize_scheduler(scheduler_type),
            safety_checker=dummy_safety_checker(),
            feature_extractor=dummy_feature_extractor()
        ).to(self.device)

        kwargs = {k: j for k, j in kwargs.items() if j is not None}

        generator = torch.Generator(device="cuda").manual_seed(seed)

        return test_pipe(**kwargs, generator=generator).images


GampipeSD = SD()


## App

In [None]:
#@title Authenticate { display-mode: "form" }

from huggingface_hub import notebook_login

notebook_login()


In [None]:
#@title ### Settings { display-mode: "form" }

# @markdown ### Supported models:
# @markdown <ul>
# @markdown <li>sd: runwayml/stable-diffusion-v1-5</li>
# @markdown <li>ad: Linaqruf/anything-v3.0</li>
# @markdown <li>wd: hakurei/waifu-diffusion</li>
# @markdown <li>nd: novelai-diffusion</li>
# @markdown <li>mse: stabilityai/sd-vae-ft-mse</li>
# @markdown <li>custom_model1: custom_model1</li>
# @markdown <li>custom_model2: custom_model2</li>
# @markdown </ul>

# @markdown ***

unet_model_names = "custom_model2" #@param {type:"string"}
unet_merge_ratio = "1 1" #@param {type:"string"}
vae_model_name = "ad" #@param {type:"string"}
tokenizer_model_name = "ad" #@param {type:"string"}
encoder_model_name = "custom_model2" #@param {type:"string"}

# @markdown ***

custom_model1 = "eimiss/EimisAnimeDiffusion_1.0v" #@param {type:"string"}
custom_model1_branch = "main" #@param ["main", "fp16"]
custom_model2 = "magnet:?xt=urn:btih:969cabc39c8363aeec824f5530bb0749b6452621&dn=dreambooth-calli-nsfw&tr=udp%3a%2f%2ftracker.opentrackr.org%3a1337%2fannounce&tr=udp%3a%2f%2f9.rarbg.com%3a2810%2fannounce&tr=udp%3a%2f%2ftracker.openbittorrent.com%3a6969%2fannounce&tr=http%3a%2f%2ftracker.openbittorrent.com%3a80%2fannounce&tr=udp%3a%2f%2fopentracker.i2p.rocks%3a6969%2fannounce&tr=https%3a%2f%2fopentracker.i2p.rocks%3a443%2fannounce&tr=udp%3a%2f%2ftracker.torrent.eu.org%3a451%2fannounce&tr=udp%3a%2f%2fopen.stealth.si%3a80%2fannounce&tr=udp%3a%2f%2ftracker1.bt.moack.co.kr%3a80%2fannounce&tr=udp%3a%2f%2ftracker.tiny-vps.com%3a6969%2fannounce&tr=udp%3a%2f%2ftracker.pomf.se%3a80%2fannounce&tr=udp%3a%2f%2ftracker.dler.org%3a6969%2fannounce" #@param {type:"string"}
custom_model2_branch = "main" #@param ["main", "fp16"]


%cd /content/


models = [
    {"name": "sd", "path": "runwayml/stable-diffusion-v1-5", "branch": "main"},
    {"name": "ad", "path": "Linaqruf/anything-v3.0", "branch": "main"},
    {"name": "wd", "path": "hakurei/waifu-diffusion", "branch": "main"},
    {"name": "nd", "path": "https://mega.nz/file/ThZi2CTJ#2Hu_glv74Q60-F-M_0AbWdCtfyL4bTZsoGfBZk9rjXk", "branch": None},
    {"name": "mse", "path": "stabilityai/sd-vae-ft-mse", "branch": "main"},
    {"name": "custom_model1", "path": custom_model1, "branch": custom_model1_branch},
    {"name": "custom_model2", "path": custom_model2, "branch": custom_model2_branch}
]


def string_to_list(string):
    return string.strip().split(" ")

def __get_model_params(model_name):
    for k in models:
        if k["name"] == model_name:
            return k

def get_model_params(model_names):
    return [__get_model_params(k) for k in model_names]


unet_model_names = string_to_list(unet_model_names)
if (unet_merge_ratio := string_to_list(unet_merge_ratio))[0] == "":
    unet_merge_ratio = None
else:
    unet_merge_ratio = [float(w) for w in unet_merge_ratio]


model_names = unet_model_names + [vae_model_name, tokenizer_model_name, encoder_model_name]
model_names = list(set(model_names))

n_magnet_models_used = 0
all_models = []
for k in get_model_params(model_names):
    model = k.copy()

    if k["path"].startswith("https://mega.nz/"):
        clear_dir(download)
        archive_path = pth(download, "animefull-final-pruned_unet.tar.gz")
        folder_path = pth(download, "animefull-final-pruned")

        print(os.getcwd())

        if not archive_path.is_file():
            bash(f"mega-get {k['path']} {str(download)}")
        
        if not folder_path.is_file():
            archive = tarfile.open(str(archive_path))
            archive.extractall("./")

            archive.close()
        
        model["path"] = {
            "unet": str(folder_path),
            "vae": str(folder_path),
            "tokenizer": str(folder_path),
            "encoder": str(folder_path),
        }

    elif k["path"].startswith("magnet:"):
        if n_magnet_models_used > 1:
            raise Exception("Only one magnet model can be used at a time")

        clear_dir(download)
        clear_dir(converted)

        bash(f"aria2c -d {str(download)} --seed-time=0 {k['path']}")

        model_paths = sorted(download.rglob("*.ckpt"))
        model_path = model_paths[0]
        if len(model_paths) > 1:
            print(f"Warning: found multiple models {model_paths}, using {model_path}")

        convert_to_diffusers(model_path, str(converted))

        model["path"] = {
            "unet": str(converted / "unet"),
            "vae": str(converted / "vae"),
            "tokenizer": str(converted),
            "encoder": str(converted),
        }

        n_magnet_models_used += 1

    else:
        model["path"] = {
            "unet": k["path"],
            "vae": k["path"],
            "tokenizer": k["path"],
            "encoder": k["path"],
        }

    all_models.append(model)


def get_by_model_name(model_name):
    for k in all_models:
        if k["name"] == model_name:
            return k


unet_models = [get_by_model_name(k) for k in unet_model_names]
vae_model = get_by_model_name(vae_model_name)
tokenizer_model = get_by_model_name(tokenizer_model_name)
encoder_model = get_by_model_name(encoder_model_name)

vae_subfolder = None if vae_model_name == "mse" else "vae"
if vae_model["name"] == "ad":
    vae_model["path"]["vae"] = "ckpt/anything-v3-vae-swapped"

GampipeSD.init_unets([k["path"]["unet"] for k in unet_models], unet_merge_ratio, [{"revision": k["branch"]} for k in unet_models])
GampipeSD.init_vae(vae_model["path"]["vae"], vae_subfolder, revision=vae_model["branch"])
GampipeSD.init_tokenizer(tokenizer_model["path"]["tokenizer"], revision=tokenizer_model["branch"], subfolder="tokenizer")
GampipeSD.init_encoder(encoder_model["path"]["encoder"], torch_dtype=torch.float16, revision=encoder_model["branch"], subfolder="text_encoder")

GampipeSD.to_device()

In [None]:
# @title ### Upload Autotagger Image { display-mode: "form" }

deepbooru_image_path = upload_img(deepbooru_image_dir)


In [None]:
# @title ### Upload Init Image { display-mode: "form" }

init_image_path = upload_img(init_image_dir)


In [None]:
# @title # Settings { display-mode: "form" }

#@markdown ## General Settings
config = {}

prompt = "sks mori_calliope"  # @param {type:"string"}
negative_prompt = "poo quality, bad quality"  # @param {type:"string"}


# @markdown ***

aspect_ratio = "vertical" #@param ["vertical", "horizontal", "square"]
guidance_scale = 12  # @param {type:"slider", min:-50, max:50, step:1}
steps = 50  # @param {type:"slider", min:1, max:150, step:1}
seed = -1  # @param {type:"integer"}
# @markdown ***

number_of_images = 3  # @param {type:"slider", min:1, max:12, step:1}

# @markdown ***
# @markdown #### Danbooru autotagger
use_auto_tagger = False #@param {type:"boolean"}
max_tags = 16  # @param {type:"slider", min:1, max:50, step:1}
confidence_thresh = 1  # @param {type:"slider", min:1, max:99, step:1}
excluded_tags = "censor, pubic, futa, penis"  # @param {type:"string"}

if use_auto_tagger:
    try:
        deepbooru_image_path
    except:
        raise Exception("No autotagger image was uploaded")
    
    preds = deepbooru(str(deepbooru_image_path))
    prompt = tags_to_prompt(prompt, preds, excluded_tags, max_tags, confidence_thresh / 100)

# @markdown ***
# @markdown ### img2img settings


use_init_image = False #@param {type:"boolean"}
init_image_strength = 18  # @param {type:"slider", min:0, max:100, step:1}


width = height = 512
if aspect_ratio == "vertical":
  height = 768
elif aspect_ratio == "horizontal":
  width = 768


init_image = None
mask_image = None
if use_init_image:
    try:
        init_image_path
    except:
        raise Exception("No init image was uploaded")
    
    init_image = load_init_image(init_image_path)

    display.display(init_image)

    init_image_strength = float(1 - init_image_strength / 100)
    width = height = None
else:
    init_image_strength = None


prompt = prompt.lower().strip()
negative_prompt = negative_prompt.lower().strip()

config["prompt"] = prompt
config["negative_prompt"] = negative_prompt
config["guidance_scale"] = float(guidance_scale)
config["init_image_strength"] = init_image_strength
config["seed"] = seed
config["n_iter"] = number_of_images
config["steps"] = steps

config["width"] = width
config["height"] = height

config_json = json.dumps(config, indent=2)

with open(str(config_path), "w") as f:
    f.write(config_json)

print(get_printable_json_config(config_json))


### Generate

In [None]:
# @title Generate { display-mode: "form" }
# Read config

# set shortcuts required by cross attention control code
# device = GampipeSD.device
# dtype = GampipeSD.torch_dtype

# clip_tokenizer = GampipeSD.clip_tokenizer
# clip_model = GampipeSD.clip_model
# clip = clip_model.text_model
# unet = GampipeSD.unet
# vae = GampipeSD.vae
# scheduler = GampipeSD.scheduler


config = read_config()
printable_config = get_printable_json_config(json.dumps(config, indent=2))

print_config = "<p>"
for line in printable_config.splitlines():
    print_config += line + "<br>"
print_config += "</p>"

config_widget = widgets.HTML(
    value=print_config,
    placeholder="Config",
)


# Get prompt variations
prompts = parse_prompt(config["prompt"])


# Create outputs widgets to control order of outputs
out = widgets.Output(layout=layouts["output"])
out1 = widgets.Output(layout=layouts["output"])
out2 = widgets.Output(layout=layouts["output"])
app = widgets.VBox([out, out1, out2])
display.display(app)

gb = widgets.GridBox([], layout=layouts["img_grid"])
with app.children[0]:
    display.display(config_widget)

with app.children[1]:
    display.display(gb)

# # Include button icons
# with app.children[2]:
#     display.display(
#         display.HTML(
#             """<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"> """
#         )
#     )


# Generate
images = []
i = 0
for prompt in prompts:
    for iter in range(config["n_iter"]):
        # set seed
        seed = config["seed"] + iter

        with app.children[2]:
            img = GampipeSD.generate(
                seed=seed,
                prompt=prompt[0],
                negative_prompt=config["negative_prompt"],
                guidance_scale=config["guidance_scale"],
                strength=config["init_image_strength"],
                image=init_image,
                mask_image=mask_image,
                width=config["width"],
                height=config["height"],
                num_inference_steps=config["steps"],
                scheduler_type="dpm",
            )[0]

        compressed = pil_to_bytes(img, "png")
        images.append(compressed)

        label = generate_label(prompt, config, seed, iter)

        # Create widgets
        info = widgets.HBox(
            [
                widgets.HTML(
                    value=label,
                    placeholder="Label",
                ),
            ]
        )
        image_widget = widgets.VBox(
            [info, widgets.Image(value=compressed, layout=layouts["img"])]
        )

        gb.children = (*gb.children, image_widget)

        i += 1


# SwinIR

## Setup

In [None]:
!rm -r SwinIR
!git clone https://github.com/JingyunLiang/SwinIR.git
!pip install timm

In [None]:
def is_img(fn):
    ext = ["png", "jpg", "jpeg"]

    for ext in ext:
        if fn.endswith(f".{ext}"):
            return True
    return False


def get_images(path):
    files = os.listdir(path)

    for f in files:
        if not is_img(f):
            files.remove(f)
    return files


def force_make_dir(path):
    if path.is_dir():
        shutil.rmtree(path)
    path.mkdir(exist_ok=True)


In [None]:
import os
from os import path as osp
import glob
import shutil

result_root = pth("results")
upload = pth("SwinIR Upload")
swin_result = result_root / "swinir_real_sr_x4_large"

for path in [result_root, upload, swin_result]:
    force_make_dir(path)


## App

### Upscale

In [None]:
#@title Upscale
tiled = True #@param {type:"boolean"}

import cv2 as cv

!rm -r $swin_result/*

# SwinIR-Large
if tiled:
  !python SwinIR/main_test_swinir.py --task real_sr --model_path experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth --folder_lq "$upload" --scale 4 --large_model --tile 256
else:
  !python SwinIR/main_test_swinir.py --task real_sr --model_path experiments/pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth --folder_lq "$upload" --scale 4 --large_model
shutil.rmtree('results/SwinIR_large', ignore_errors=True)
shutil.move('results/swinir_real_sr_x4_large', 'results/SwinIR_large')
for path in sorted(glob.glob(os.path.join('results/SwinIR_large', '*.png'))):
  img = cv.imread(path)
  cv.imwrite(path, img, [cv.IMWRITE_PNG_COMPRESSION, 9])
  # os.rename(path, path.replace('SwinIR.png', 'SwinIR_large.png'))