In [None]:
#@title Mount Drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title Install required dependencies

!pip install colormap
!pip install easydev
!pip install diffusers transformers xformers git+https://github.com/huggingface/accelerate.git
!pip install opencv-contrib-python
!pip install controlnet_aux

In [None]:
#@title Import Dependencies

from typing import List, Tuple, Dict, Union, Callable, TypedDict

from colormap import rgb2hex, hex2rgb

import os
import cv2
import random
import ipywidgets as widgets

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image

import torch

from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler


class PromptDesigns(TypedDict):
    category:str
    design:List[str]


class CategoryMaps(TypedDict):
    category:str
    image:List[Image.Image]


device: str = "cuda"

In [None]:
#@title Instantiate StableDiffusion 2.1

sd2_pipe = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
)
sd2_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd2_pipe.scheduler.config)

In [None]:
#@title Core Functionality


def make_prompt_designs(excel_path:str) -> PromptDesigns:
    df = pd.read_excel(excel_path)
    cats = df['category'].unique()
    df = df.set_index('category')
    prompt_designs: PromptDesigns = {}
    for cat in cats: prompt_designs[cat] = df.loc[cat]['prompt'].tolist()
    return prompt_designs


def make_color_prompts(prompt_list:List[str], category:str, palette:List[str]) -> List[str]:
    cat_prompts = []

    for sample_prompt in prompt_list:
        prompt_cat = []
        prompt_split = sample_prompt.split('#')

        n = len([c for c in sample_prompt if c == '#'])
        m = len(palette)
        if n > m: palette = [palette[random.randint(0, m-1)] for _ in range(n)]

        for i in range(n): prompt_cat.append(''.join([prompt_split[i], palette[i]]))
        prompt_cat.append(prompt_split[-1])
        cat_prompts.append(''.join(prompt_cat).replace('$', category))

    return cat_prompts


def image_grid(imgs:List[np.array], rows:int, cols:int) -> Image.Image:
    assert len(imgs) == rows * cols

    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid


def view_generated_imgs(gen_imgs:CategoryMaps):
    fig = plt.figure(figsize=(20, 20))
    for cat, imgs in gen_imgs.items():
        for i, img in enumerate(imgs):
            fig.add_subplot(len(imgs), 1, i+1)
            plt.axis('off')
            plt.imshow(img)
        plt.show()


def sd2_gen_images(
    palette:List[str],
    prompt_designs:PromptDesigns,
    detail_prompt:str='ambient lighting, extremely detailed, photorealistic',
    negative_prompt:str='',
    gen_kwargs:Dict={
        "guidance_scale":10,
        "num_inference_steps":20,
        "num_images_per_prompt":4
    },
) -> CategoryMaps:
    gen_imgs: CategoryMaps = {}
    num_images: int = gen_kwargs['num_images_per_prompt']

    for cat, prompt_list in prompt_designs.items():
        cat_prompts = make_color_prompts(prompt_list, cat, palette)
        for cat_prompt in cat_prompts:
            input_prompt = f"{cat_prompt}, {detail_prompt}"
            print(f"input prompt: {input_prompt}")
            output = sd2_pipe(
                input_prompt,
                negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality, wooden, gradient, greek sculpture, floral design, face, text, table",
                **gen_kwargs
            )

            if cat not in gen_imgs: gen_imgs[cat] = []
            gen_imgs[cat].extend(output.images)

    return gen_imgs

In [None]:
#@title Load up the prompts from Excel sheet
prompt_designs = make_prompt_designs('prompts_ai_gen.xlsx')

In [None]:
#@title Build the pallete

primary = "red" #@param {type:"string"}
secondary = "metallic gold" #@param {type:"string"}
tertiary = "blue" #@param {type:"string"}
background = "white" #@param {type:"string"}

user_palette:List[str] = [primary, secondary, tertiary, background]

In [None]:
#@title Pick a category

#@markdown To choose a cateogory simple double-click on an option

categories = list(prompt_designs.keys())
options = widgets.SelectMultiple(
    options=categories,
    value=[categories[0]],
    description='Categories',
    disabled=False
)
options

In [None]:
#@title Category you've picked

user_cat: str = options.value[0]
print(f"""You have picked "{user_cat}" category""")

In [None]:
#@title Generate Images for category

#@markdown Note: num_images is the number of images generated for each prompt hence the total number of images generated will be num_images * num_prompts

sd2_pipe.to(device)

custom_prompt: PromptDesigns = {
    user_cat: prompt_designs[user_cat]
}

guidance_scale = 10 #@param {type:"number"}
num_steps = 20 #@param {type:"integer"}
num_images = 1 #@param {type:"integer"}

gen_imgs_sd2 = sd2_gen_images(user_palette, custom_prompt, gen_kwargs={
    "guidance_scale": guidance_scale,
    "num_inference_steps": num_steps,
    "num_images_per_prompt": num_images
})

In [None]:
#@title View the generated images

view_generated_imgs(gen_imgs_sd2)

In [None]:
#@title Save all the generated images

#@markdown Note: all images will be zipped into the `output_dir` folder

output_dir = 'outputs' #@param {type:"string"}
os.makedirs(output_dir, exist_ok=True)

for cat, imgs in gen_imgs_sd2.items():
    for i, img in enumerate(imgs):
        save_cat = user_cat.lower()
        save_cat = save_cat.replace('/', '_').replace(' ', '_')
        img.save(
            f"{output_dir}/gen_img_{save_cat}_{i}.png"
        )

!zip -r {output_dir}_images.zip {output_dir}

In [None]:
#@title Download the `.zip` file onto your computer