In [None]:
import os
import shutil
import subprocess
import urllib.request
import configparser
import random
import concurrent.futures
import time
from library.studiolabs_utils import (
    clone_or_update_repo,
    install_dependencies,
    create_dirs,
    download_model,
    get_config_from_folder,
    preprocess_folder,
    update_model_paths,
    get_config_file_paths,
    get_config_dict_from_ini,
    get_train_args
)


print('1.0 DEFINE DIRECTORIES')
dirs = create_dirs()
print('2.0 CLONE REPO AND INSTALL DIRECTORIES')
# Read the config.ini file
default_config = configparser.ConfigParser()
default_config.read(dirs['default_config'])

print(dirs['accelerate_config'])
clone_or_update_repo(
    url=default_config.get('UserSettings', 'repo_url'),
    save_directory=dirs['root_dir'],
    branch = default_config.get('UserSettings', 'branch')
    )


install_dependencies(
    dirs,
    verbose=default_config.getboolean('UserSettings', 'verbose'), 
    install_xformers=default_config.getboolean('UserSettings', 'install_xformers')
    )

command = "pip cache purge"
subprocess.run(command, shell=True)
command = "accelerate config default"
subprocess.run(command, shell=True)


from PIL import Image, ImageDraw, ImageFont
import textwrap
import matplotlib.font_manager as fm
from huggingface_hub import login
from huggingface_hub import HfApi
from huggingface_hub.utils import validate_repo_id, HfHubHTTPError

def authenticate(write_token):
    login(write_token, add_to_git_credential=True)
    api = HfApi()
    return api.whoami(write_token), api


def create_repo(api, user, orgs_name, repo_name, repo_type, make_private=False):
    global model_repo
    global datasets_repo
    
    if orgs_name == "":
        repo_id = user["name"] + "/" + repo_name.strip()
    else:
        repo_id = orgs_name + "/" + repo_name.strip()

    try:
        validate_repo_id(repo_id)
        api.create_repo(repo_id=repo_id, repo_type=repo_type, private=make_private)
        print(f"{repo_type.capitalize()} repo '{repo_id}' didn't exist, creating repo")
    except HfHubHTTPError as e:
        print(f"{repo_type.capitalize()} repo '{repo_id}' exists, skipping create repo")
    
    if repo_type == "model":
        model_repo = repo_id
        print(f"{repo_type.capitalize()} repo '{repo_id}' link: https://huggingface.co/{repo_id}\n")
    else:
        datasets_repo = repo_id
        print(f"{repo_type.capitalize()} repo '{repo_id}' link: https://huggingface.co/datasets/{repo_id}\n")

user, api = authenticate(write_token)

# Get the parameter values from the config file
model_url = default_config.get('DownloadModels', 'model_url')
vae_url = default_config.get('DownloadModels', 'vae_url')


# Download the model file if it doesn't exist
download_model(model_url, dirs['pretrained_dir'])
download_model(vae_url, dirs['vae_dir'])
    
print('4.1 DATA CLEANING + BLIP Captioning + Custom Caption/Tag')
# Use BLIP for general images
# Use Waifu for anime/manga images
# Specified in the config file

skip_cleaning_captioning=False

subfolders = [
    os.path.join(dirs['train_data_dir'], f)
    for f in os.listdir(dirs['train_data_dir'])
    if os.path.isdir(os.path.join(dirs['train_data_dir'], f)) 
    and not f.startswith('.')
]


