# Text Diffusion

## Install The Dependencies

In [None]:
!pip install -U accelerate datasets densecurves diffusers[training] torchvision

In [None]:
import functools
import glob
import math
import os
import re

import accelerate
import accelerate.utils
import datasets
import diffusers
import diffusers.optimization
import numpy as np
import torch
import torch.nn.functional
import torchvision
import tqdm

import PIL as pillow
import matplotlib.axes as mpaxes
import matplotlib.colors as mpcolors
import matplotlib.pyplot as mpplot

import densecurves.hilbert

## Define The Config

In [None]:
# BASE #########################################################################

BASE_CONFIG = {
    'height_dim': 64,
    'width_dim': 64,
    'padding_str': ' ',} # '\x00'

In [None]:
# RANDOM #######################################################################

RANDOM_CONFIG = {
    'seed': 1337,}

In [None]:
# MODEL ########################################################################

MODEL_CONFIG = {
    'sample_size': BASE_CONFIG['height_dim'],
    'in_channels': 3,
    'out_channels': 3,
    'layers_per_block': 2,
    'block_out_channels': (128, 128, 256, 256, 512, 512),
    'down_block_types': ('DownBlock2D', 'DownBlock2D', 'DownBlock2D', 'DownBlock2D', 'AttnDownBlock2D', 'AttnDownBlock2D',),
    'up_block_types': ('AttnUpBlock2D', 'AttnUpBlock2D', 'UpBlock2D', 'UpBlock2D', 'UpBlock2D', 'UpBlock2D'),
    'act_fn': 'silu',
    'norm_eps': 1e-05,
    'norm_num_groups': 16,}

# 'attention_head_dim': 8,
# 'center_input_sample': False,
# 'downsample_padding': 1,
# 'flip_sin_to_cos': True,
# 'freq_shift': 0,
# 'mid_block_scale_factor': 1,

In [None]:
# PATH #########################################################################

PATH_CONFIG = {
    'output_dir': 'output',
    'cache_dir': '.cache',
    'logging_dir': 'logs',}

In [None]:
# DATASET ######################################################################

DATASET_CONFIG = {
    'path': 'apehex/ascii-art', # 'apehex/ascii-art-datacompdr-12m' # 'huggan/smithsonian_butterflies_subset',
    'name': 'asciiart',
    'split': 'train', # 'fixed'
    'cache_dir': PATH_CONFIG['cache_dir'],}

In [None]:
# CHECKPOINT ###################################################################

CHECKPOINT_CONFIG = {
    'checkpoint_epoch_num': 1,}

In [None]:
# TRAINING #####################################################################

ITERATION_CONFIG = {
    'batch_size': 32,
    'epoch_num': 32,
    'step_num': 166,}

SCHEDULER_CONFIG = {
    'num_warmup_steps': 128,
    'num_training_steps': ITERATION_CONFIG['step_num'] * ITERATION_CONFIG['epoch_num'],}

OPTIMIZER_CONFIG = {
    'lr': 1e-4,
    'betas': (0.9, 0.999),
    'weight_decay': 1e-2,
    'eps': 1e-8,}

ACCELERATE_CONFIG = {
    'sync_gradients': True,
    'gradient_accumulation_steps': 1,
    'mixed_precision': 'fp16',
    'log_with': 'tensorboard',}

In [None]:
# DIFFUSION ####################################################################

DIFFUSION_CONFIG = {
    'batch_size': ITERATION_CONFIG['batch_size'],
    'num_inference_steps': 1024,}

## Download The Dataset

In [None]:
# DOWNLOAD #####################################################################

DATASET = datasets.load_dataset(**DATASET_CONFIG)

## Data Visualization

In [None]:
# SUBPLOTS #####################################################################

