In [1]:
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
import os
import math
import pickle
from PIL import Image
from datetime import datetime

In [2]:
from huggingface_hub import notebook_login
#notebook_login()

In [2]:
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

# So far it seems the DDIM scheduler produces better results than the default, so replace PNDM with DDIM.
default_scheduler = pipe.scheduler
ddim_scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False)
#pndm_scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
#lms_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
#ddim_scheduler_oil_painting = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="squaredcos_cap_v2", clip_sample=False)
pipe.scheduler = ddim_scheduler


In [24]:
basepath = "C:/Users/Skylar/Pictures/AIGen/"

def ensure_write_to(filename):
    path = os.path.dirname(filename)
    if not os.path.exists(path):
        os.makedirs(path)

def save_text_to_file(text, filename):
    ensure_write_to(filename)
    with open(filename, "a") as fp:
        fp.write(text)

def get_text_from_file(text, filename, condense = True):
    lines = []
    with open(filename, "r") as fp:
        lines = fp.readlines()
    return "\n".join(lines) if not condense else "".join(lines)

def save_image(metadata, image, filename="auto"):
    prompt = metadata['prompt']
    folder = metadata['category']
    folder = folder + "/" if folder else ""
    now = datetime.now()
    current_time = now.strftime("%H%M%S")

    if filename == "auto":
        filename = prompt
    filename = filename.replace(" ", "_").replace(",", "") + "_" + current_time
    
    imgfile = "%s%s%s.png" % (basepath, folder, filename)
    ensure_write_to(imgfile)
    while os.path.exists(imgfile):
        imgfile = imgfile + "1"
        filename = filename + "1"
    image.save(f"%s" % imgfile)
    
    metafile = "%s%s%s.meta" % (basepath, folder, filename)
    if os.path.exists(metafile):
        print("Metadata found, not overwriting.")
    else:
        ensure_write_to(metafile)
        with open(metafile, "wb") as fp:
            pickle.dump(metadata, fp)

def image_grid(imgs, rows='auto', cols='auto'):
    if cols == 'auto':
        cols = 4 if num_images > 4 else num_images
    if rows == 'auto':
        rows = math.ceil(num_images / cols) if num_images > cols else 1
    
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

def run_pipe(metadata):
    with autocast("cuda"):
        img = pipe(
            metadata['prompt'],
            latent=metadata['latents'],
            num_inference_steps=metadata['num_iterations'],
            width=metadata['width'],
            height=metadata['height']
        )["sample"][0]
    return img

def run_multiple(metadata, count):
    images = []
    for i in range(0, count):
        images.append(run_pipe(metadata))
    return images

