### Check type of GPU and VRAM available.

In [None]:
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

## Install Requirements

# Clone repos we need
!git clone https://github.com/Jan-Oliver/profaile-pic-dev.git
!git clone https://github.com/huggingface/diffusers.git

In [None]:
# Handle all of the dependencies
!cd /content/profaile-pic-dev && chmod +x dreambooth/setup_dreambooth_env.sh
!cd /content/profaile-pic-dev && dreambooth/setup_dreambooth_env.sh
!cd /content/profaile-pic-dev && chmod +x inference/setup_inference_env.sh
!cd /content/profaile-pic-dev && inference/setup_inference_env.sh

### Connect to GDrive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### Just leave everything as it is

In [None]:
# Store model weights here
OUTPUT_DIR = "/content/stable_diffusion_weights/train"
# We will use prior preservation
CLASS_IMAGES_DIR = None
# 2.1 with 512x512 resolution 
MODEL_NAME = "stabilityai/stable-diffusion-2-1-base"
# Use Floating Point 16 -> Reduce VRAM of GPU
PRECISION = "fp16"
# Make folders
!mkdir -p $INSTANCE_IMAGES_DIR_DRIVE
!mkdir -p $OUTPUT_DIR

: 

### Adapt this:
- YOUR_NAME_ABBREVIATION: Simply use the first letter of your first name and then your last name.
- MALE_OR_FEMALE: If you are a man, set this to man. If you are a woman, set this to woman.


