# Inference with Dreambooth

We load pretrained models and ask to generate images based on the new vocabulary learned

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%ls

[0m[01;34mdrive[0m/  [01;34msample_data[0m/


In [5]:
# Links to go from dreambooth to drive from here to save the results
db_to_drive_path = './../../../drive/MyDrive/SketchToReality/'
drive_to_db_path = './../../../../diffusers/examples/dreambooth' # In case, to come back :)

In [6]:
%ls

README.md              requirements.txt             train_dreambooth_flax.py
README_sdxl.md         test_dreambooth_lora_edm.py  train_dreambooth_lora.py
requirements_flax.txt  test_dreambooth_lora.py      train_dreambooth_lora_sdxl.py
requirements_sdxl.txt  test_dreambooth.py           train_dreambooth.py


# Loading the original images

We need this to get the labels corresponding to the images to generate captions for our dreambooth models

In [1]:
import pandas as pd
from PIL import Image
def dataframe_toPILlist(df, show_img = False):
    image_list = []
    label_list = []
    for i in range(df.shape[0]):
        row = df.iloc[i]
        path = row['path']
        lab = row['label']

        img = Image.open(path)
        image_list.append(img)
        label_list.append(lab)

        if show_img:
            print(f'\nLabel {lab}')
            display(img)

    return image_list, label_list

In [5]:
original_test_dataset_path = './Data/data/sketch/eval_image_captioning.csv'
eval_dataframe = pd.read_csv(original_test_dataset_path)
PILimages, labels = dataframe_toPILlist(eval_dataframe, show_img=False) 

# Inference

In [8]:
from PIL import Image
import numpy as np
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [9]:
%ls

README.md              requirements.txt             train_dreambooth_flax.py
README_sdxl.md         test_dreambooth_lora_edm.py  train_dreambooth_lora.py
requirements_flax.txt  test_dreambooth_lora.py      train_dreambooth_lora_sdxl.py
requirements_sdxl.txt  test_dreambooth.py           train_dreambooth.py


In [10]:
from torch import autocast
from diffusers import StableDiffusionPipeline
import torch

model_id = "./drive/MyDrive/SketchToReality/dreambooth_cat_results" #@param {type:"string"}
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
def generate_images(pipe, labels, resample_if_nsfw=True):
    generated_images = []
    for l in labels:
        prompt = f'A picture of {l} in the style of sks'
        num_samples=1
        image = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=50, guidance_scale=7.5).images
        if resample_if_nsfw:
            n_try  =0
            while np.mean(image) == 0 and n_try<=100: # Totally black image, so mean of pixels is 0
                image = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=50, guidance_scale=7.5).images
                n_try+=1 # To be sure the algorithm stops, even if the image keeps being nsfw
        generated_images.extend(generated_images)
    
    return generated_images