def generate_latent(width=512, height=512):
    generator = torch.Generator(device="cuda")
    seed = generator.seed()
    generator = generator.manual_seed(seed)
    
    image_latents = torch.randn(
        (1, pipe.unet.in_channels, height // 8, width // 8),
        generator = generator,
        device = "cuda"
    )
    return (seed, image_latents)

def generate_metadata(prompt, category="uncategorised", num_iterations=50, width=512, height=512, latent=None):
    if latent is None:
        latent = generate_latent(width, height)
    return {
        'prompt': prompt,
        'category': category,
        'num_iterations': num_iterations,
        'width': width,
        'height': height,
        'seed': latent[0],
        'latents': latent[1],
        'scheduler': "ddim"
    }

def load_metadata(filename):
    with open("%s%s.meta" % (basepath, filename), "rb") as fp:
        metadata = pickle.load(fp)
    return metadata

def run_prompts(prompts, num_images=8, num_iterations=50, height=512, width=512, folder=""):
    for filename in prompts.keys():
        prompt = prompts[filename]['prompt']
        print(prompt)
        if 'folder' in prompts[filename]:
            folder = prompts[filename]['folder']
        else:
            folder = filename.replace(" ", "/")
            
        metadata = generate_metadata(prompt, category=folder, num_iterations=num_iterations, width=width, height=height)
        for i in range(0, num_images):
            print("%i / %i" % (i+1, num_images))
            image = run_pipe(metadata)
            save_image(metadata, image, filename="%s_%i" % (filename, i))

def refill_folder(folder, num_images=8, num_iterations=50, height=512, width=512):
    prompt = get_text_from_file("%s/prompt.txt" % folder)
    run_prompts({"refill": {"prompt": prompt, "folder": folder}}, num_images, num_iterations, height, width)

In [22]:
# Single image testing
prompt = "photo of a gorgeous cat, in the style of stefan kostic, realistic, sharp focus, 8k high definition, insanely detailed, intricate, elegant, art by stanley lau and artgerm"
metadata = generate_metadata(prompt, num_iterations=60, width=512, height=640)
image = run_pipe(metadata)
image

In [18]:
save_image(metadata, image, filename="metatest")

In [23]:
metadata = load_metadata("uncategorised/metatest_145607")
image = run_pipe(metadata)
image

In [21]:
save_image(metadata, image, filename="metatest2")

In [None]:
# ------------------------------------------------------------------------------------------------------------------------

In [47]:
# Run a list of prompts.
run_prompts({
    "realistic forest_witch": {
        "prompt": "photo of a forest witch, in the style of stefan kostic, realistic, sharp focus, 8k high definition, insanely detailed, intricate, elegant, art by stanley lau and artgerm",
    },
}, 20, 100, height=640, width=512)


1 / 20


0it [00:00, ?it/s]

2 / 20


0it [00:00, ?it/s]

3 / 20


0it [00:00, ?it/s]

4 / 20


0it [00:00, ?it/s]

5 / 20


0it [00:00, ?it/s]

6 / 20


0it [00:00, ?it/s]

7 / 20


0it [00:00, ?it/s]

8 / 20


0it [00:00, ?it/s]

9 / 20


0it [00:00, ?it/s]

10 / 20


0it [00:00, ?it/s]

11 / 20


0it [00:00, ?it/s]

12 / 20


0it [00:00, ?it/s]

13 / 20


0it [00:00, ?it/s]

14 / 20


0it [00:00, ?it/s]

15 / 20


0it [00:00, ?it/s]

16 / 20


0it [00:00, ?it/s]

17 / 20


0it [00:00, ?it/s]

18 / 20


0it [00:00, ?it/s]

19 / 20


0it [00:00, ?it/s]

20 / 20


0it [00:00, ?it/s]

1 / 20


0it [00:00, ?it/s]

2 / 20


0it [00:00, ?it/s]

3 / 20


0it [00:00, ?it/s]

4 / 20


0it [00:00, ?it/s]

5 / 20


0it [00:00, ?it/s]

6 / 20


0it [00:00, ?it/s]

7 / 20


0it [00:00, ?it/s]

8 / 20


0it [00:00, ?it/s]

9 / 20


0it [00:00, ?it/s]

10 / 20


0it [00:00, ?it/s]

11 / 20


0it [00:00, ?it/s]

12 / 20


0it [00:00, ?it/s]

13 / 20


0it [00:00, ?it/s]

14 / 20


0it [00:00, ?it/s]

15 / 20


0it [00:00, ?it/s]

16 / 20


0it [00:00, ?it/s]

17 / 20


0it [00:00, ?it/s]

18 / 20


0it [00:00, ?it/s]

19 / 20


0it [00:00, ?it/s]

20 / 20


0it [00:00, ?it/s]

1 / 20


0it [00:00, ?it/s]

2 / 20


0it [00:00, ?it/s]

3 / 20


0it [00:00, ?it/s]

4 / 20


0it [00:00, ?it/s]

5 / 20


0it [00:00, ?it/s]

6 / 20


0it [00:00, ?it/s]

7 / 20


0it [00:00, ?it/s]

8 / 20


0it [00:00, ?it/s]

9 / 20


0it [00:00, ?it/s]

10 / 20


0it [00:00, ?it/s]

11 / 20


0it [00:00, ?it/s]

12 / 20


0it [00:00, ?it/s]

13 / 20


0it [00:00, ?it/s]

14 / 20


0it [00:00, ?it/s]

15 / 20


0it [00:00, ?it/s]

16 / 20


0it [00:00, ?it/s]

17 / 20


0it [00:00, ?it/s]

18 / 20


0it [00:00, ?it/s]

19 / 20


0it [00:00, ?it/s]

20 / 20


0it [00:00, ?it/s]

1 / 20


0it [00:00, ?it/s]

2 / 20


0it [00:00, ?it/s]

3 / 20


0it [00:00, ?it/s]

4 / 20


0it [00:00, ?it/s]

5 / 20


0it [00:00, ?it/s]

6 / 20


0it [00:00, ?it/s]

7 / 20


0it [00:00, ?it/s]

8 / 20


0it [00:00, ?it/s]

9 / 20


0it [00:00, ?it/s]

10 / 20


0it [00:00, ?it/s]

11 / 20


0it [00:00, ?it/s]

12 / 20


0it [00:00, ?it/s]

13 / 20


0it [00:00, ?it/s]

14 / 20


0it [00:00, ?it/s]

15 / 20


0it [00:00, ?it/s]

16 / 20


0it [00:00, ?it/s]

17 / 20


0it [00:00, ?it/s]

18 / 20


0it [00:00, ?it/s]

19 / 20


0it [00:00, ?it/s]

20 / 20


0it [00:00, ?it/s]

1 / 20


0it [00:00, ?it/s]

2 / 20


0it [00:00, ?it/s]

3 / 20


0it [00:00, ?it/s]

4 / 20


0it [00:00, ?it/s]

5 / 20


0it [00:00, ?it/s]

6 / 20


0it [00:00, ?it/s]

7 / 20


0it [00:00, ?it/s]

8 / 20


0it [00:00, ?it/s]

9 / 20


0it [00:00, ?it/s]

10 / 20


0it [00:00, ?it/s]

11 / 20


0it [00:00, ?it/s]

12 / 20


0it [00:00, ?it/s]

13 / 20


0it [00:00, ?it/s]

14 / 20


0it [00:00, ?it/s]

15 / 20


0it [00:00, ?it/s]

16 / 20


0it [00:00, ?it/s]

17 / 20


0it [00:00, ?it/s]

18 / 20


0it [00:00, ?it/s]

19 / 20


0it [00:00, ?it/s]

20 / 20


0it [00:00, ?it/s]