**Why do we need this**?:
- This is based on the Dreambooth Paper. You can find it [here](https://dreambooth.github.io)

In [None]:
YOUR_NAME_ABBREVIATION = "jseidenfuss"
MAN_OR_WOMAN = "man"

In [None]:
concepts_list = [
    {
        "instance_prompt":      f"photo of {YOUR_NAME_ABBREVIATION} {MAN_OR_WOMAN}",
        "class_prompt":         f"photo of a {MAN_OR_WOMAN}",
        "instance_data_dir":    "/content/data/instance_images",
        "class_data_dir":       "/content/data/class_images"
    }
]

# `class_data_dir` contains regularization images
import json
import os
for c in concepts_list:
    os.makedirs(c["instance_data_dir"], exist_ok=True)
    os.makedirs(c["class_data_dir"], exist_ok=True)

with open("concepts_list.json", "w") as f:
    json.dump(concepts_list, f, indent=4)

### Now upload 10 - 15 images into /data/instance_data_dir
- They all have to be cropped to 512x512. 
- To do that use [this mass cropping tool](https://www.birme.net)
- Make sure your it is only you on the images and you have quite a bit of variable poses, etc.

**Specify the amount of images you uploaded**

In [None]:
N_IMAGES = 10

### Memory consumption

Use the table below to choose the best flags based on your memory and speed requirements. Tested on Tesla T4 GPU.


| `fp16` | `train_batch_size` | `gradient_accumulation_steps` | `gradient_checkpointing` | `use_8bit_adam` | GB VRAM usage | Speed (it/s) |
| ---- | ------------------ | ----------------------------- | ----------------------- | --------------- | ---------- | ------------ |
| fp16 | 1                  | 1                             | TRUE                    | TRUE            | 9.92       | 0.93         |
| no   | 1                  | 1                             | TRUE                    | TRUE            | 10.08      | 0.42         |
| fp16 | 2                  | 1                             | TRUE                    | TRUE            | 10.4       | 0.66         |
| fp16 | 1                  | 1                             | FALSE                   | TRUE            | 11.17      | 1.14         |
| no   | 1                  | 1                             | FALSE                   | TRUE            | 11.17      | 0.49         |
| fp16 | 1                  | 2                             | TRUE                    | TRUE            | 11.56      | 1            |
| fp16 | 2                  | 1                             | FALSE                   | TRUE            | 13.67      | 0.82         |
| fp16 | 1                  | 2                             | FALSE                   | TRUE            | 13.7       | 0.83          |
| fp16 | 1                  | 1                             | TRUE                    | FALSE           | 15.79      | 0.77         |


Add `--gradient_checkpointing` flag for around 9.92 GB VRAM usage.

remove `--use_8bit_adam` flag for full precision. Requires 15.79 GB with `--gradient_checkpointing` else 17.8 GB.

remove `--train_text_encoder` flag to reduce memory usage further, degrades output quality.

### Launch the training

In [None]:
## Dont change
NUM_CLASS_IMAGES = N_IMAGES * 12
MAX_NUM_STEPS = N_IMAGES * 80
LR_WARMUP_STEPS = int(MAX_NUM_STEPS / 10)

In [None]:
!accelerate launch /content/profaile-pic-dev/dreambooth/train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse" \
  --output_dir=$OUTPUT_DIR \
  --with_prior_preservation --prior_loss_weight=1.0 \
  --seed=1337 \
  --resolution=512 \
  --train_batch_size=1 \
  --train_text_encoder \
  --mixed_precision=$PRECISION \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-6 \
  --lr_scheduler="polynomial" \
  --lr_warmup_steps=$LR_WARMUP_STEPS \
  --num_class_images=$NUM_CLASS_IMAGES \
  --sample_batch_size=4 \
  --max_train_steps=$MAX_NUM_STEPS \
  --save_interval=4000 \
  --save_min_steps=4000 \
  --save_sample_prompt= f"medium shot side profile portrait photo of the {YOUR_NAME_ABBREVIATION} {MAN_OR_WOMAN} warrior chief, tribal panther make up, blue on red, looking away, serious eyes, 50mm portrait, photography, hard rim lighting photography –ar 2:3 –beta –upbeta" \
  --save_sample_negative_prompt="blender, ugly, multiple hands, bad anatomy, bad proportions, unrealistic, full body, cropped, lowres, poorly drawn face, out of frame, poorly drawn hands, double, blurred, disfigured, deformed, repetitive, black and white" \
  --n_save_sample=4 \
  --save_guidance_scale=7.5 \
  --save_infer_steps=50 \
  --concepts_list="concepts_list.json" \
  --wandb_group_name=$GROUP_NAME \
  --wandb_project_name=$PROJECT_NAME  \
  --use_8bit_adam

# Reduce the `--save_interval` to lower than `--max_train_steps` to save weights from intermediate steps.
# `--save_sample_prompt` can be same as `--instance_prompt` to generate intermediate samples (saved along with weights in samples directory).

#Specify the weights directory to use (leave blank for latest)
WEIGHTS_DIR = os.path.join(OUTPUT_DIR, "2250")
print(f"[*] WEIGHTS_DIR={WEIGHTS_DIR}")

### Preview the results

In [None]:
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

weights_folder = OUTPUT_DIR
folders = sorted([f for f in os.listdir(weights_folder) if f != "0"], key=lambda x: int(x))

row = len(folders)
col = len(os.listdir(os.path.join(weights_folder, folders[0], "samples")))
scale = 4
fig, axes = plt.subplots(row, col, figsize=(col*scale, row*scale), gridspec_kw={'hspace': 0, 'wspace': 0})

for i, folder in enumerate(folders):
    folder_path = os.path.join(weights_folder, folder)
    image_folder = os.path.join(folder_path, "samples")
    images = [f for f in os.listdir(image_folder)]
    for j, image in enumerate(images):
        if row == 1:
            currAxes = axes[j]
        else:
            currAxes = axes[i, j]
        if i == 0:
            currAxes.set_title(f"Image {j}")
        if j == 0:
            currAxes.text(-0.1, 0.5, folder, rotation=0, va='center', ha='center', transform=currAxes.transAxes)
        image_path = os.path.join(image_folder, image)
        img = mpimg.imread(image_path)
        currAxes.imshow(img, cmap='gray')
        currAxes.axis('off')
        
plt.tight_layout()
plt.savefig('grid.png', dpi=72)

### Push model to Google Drive so you can use it later

- Adapt this path: PATH_TO_COPY_MODEL_TO_GDRIVE

In [None]:
import shutil
PATH_TO_COPY_MODEL_TO_GDRIVE = ""
shutil.copytree(WEIGHTS_DIR, )

## Part 2 Inference

### Util stuff

In [None]:
# Util functions
from PIL import Image
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, DDIMScheduler, EulerDiscreteScheduler
from IPython.display import display

if PRECISION == "fp16":
    fp16 = True
else:
    fp16 = False

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

model_path = WEIGHTS_DIR             # If you want to use previously trained model saved in gdrive, replace this with the full path of model in gdrive

scheduler_euler = EulerDiscreteScheduler.from_pretrained(MODEL_NAME, subfolder="scheduler")
scheduler_ddim = DDIMScheduler.from_pretrained(MODEL_NAME,subfolder="scheduler")

# First for Euler
if fp16:
  pipe = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler_euler, safety_checker=None, torch_dtype=torch.float16).to("cuda")
else:
  pipe = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler_euler, safety_checker=None, torch_dtype=torch.float32).to("cuda")