for folder in subfolders:
    if not skip_cleaning_captioning:
        preprocess_folder(folder, dirs, default_config)
        
        #@title ## 5 Start Training

    sample_prompt, config_file = get_config_file_paths(folder, dirs)
    
    if config_file is None:
        train_config = default_config
    else:
        train_config = configparser.ConfigParser()
        train_config.read(config_file)


    # Create an empty dictionary
    token_dictionary = get_config_dict_from_ini(os.path.join(dirs['default_configs_dir'],'my_tokens.ini'))


    accelerate_conf = {
        "config_file" : dirs['accelerate_config'],
        "num_cpu_threads_per_process" : 1,
    }

    train_confing2 = {
        "train_data_dir" : folder,
        "huggingface_token" : token_dictionary['huggingface_token'],
        "sample_prompts" : sample_prompt
    }
    
    train_config_full = {}
    train_config_full.update(dict(train_config.items('model_arguments')))
    train_config_full.update(dict(train_config.items('optimizer_arguments')))
    train_config_full.update(dict(train_config.items('dataset_arguments')))
    train_config_full.update(dict(train_config.items('dataset')))
    train_config_full.update(dict(train_config.items('dataset_subset')))
    train_config_full.update(dict(train_config.items('general')))
    train_config_full.update(dict(train_config.items('training_arguments')))
    train_config_full.update(dict(train_config.items('sample_prompt_arguments')))
    train_config_full.update(dict(train_config.items('sample_prompt_arguments')))
    train_config_full.update(dict(train_config.items('saving_arguments')))
    train_config_full.update(dict(train_config.items('huggingface_arguments')))
    train_config_full.update(token_dictionary)
    train_config_full.update(train_confing2)
    

    accelerate_args = get_train_args(accelerate_conf)
    train_args = get_train_args(train_config_full)
    print('TRAIN ARGS: ',train_args)
    final_args = f"accelerate launch {accelerate_args} train_db.py {train_args}"
    print(final_args)
    subprocess.run(final_args, shell=True, check=True)
    trained_model_name=train_config.get('training_arguments', 'output_name')
    output_directory=os.path.join(dirs['dreambooth_output_dir'],'sample',trained_model_name)
    # Fetch image files from the directory
    image_locations = fetch_image_locations(output_directory)
    prompts = sorted(list(set([img_name.split('_')[1] for img_name in image_locations])))
    epochs = sorted(list(set([img_name.split('_')[2][1:] for img_name in image_locations])))
    font, fontsize = get_font()
    
    grid_name=trained_model_name+'_grid.png'
    grid_save_dir=os.path.join(dirs['dreambooth_sample_dir'],grid_name)
    max_image_width = 0
    max_image_height = 0

    # Find the maximum width and height among all images
    for image_file in image_locations:
        image_path = os.path.join(output_directory, image_file)
        img = Image.open(image_path)
        width, height = img.size
        max_image_width = max(max_image_width, width)
        max_image_height = max(max_image_height, height)

    canvas_width = ((max_image_width + padding) * len(epochs)) + padding + left_space
    canvas_height = ((max_image_height + padding) * len(prompts)) + padding + top_space
    canvas = Image.new('RGB', (canvas_width, canvas_height), 'white')
    draw = ImageDraw.Draw(canvas)

    # Iterate over the image files and place them in the grid
    for i,epoch in enumerate(epochs):
        epoch_x = ((max_image_width + padding) * i) + padding + left_space 
        epoch_y = text_offset 
        draw.text((epoch_x, epoch_y), f'Epoch: {epoch}', fill='black', font=font)
        for k,prompt in enumerate(prompts):
        
            text=prompt
            #Shorten the text if it exceeds 50 characters
            if len(text) > 200:
                text = text[:200] + '...'

            # Wrap the text to fit within the limit
            wrapped_text = textwrap.wrap(text, width=int(text_width_limit / fontsize))

            # Calculate the total height required for the wrapped text
            total_text_height = len(wrapped_text) * fontsize

            # Calculate the starting position to center the text vertically
            y_start = top_space + padding + k * (padding + max_image_height)
            x = text_offset * 4

            # Draw the wrapped text
            for line in wrapped_text:
                text_bbox = draw.textbbox((0, y_start), line, font=font)
                text_width = text_bbox[2] - text_bbox[0]
                text_height = text_bbox[3] - text_bbox[1]
                draw.text((x, y_start), line, font=font, fill='black')
                y_start += fontsize
        
            for img_name in image_locations:
                if epoch == img_name.split('_')[2][1:] and prompt == img_name.split('_')[1]:
                    image_path = os.path.join(output_directory, img_name)
                    img = Image.open(image_path)

                    # Calculate the position of the image in the grid
                    x = (max_image_width + padding) * i + padding + left_space
                    y = top_space + (max_image_height + padding) * k + padding

                    # Paste the image onto the canvas
                    canvas.paste(img, (x, y))

    # Save the final image grid
    canvas.save(grid_save_dir)
    print('Saved')
    # @markdown Login to Huggingface Hub
    # @markdown > Get **your** huggingface `WRITE` token [here](https://huggingface.co/settings/tokens)
    write_token = token_dictionary['huggingface_token']
    # @markdown Fill this if you want to upload to your organization, or just leave it empty.
    # @markdown If your model/dataset repo does not exist, it will automatically create it.
    make_private = True
    user, api = authenticate(write_token)
          
    # @markdown This will be uploaded to model repo
    #model_path = os.path.join(dreambooth_output_dir,"Hen1.ckpt")  # @param {type :"string"}
    path_in_repo = ""  # @param {type :"string"}
    # @markdown Now you can save your config file for future use
    # @markdown Other Information
    commit_message = "uploading model"  # @param {type :"string"}

    if not commit_message:
        commit_message = "feat: upload checkpoint"

    for f in os.listdir(dirs['dreambooth_sample_dir']):
        if not f.startswith('.'):
            create_repo(api, user, orgs_name, f, "model", make_private)

    print("uploading to: ",user["name"] + "/" + trained_model_name)
    
    if config_file is None:
        upload_config=dirs['default_config']
    else:
        upload_config=config_file
        
    print("uploading config files")
    api.upload_file(
        path_or_fileobj=upload_config)
        path_in_repo="config/config.ini",
        repo_id=user["name"] + "/" + trained_model_name,
        repo_type=None,
    )
    
    print("uploading sample prompt")
    api.upload_file(
        path_or_fileobj=sample_prompt)
        path_in_repo="config/sample_prompt.txt",
        repo_id=user["name"] + "/" + trained_model_name,
        repo_type=None,
    )

    print("uploading sample images")
    api.upload_folder(
        folder_path=os.path.join(dirs['dreambooth_sample_dir'], trained_model_name)
        repo_id=user["name"] + "/" + trained_model_name,
        repo_type=None,
        path_in_repo="samples",
    )
    print("uploading sample image grid")
    api.upload_file(
        path_or_fileobj=grid_save_dir)
        path_in_repo=grid_name,
        repo_id=user["name"] + "/" + trained_model_name,
        repo_type=None,
    )

