# Setup

## Install Dependencies

In [None]:
!git clone 'https://github.com/apehex/tr1cot-torch.git' 'tr1cot'

In [None]:
!pip install -qq -U accelerate bitsandbytes datasets densecurves diffusers ftfy mlable-torch peft tensorboard torchvision transformers xformers
# !pip install /content/tr1cot

## Import Dependencies

In [None]:
import math

import PIL as pl
import numpy as np

import torch
import diffusers

# Config

In [None]:
# RANDOM CONFIG ################################################################

RANDOM_CONFIG = {
    'random_seed': 1337,}

In [None]:
# VERSION CONFIG ###############################################################

VERSION_CONFIG = {
    'model_name': 'stable-diffusion-v1-5/stable-diffusion-v1-5',} # 'CompVis/stable-diffusion-v1-4'

In [None]:
# MODEL CONFIG #################################################################

MODEL_CONFIG = {
    'lora_rank': 8,}

In [None]:
# DATA CONFIG ##################################################################

DATASET_CONFIG = {
    'dataset_name': 'apehex/ascii-art-datacompdr-12m', # 'lambdalabs/naruto-blip-captions'
    'dataset_config': 'default',
    'dataset_split': 'train', # 'fixed'
    'image_column': 'content', # 'image'
    'caption_column': 'caption', # 'text'
    'max_samples': 0,}

In [None]:
# PATH CONFIG ##################################################################

PATH_CONFIG = {
    'output_dir': 'output',
    'cache_dir': '.cache',
    'logging_dir': 'logs',
    'project_name': 'text-to-text'}

In [None]:
# CHECKPOINT CONFIG ############################################################

CHECKPOINT_CONFIG = {
    # 'resume_from': None,
    # 'checkpoint_limit': None,
    'checkpoint_steps': 128,}

In [None]:
# PREPROCESSING CONFIG #########################################################

PREPROCESS_CONFIG = {
    'image_resolution': 512,
    'center_crop': True,
    'random_flip': True,
    'interpolation_mode': 'lanczos',}

In [None]:
# POSTPROCESSING CONFIG ########################################################

POSTPROCESS_CONFIG = {}

In [None]:
# TRAINING CONFIG ##############################################################

ITERATION_CONFIG = {
    'batch_dim': 1,
    'step_num': 2048,
    'epoch_num': 4,}

TRAINING_CONFIG = {
    'learning_rate': 1e-4,
    'scale_lr': False,
    'lr_scheduler': 'cosine',
    'lr_warmup_steps': 0,}

OPTIMIZER_CONFIG = {
    'adam_beta1': 0.9,
    'adam_beta2': 0.999,
    'adam_weight_decay': 1e-2,
    'adam_epsilon': 1e-8,
    'max_grad_norm': 1.0,}

LOSS_CONFIG = {
    'snr_gamma': 0.0,}

GRADIENT_CONFIG = {
    'gradient_accumulation_steps': 32,
    'gradient_checkpointing': False,}

PRECISION_CONFIG = {
    'mixed_precision': 'bf16',
    'allow_tf32': False,
    'use_8bit_adam': False,}

DISTRIBUTION_CONFIG = {
    # 'local_rank': -1,
    'dataloader_num_workers': 0,}

FRAMEWORK_CONFIG = {
    'enable_xformers': False,}

DIFFUSION_CONFIG = {
    'prediction_type': 'epsilon',
    'noise_offset': 0.0,}

In [None]:
# TESTING CONFIG ###############################################################

TESTING_CONFIG = {
    'validation_prompts': 4 * '"A monkey in ASCII art with width 64, braille characters." ',
    'validation_epochs': 2,}

# Preprocess

In [None]:
# ARGS #########################################################################

def format_bool_option(name: str, value: bool) -> str:
    return int(value) * f'--{name}'

def format_str_option(name: str, value: str) -> str:
    return f'--{name}="{value}"'

def format_any_option(name: str, value: any) -> str:
    return f'--{name}={value}'

def format_option(name: str, value: any) -> str:
    __fn = format_any_option
    if isinstance(value, bool):
        __fn = format_bool_option
    if isinstance(value, str):
        __fn = format_str_option
    return __fn(name=name, value=value)

def format_command(prefix: str, options: dict) -> str:
    return prefix + ' ' + ' '.join(format_option(name=__k, value=__v) for __k, __v in options.items())

In [None]:
# COMMAND ######################################################################

COMMAND = format_command(
    prefix='accelerate launch /content/tr1cot/scripts/train_text_to_text_lora.py',
    options={
        **RANDOM_CONFIG,
        **VERSION_CONFIG,
        **MODEL_CONFIG,
        **DATASET_CONFIG,
        **PATH_CONFIG,
        **CHECKPOINT_CONFIG,
        **PREPROCESS_CONFIG,
        **POSTPROCESS_CONFIG,
        **ITERATION_CONFIG,
        **TRAINING_CONFIG,
        **OPTIMIZER_CONFIG,
        **LOSS_CONFIG,
        **GRADIENT_CONFIG,
        **PRECISION_CONFIG,
        **DISTRIBUTION_CONFIG,
        **FRAMEWORK_CONFIG,
        **DIFFUSION_CONFIG,
        **TESTING_CONFIG,})

# Train

In [None]:
# CLEAN ########################################################################

!rm -rf $INSTANCE_DIR/.ipynb_checkpoints

In [None]:
# SAMPLE #######################################################################

pipe = diffusers.StableDiffusionPipeline.from_pretrained(VERSION_CONFIG['model_name'], torch_dtype=torch.float16)
pipe.to("cuda")

prompt = TESTING_CONFIG['validation_prompts'].split('"')[1]
pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

In [None]:
# RUN ##########################################################################

!{COMMAND}

# Generate

In [None]:
# SAMPLE #######################################################################

pipe = diffusers.StableDiffusionPipeline.from_pretrained(VERSION_CONFIG['model_name'], safety_checker=None, torch_dtype=torch.float16)
pipe.unet.load_attn_procs(PATH_CONFIG['output_dir'] + '/checkpoint-2048')
pipe.to("cuda")

In [None]:
prompt = TESTING_CONFIG['validation_prompts'].split('"')[1]
image = pipe(prompt, num_inference_steps=256, guidance_scale=8).images[0]
image.save("test.png")

In [None]:
# POSTPROCESS ##################################################################

def restore(data: np.ndarray) -> np.ndarray:
    # single channel array
    __zeros = np.zeros(tuple(data.shape)[:-1] + (1,), dtype=data.dtype)
    # add the leading zero in UTF-32-BE
    return np.concat([__zeros, data], axis=-1)

def decode(data: np.ndarray) -> str:
    # keep the batch and height axes (the output doesn't include newlines)
    __shape = tuple(data.shape)[:-2] + (math.prod(data.shape[-2:]),)
    # but the width and channel axes are merged into a single sequence
    __bytes = data.reshape(__shape)
    # interpret as UTF encodings
    return np.apply_along_axis(lambda __r: bytes(__r).decode('utf-32-be', errors='replace'), arr=__bytes, axis=-1)

In [None]:
__rgb = np.asarray(image.resize((64, 64), resample=pl.Image.Resampling.LANCZOS))

In [None]:
__rgb

In [None]:
decode(restore(__rgb))