# Then for DDIM
g_cuda = torch.Generator(device='cuda').manual_seed(52362)

num_samples = 10
guidance_scale = 7
num_inference_steps = 70
height = 512
width = 512

### Play around with these parameters.
- Note: If you want to be in the image you have to add your abbreviation and gender into the positive promt.
- Run the next two cells to generate the images.

In [None]:
# Example prompt 1
prompt = f"photo of face and shoulders of {YOUR_NAME_ABBREVIATION} {MAN_OR_WOMAN} in a blue suit from the front, front view, closeup, centered frame, symmetric, studio lighting, clear and realistic face, uhd faces, pexels, 85mm, casual pose, 35mm film roll photo, hard light, detailed skin texture, masterpiece, sharp focus, pretty, lovely, adorable, attractive, hasselblad, candid street portrait"
negative_prompt = "blender, ugly, multiple hands, bad anatomy, bad proportions, unrealistic, full body, cropped, lowres, poorly drawn face, out of frame, poorly drawn hands, double, blurred, disfigured, deformed, repetitive, black and white "

# Example prompt 2
#prompt = f"photo of face and shoulders of {YOUR_NAME_ABBREVIATION} {MAN_OR_WOMAN} in a blue suit from the front, front view, closeup, centered frame, symmetric, studio lighting, clear and realistic face, uhd faces, pexels, 85mm, casual pose, 35mm film roll photo, hard light, detailed skin texture, masterpiece, sharp focus, pretty, lovely, adorable, attractive, hasselblad, candid street portrait"
#negative_prompt = "blender, ugly, multiple hands, bad anatomy, bad proportions, unrealistic, full body, cropped, lowres, poorly drawn face, out of frame, poorly drawn hands, double, blurred, disfigured, deformed, repetitive, black and white "

# Example prompt 3
#prompt = f"photo of face and shoulders of {YOUR_NAME_ABBREVIATION} {MAN_OR_WOMAN} in a blue suit from the front, front view, closeup, centered frame, symmetric, studio lighting, clear and realistic face, uhd faces, pexels, 85mm, casual pose, 35mm film roll photo, hard light, detailed skin texture, masterpiece, sharp focus, pretty, lovely, adorable, attractive, hasselblad, candid street portrait"
#negative_prompt = "blender, ugly, multiple hands, bad anatomy, bad proportions, unrealistic, full body, cropped, lowres, poorly drawn face, out of frame, poorly drawn hands, double, blurred, disfigured, deformed, repetitive, black and white "

In [None]:
images = pipe(
    prompt,
    height=height,
    width=width,
    negative_prompt=negative_prompt,
    num_images_per_prompt=num_samples,
    num_inference_steps=num_inference_steps,
    guidance_scale=guidance_scale,
    generator=g_cuda
).images
for image in images:
  display(image)