In [None]:
# @title ## 6.2. Inference
v2 = False  # @param {type:"boolean"}
v_parameterization = False  # @param {type:"boolean"}
prompt = "RAW photo, mirox in a fancy suit, fashion magazine photoshoot, full body shot, high detailed skin, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"  # @param {type: "string"}
negative = "(weird eyes, disfigured eyes, looking different direction:1.3), cgi, 3d, render, mutated hands, mutated fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, bad quality, worst quality"  # @param {type: "string"}
model = os.path.join(dreambooth_output_dir,'Miroslav7.ckpt')  # @param {type: "string"}
vae = os.path.join(vae_dir,'vae-ft-mse-840000-ema-pruned.ckpt')  # @param {type: "string"}
outdir = inference_dir  # @param {type: "string"}
scale = 7  # @param {type: "slider", min: 1, max: 40}
sampler = "euler_a"  # @param ["ddim", "pndm", "lms", "euler", "euler_a", "heun", "dpm_2", "dpm_2_a", "dpmsolver","dpmsolver++", "dpmsingle", "k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a"]
steps = 35  # @param {type: "slider", min: 1, max: 100}
precision = "fp16"  # @param ["fp16", "bf16"] {allow-input: false}
width = 512  # @param {type: "integer"}
height = 768  # @param {type: "integer"}
images_per_prompt = 12  # @param {type: "integer"}
batch_size = 1  # @param {type: "integer"}
clip_skip = 1  # @param {type: "slider", min: 1, max: 40}
seed = -1  # @param {type: "integer"}

final_prompt = f"{prompt} --n {negative}"

config = {
    "v2": v2,
    "v_parameterization": v_parameterization,
    "ckpt": model,
    "outdir": outdir,
    "xformers": True,
    "vae": vae if vae else None,
    "fp16": True,
    "W": width,
    "H": height,
    "seed": seed if seed > 0 else None,
    "scale": scale,
    "sampler": sampler,
    "steps": steps,
    "max_embeddings_multiples": 3,
    "batch_size": batch_size,
    "images_per_prompt": images_per_prompt,
    "clip_skip": clip_skip if not v2 else None,
    "prompt": final_prompt,
}

args = ""
for k, v in config.items():
    if isinstance(v, str):
        args += f'--{k}="{v}" '
    if isinstance(v, bool) and v:
        args += f"--{k} "
    if isinstance(v, float) and not isinstance(v, bool):
        args += f"--{k}={v} "
    if isinstance(v, int) and not isinstance(v, bool):
        args += f"--{k}={v} "

final_args = f"python gen_img_diffusers.py {args}"

os.chdir(repo_dir)
!{final_args}