# Setup

## Install Dependencies

In [None]:
!git clone https://github.com/apehex/tr1cot.git

In [None]:
!pip install accelerate bitsandbytes datasets diffusers ftfy peft tensorboard torchvision transformers xformers
# !pip install /content/tr1cot

## Import Dependencies

In [None]:
import torch
import diffusers

# Config

In [None]:
# MODEL CONFIG #################################################################

MODEL_CONFIG = {
    'model_name': 'stable-diffusion-v1-5/stable-diffusion-v1-5',} # 'CompVis/stable-diffusion-v1-4'

In [None]:
# DATA CONFIG ##################################################################

DATASET_CONFIG = {
    'dataset_name': 'lambdalabs/naruto-blip-captions', # 'apehex/ascii-art-datacompdr-12m'
    'dataset_config': 'default',
    # 'dataset_split': 'train', # 'fixed
    'image_column': 'image', # 'content'
    'caption_column': 'text',} # 'caption'
    # 'max_samples': None,}

In [None]:
# PATH CONFIG ##################################################################

PATH_CONFIG = {
    'output_dir': 'naruto-lora',
    'cache_dir': None,
    'logging_dir': 'logs',}

In [None]:
# CHECKPOINT CONFIG ############################################################

CHECKPOINT_CONFIG = {
    # 'resume_from': None,
    # 'checkpoint_limit': None,
    'checkpoint_steps': 32,}

In [None]:
# RANDOM CONFIG ################################################################

RANDOM_CONFIG = {
    'seed': 1337,}

In [None]:
# PREPROCESSING CONFIG #########################################################

PREPROCESS_CONFIG = {
    'resolution': 512,
    'center_crop': True,
    'random_flip': True,
    'image_interpolation_mode': 'lanczos',}

In [None]:
# POSTPROCESSING CONFIG ########################################################

POSTPROCESS_CONFIG = {}

In [None]:
# TRAINING CONFIG ##############################################################

ITERATION_CONFIG = {
    'batch_dim': 1,
    # 'step_num': None,
    'epoch_num': 4,}

TRAINING_CONFIG = {
    'learning_rate': 1e-4,
    'scale_lr': False,
    'lr_scheduler': 'cosine',
    'lr_warmup_steps': 0,}

OPTIMIZER_CONFIG = {
    'adam_beta1': 0.9,
    'adam_beta2': 0.999,
    'adam_weight_decay': 1e-2,
    'adam_epsilon': 1e-8,
    'max_grad_norm': 1.0,}

LOSS_CONFIG = {
    'snr_gamma': False,}

GRADIENT_CONFIG = {
    'gradient_accumulation_steps': 16,
    'gradient_checkpointing': False,}

PRECISION_CONFIG = {
    'mixed_precision': 'bf16',
    'allow_tf32': False,
    'use_8bit_adam': False,}

DISTRIBUTION_CONFIG = {
    # 'local_rank': -1,
    'dataloader_num_workers': 0,}

FRAMEWORK_CONFIG = {
    'enable_xformers': False,}

DIFFUSION_CONFIG = {
    'prediction_type': 'epsilon',
    'noise_offset': 0.0,}

In [None]:
# TESTING CONFIG ###############################################################

TESTING_CONFIG = {
    'validation_prompt': 'Sasuke l33t hacking with a smartphone.',
    'num_validation_images': 4,
    'validation_epochs': 2,}

# Preprocess

In [None]:
# ARGS #########################################################################

def format_bool_option(name: str, value: bool) -> str:
    return int(value) * f'--{name}'

def format_str_option(name: str, value: str) -> str:
    return f'--{name}="{value}"'

def format_any_option(name: str, value: any) -> str:
    return f'--{name}={value}'

def format_option(name: str, value: any) -> str:
    __fn = format_any_option
    if isinstance(value, bool):
        __fn = format_bool_option
    if isinstance(value, str):
        __fn = format_str_option
    return __fn(name=name, value=value)

def format_command(prefix: str, options: dict) -> str:
    return prefix + ' ' + ' '.join(format_option(name=__k, value=__v) for __k, __v in options.items())

In [None]:
# COMMAND ######################################################################

COMMAND = format_command(
    prefix='accelerate launch /content/tr1cot/scripts/train_text_to_text_lora.py',
    options={
        **MODEL_CONFIG,
        **DATASET_CONFIG,
        **PATH_CONFIG,
        **CHECKPOINT_CONFIG,
        **PREPROCESS_CONFIG,
        **POSTPROCESS_CONFIG,
        **ITERATION_CONFIG,
        **TRAINING_CONFIG,
        **OPTIMIZER_CONFIG,
        **LOSS_CONFIG,
        **GRADIENT_CONFIG,
        **PRECISION_CONFIG,
        **DISTRIBUTION_CONFIG,
        **FRAMEWORK_CONFIG,
        **DIFFUSION_CONFIG,
        **TESTING_CONFIG,})

# Train

In [None]:
# CLEAN ########################################################################

!rm -rf $INSTANCE_DIR/.ipynb_checkpoints

In [None]:
# SAMPLE #######################################################################

pipe = diffusers.StableDiffusionPipeline.from_pretrained(MODEL_CONFIG['model_name'], torch_dtype=torch.float16)
pipe.to("cuda")

prompt = TESTING_CONFIG['validation_prompt']
pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

In [None]:
# RUN ##########################################################################

!{COMMAND}

# Generate

In [None]:
# SAMPLE #######################################################################

pipe = diffusers.StableDiffusionPipeline.from_pretrained(MODEL_CONFIG['model_name'], torch_dtype=torch.float16)
pipe.unet.load_attn_procs(PATH_CONFIG['output_path'] + '/checkpoint-128')
pipe.to("cuda")

prompt = TESTING_CONFIG['validation_prompt']
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("sasuke.png")

In [None]:
# POSTPROCESS ##################################################################