class PlotContext:
    def __init__(self, rows: int, cols: int, zoom: iter=(4, 4), show: bool=False, **kwargs) -> None:
        self._rows = rows
        self._cols = cols
        self._zoom = zoom
        self._show = show
        self._args = dict(kwargs)
        self._size = (zoom[0] * cols, zoom[-1] * rows)
        self._figure = None
        self._axes = None

    def __enter__(self) -> tuple:
        self._figure, self._axes = mpplot.subplots(nrows=self._rows, ncols=self._cols, figsize=self._size, **self._args)
        # toggle the lines
        for __a in self._figure.axes:
            __a.get_xaxis().set_visible(self._show)
            __a.get_yaxis().set_visible(self._show)
        # return to the execution env
        return (self._figure, self._axes)

    def __exit__(self, exc_type: any, exc_value: any, traceback: any) -> None:
        mpplot.tight_layout()
        mpplot.show()

In [None]:
# IMAGE GRID ###################################################################

def imgrid(images):
    # parse the shape
    __count = len(images)
    __width, __height = images[0].size
    # distribute evenly across rows and cols
    __cols = 2 ** int(0.5 * math.log2(__count))
    __rows = __count // __cols
    # concatenate the images
    __grid = pillow.Image.new('RGB', size=(__cols * __width, __rows * __height))
    # paste each image in its corresponding spot
    for __i, __image in enumerate(images):
        __grid.paste(__image, box=((__i % __cols) * __width, (__i // __cols) * __height))
    # single image
    return __grid

In [None]:
# IMAGE WITH CAPTION OVERLAY ###################################################

def matshow(axes: mpaxes.Axes, data: iter=(), curve: iter=(), text: iter=(), family: iter=None, cmap: mpcolors.Colormap=None) -> None:
    # image like display of an array
    if len(data):
        axes.matshow(data, cmap=cmap)
    # path of the curve
    if len(curve):
        axes.plot(curve[0], curve[-1], color='black')
    # add a text overlay
    for __j in range(len(text)):
        for __i in range(len(text[__j])):
            if text[__j][__i] not in ' \x00':
                axes.text(__i, __j, str(text[__j][__i]), va='center', ha='center', color='white', family=family)

## Preprocessing Operations

In [None]:
# CLEAN ########################################################################

ANSI_REGEX = r'\x1b\[[0-9;]*[mGKHF]'

def clean(text: str, pattern: str=ANSI_REGEX, rewrite: str='') -> str:
    return re.sub(pattern=pattern, repl=rewrite, string=text)

In [None]:
# 1D => 2D #####################################################################

def chunk(seq: list, size: int, repeats: bool=True) -> list:
    __chunks = (seq[__i:__i + size] for __i in range(0, len(seq), size))
    return list(__chunks if repeats else set(__chunks))

def split(text: str, height: int=64, width: int=64, separator: str='\n') -> list:
    # typically split on \n or at a fixed size
    __rows = text.split(separator) if separator else chunk(text, width)
    # :width would leave one character out when width == -1
    __width = slice(width if (width > 0) else None)
    # idem fro the height
    __height = slice(height if (height > 0) else None)
    # enforce the maximum dimensions
    return [__r[__width] for __r in __rows[__height] if __r]

def pad(rows: list, height: int=64, width: int=64, value: str='\x00') -> list:
    return [__r + (width - len(__r)) * value for __r in rows] + (height - len(rows)) * [width * value]

In [None]:
# RGB ENCODING #################################################################

def rgb_utf(rows: list) -> np.ndarray:
    __height, __width = len(rows), len(rows[0])
    # each character is encoded as 4 bytes
    __rows = [list(__r.encode('utf-32-be')) for __r in rows]
    # 2d reshaping
    __array = np.array(__rows, dtype=np.uint8).reshape((__height, __width, 4))
    # strip the leading byte, always null in utf-32 (big-endian)
    return __array[..., 1:]

# CUSTOM COLOR SCHEMES #########################################################

def mix_channels(channels: np.ndarray) -> np.ndarray:
    __mod = np.array(3 * [256], dtype=channels.dtype)
    __mix = [channels[0] + channels[-1], channels[1] + channels[-1], channels[-1]]
    return np.mod(__mix, __mod)

def rgb_mixed(rows: list) -> np.ndarray:
    return np.apply_along_axis(mix_channels, arr=rgb_utf(rows).astype(np.int32), axis=-1)

def rgb_hilbert(rows: list) -> np.ndarray:
    __height, __width = len(rows), len(rows[0])
    # each character is encoded as 4 bytes
    __rows = [[densecurves.hilbert.point(ord(__c), order=8, rank=3) for __c in __r] for __r in rows]
    # cast and reshape
    return np.array(__rows, dtype=np.uint8).reshape((__height, __width, 3))

In [None]:
# TEXT TO IMAGE ################################################################

def text_to_image(examples: dict, height: int=BASE_CONFIG['height_dim'], width: int=BASE_CONFIG['width_dim'], padding: str='\x00', encode: callable=rgb_utf) -> list:
    # remove ANSI color codes
    __data = [clean(__d) for __d in examples['content']]
    # split the ASCII art string line by line
    __data = [split(__d, height=height, width=width, separator='\n') for __d in __data]
    # pad with null codepoints (=> null channels) to full height x width
    __data = [pad(__d, height=height, width=width, value=padding) for __d in __data]
    # encode as rgb
    __data = [encode(__d) for __d in __data]
    # format as pillow image
    return [pillow.Image.fromarray(__d, mode='RGB') for __d in __data]

In [None]:
# IMAGE OPERATIONS #############################################################

operations = torchvision.transforms.Compose([
    # torchvision.transforms.Resize((BASE_CONFIG['height_dim'], BASE_CONFIG['width_dim'])),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5], [0.5]),])

In [None]:
# END-TO-END ###################################################################

def preprocess(examples: dict, transforms: callable, height: int=BASE_CONFIG['height_dim'], width: int=BASE_CONFIG['width_dim'], padding: str='\x00', encode: callable=rgb_utf):
    # use UTF-32 encoding to interpret text as RGB data
    __images = text_to_image(examples=examples, height=height, width=width, padding=padding, encode=encode)
    # apply image transformations (resize, crop, etc)
    return {'images': [transforms(__i) for __i in __images],}

def collate_fn(examples: iter):
    __images = torch.stack([__e['images'] for __e in examples])
    __images = __images.to(memory_format=torch.contiguous_format).float()
    return {'images': __images,}

## Postprocessing Operations

In [None]:
# TENSOR TO IMAGE ##############################################################

def transpose(data: torch.Tensor) -> np.ndarray:
    __rank = len(data.shape)
    __perm = (0, 2, 3, 1) if (__rank == 4) else (1, 2, 0)
    return data.permute(__perm).numpy()

def denorm(data: np.ndarray) -> np.ndarray:
    return np.round(255 * (0.5 * data + 0.5)).astype(np.int32)

In [None]:
# IMAGE TO TEXT ################################################################

def restore(data: np.ndarray) -> np.ndarray:
    # single channel array
    __zeros = np.zeros(tuple(data.shape)[:-1] + (1,), dtype=data.dtype)
    # add the leading zero in UTF-32-BE
    return np.concat([__zeros, data], axis=-1)

def decode(data: np.ndarray) -> str:
    # keep the batch and height axes (the output doesn't include newlines)
    __shape = tuple(data.shape)[:-2] + (math.prod(data.shape[-2:]),)
    # but the width and channel axes are merged into a single sequence
    __bytes = data.reshape(__shape)
    # interpret as UTF encodings
    return np.apply_along_axis(lambda __r: bytes(__r.tolist()).decode('utf-32-be', errors='replace'), arr=__bytes, axis=-1)

In [None]:
# CAST #########################################################################

def unpack(data: np.ndarray) -> list:
    return

## Preprocess The Dataset

In [None]:
# APPLY ########################################################################

__preprocess = functools.partial(
    preprocess,
    transforms=operations,
    height=BASE_CONFIG['height_dim'],
    width=BASE_CONFIG['width_dim'],
    padding=BASE_CONFIG['padding_str'],
    encode=rgb_utf)

DATASET.set_transform(__preprocess)

In [None]:
# CHECK ########################################################################

with PlotContext(rows=2, cols=2, zoom=(8, 8), show=False) as (__fig, __axes):
    for __i, __image in enumerate(DATASET[136:140]['images']):
        __colors = denorm(transpose(__image))
        __text = decode(restore(__colors))
        matshow(data=__colors, text=__text, axes=__axes[__i // 2][__i % 2])

In [None]:
# COLLATE ######################################################################

DATALOADER = torch.utils.data.DataLoader(DATASET, batch_size=ITERATION_CONFIG['batch_size'], shuffle=True)

## Init The Model

In [None]:
# CREATE #######################################################################

MODEL = diffusers.UNet2DModel(**MODEL_CONFIG)

In [None]:
# RUN ##########################################################################

__sample = DATASET[0]['images'].unsqueeze(0)

print('Input shape:', __sample.shape)
print('Output shape:', MODEL(__sample, timestep=0).sample.shape)

## Setup The Training Env

In [None]:
# PATHS ########################################################################

os.makedirs(PATH_CONFIG['cache_dir'], exist_ok=True)
os.makedirs(PATH_CONFIG['output_dir'], exist_ok=True)
os.makedirs(PATH_CONFIG['logging_dir'], exist_ok=True)

In [None]:
# OPTIMIZER ####################################################################

OPTIMIZER = torch.optim.AdamW(MODEL.parameters(), **OPTIMIZER_CONFIG)

In [None]:
# SCHEDULERS ###################################################################

LR_SCHEDULER = diffusers.optimization.get_cosine_schedule_with_warmup(
    optimizer=OPTIMIZER,
    **SCHEDULER_CONFIG)

NOISE_SCHEDULER = diffusers.DDPMScheduler(
    num_train_timesteps=DIFFUSION_CONFIG['num_inference_steps'])

In [None]:
# CALLBACK #####################################################################

def evaluate(config: dict, pipeline: callable) -> None:
    # sample from random noise (returns PIL.Image objects)
    __images = pipeline(
        batch_size=4,
        num_inference_steps=config['num_inference_steps'],
        generator=torch.manual_seed(config['seed'])).images
    # parse the shape
    __width, __height = __images[0].size
    # display in a subplot
    with PlotContext(rows=2, cols=2, zoom=(4, 4), show=False) as (__fig, __axes):
        for __i, __image in enumerate(__images):
            # extract the byte data
            __colors = np.asarray(__images)
            # decode back into text
            __text = decode(restore(__colors))
            # overlay the text on the RGB encoding
            matshow(data=__colors[__i], axes=__axes[__i // 2][__i % 2]) # text=__text,

In [None]:
# LOOP #########################################################################

def train_loop(config, model, dataloader, optimizer, lr_scheduler, noise_scheduler):
    # init project
    __project = accelerate.utils.ProjectConfiguration(
        project_dir=config['output_dir'],
        logging_dir=config['output_dir'])
    # init accelerator
    __accelerator = accelerate.Accelerator(
        mixed_precision=config['mixed_precision'],
        gradient_accumulation_steps=config['gradient_accumulation_steps'],
        log_with=config['log_with'],
        project_config=__project)
    # init tensorboard logging
    if __accelerator.is_main_process:
        __accelerator.init_trackers(config['logging_dir'])
    # automatically handle distribution and mixed precision
    __model, __optimizer, __dataloader, __lr_scheduler, __noise_scheduler = __accelerator.prepare(
        model, optimizer, dataloader, lr_scheduler, noise_scheduler)
    # total step, accumulated over all epochs
    __step = 0
    # each epoch trains on the whole dataset
    for __epoch in range(config['epoch_num']):
        # progress inside each epoch
        __pbar = tqdm.auto.tqdm(total=len(__dataloader), disable=not __accelerator.is_local_main_process)
        __pbar.set_description(f'Epoch {__epoch}')
        # iterate over the dataset samples
        for __batch in __dataloader:
            # parse inputs
            __shape = __batch['images'].shape
            __device = __batch['images'].device
            # sample noise
            __noise = torch.randn(__shape).to(__device)
            # sample a different timestep for each image
            __timesteps = torch.randint(0, __noise_scheduler.config.num_train_timesteps, (int(__shape[0]),), device=__device).long()
            # add noise to the clean images according to the noise magnitude at each timestep
            __inputs = __noise_scheduler.add_noise(__batch['images'], __noise, __timesteps)
            # accumulate gradients over several steps
            with __accelerator.accumulate(__model):
                # predict the noise residual
                __pred = __model(__inputs, __timesteps, return_dict=False)[0]
                # compute the los
                __loss = torch.nn.functional.mse_loss(__pred, __noise)
                # compute the gradients
                __accelerator.backward(__loss)
                # clip gradients to avoid explosion
                __accelerator.clip_grad_norm_(__model.parameters(), 1.0)
                # apply the gradients
                __optimizer.step()
                # update the learning rate
                __lr_scheduler.step()
                # reset the gradients
                __optimizer.zero_grad()
            # log the progress
            __logs = {'loss': __loss.detach().item(), 'lr': __lr_scheduler.get_last_lr()[0], 'step': __step}
            # display on the progress bar
            __pbar.update(1)
            __pbar.set_postfix(**__logs)
            # save in the logs
            __accelerator.log(__logs, step=__step)
            # update the overall training step
            __step += 1

        # evaluate the model regularly
        if __accelerator.is_main_process:
            __pipeline = diffusers.DDPMPipeline(unet=__accelerator.unwrap_model(__model), scheduler=__noise_scheduler)
            if (__epoch + 1) % config['checkpoint_epoch_num'] == 0 or __epoch == config['epoch_num'] - 1:
                evaluate(config, __pipeline)
                __pipeline.save_pretrained(config['output_dir'])

## Train The Diffusion Model

In [None]:
# COLLATE ARGS #################################################################

CONFIG = {
    **RANDOM_CONFIG,
    **PATH_CONFIG,
    **CHECKPOINT_CONFIG,
    **ITERATION_CONFIG,
    **ACCELERATE_CONFIG,
    **DIFFUSION_CONFIG}

ARGS = (CONFIG, MODEL, DATALOADER, OPTIMIZER, LR_SCHEDULER, NOISE_SCHEDULER)

In [None]:
# RUN ##########################################################################

accelerate.notebook_launcher(train_loop, ARGS, num_processes=1)

## Postprocess

## Evaluate The Model

In [None]:
# SAMPLE #######################################################################

# sample_images = sorted(glob.glob(f'{PATH_CONFIG["output_dir"]}/samples/*.png'))
# pillow.Image.open(sample_images[-1])

## Inspect The Logs

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir output/logs/

In [None]:
__config = {**{'seed': 42}, **DIFFUSION_CONFIG}
__pipeline = diffusers.DDPMPipeline(unet=MODEL, scheduler=NOISE_SCHEDULER)
__images = __pipeline(
    batch_size=4,
    num_inference_steps=__config['num_inference_steps'],
    generator=torch.manual_seed(__config['seed'])).images
# evaluate({**{'seed': 42}, **DIFFUSION_CONFIG}, __pipeline)

In [None]:
__data = [np.asarray(__i) for __i in __images]
__text = [decode(restore(__d)) for __d in __data]

In [None]:
__data[0]