## Import deps

In [None]:
!pip install -qq tensorflow==2.18.0
!pip install -qq tensorflow-tpu==2.18.0 --find-links=https://storage.googleapis.com/libtpu-tf-releases/index.html

In [None]:
!pip install -qq -U datasets mlable tokun

In [None]:
!pip install -qq --no-index -f '/content/libs/' tr1cot

In [None]:
import datetime
import functools
import itertools
import math
import os
import random
import urllib.request

import datasets as hd
import tensorflow as tf

import mlable.data
import mlable.metrics
import mlable.sampling
import mlable.shapes
import mlable.shaping.axes
import mlable.shaping.hilbert
import mlable.text

import tokun.data
import tokun.eval
import tokun.models.klvae
import tokun.models.vqvae
import tokun.pipeline.flat.preprocess
import tokun.pipeline.flat.postprocess
import tokun.pipeline.hilbert.preprocess
import tokun.pipeline.hilbert.postprocess
import tokun.pipeline.square.preprocess
import tokun.pipeline.square.postprocess

import tr1cot.models.cnn
import tr1cot.models.vit
import tr1cot.models.unet

In [None]:
print("Tensorflow version " + tf.__version__)

## Setup the GPU / TPU

In [None]:
# DEBUGGING ####################################################################

tf.keras.config.disable_traceback_filtering()

In [None]:
# MIXED PRECISION ##############################################################

tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')

In [None]:
# DEVICES ######################################################################

tf.debugging.set_log_device_placement(False)

CPU = tf.config.list_logical_devices('CPU')
GPU = tf.config.list_logical_devices('GPU')
TPU = tf.config.list_logical_devices('TPU')

if TPU:
    RESOLVER = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')
    tf.config.experimental_connect_to_cluster(RESOLVER)
    tf.tpu.experimental.initialize_tpu_system(RESOLVER)
    DISTRIBUTION_STRATEGY = tf.distribute.TPUStrategy(RESOLVER)
elif GPU:
    DISTRIBUTION_STRATEGY = tf.distribute.MirroredStrategy(GPU)
else:
    DISTRIBUTION_STRATEGY = tf.distribute.MirroredStrategy(CPU)

print('CPU: ', CPU)
print('GPU: ', GPU)
print('TPU: ', TPU)
print('DS: ', DISTRIBUTION_STRATEGY)

## Mode

In [None]:
# TOGGLE #######################################################################

IMPORT = False
DOWNLOAD = False
TRAINING = True
RANDOM = True

DATA = 'square' # 'flat' / 'hilbert' / 'square'
ARCH0 = 'vqvae' # 'klvae' / 'vqvae'
ARCH1 = 'unet' # 'vit' / 'cnn' / 'unet'

## Defining The Metadata

In [None]:
# COMMON PARAMETERS ############################################################

BASE_CONFIG = {
    'batch_dim': 32, # B
    'token_dim': 3, # T
    'drop_dim': 1, # D, number of bytes dropped from the encoding
    'input_dim': 256, # U_i (bytes)
    'height_dim': 64, # H
    'width_dim': 64 * 4, # W * (T + D)
    'sample_dim': 1024, # S = L * (T + D)
    'order_num': 5, # O => H = W = 2 ** O
    'rank_num': 2, # R
    'start_rate': 0.98,
    'end_rate': 0.02,
    'epochs': 32,
    'steps': 2 ** 7,
    'epsilon': 1e-6,
    'dropout': 0.01,
    'trainable': False, # whether to freeze the weight of tokun
    'bigendian': True,
    'encoding': 'UTF-32-BE',}

In [None]:
# TOKUN PARAMETERS #############################################################

TOKUN_FACTORY = {
    'klvae': tokun.models.klvae.KlAutoEncoder,
    'vqvae': tokun.models.vqvae.QuantizedAutoEncoder,}

TOKUN_CONFIG = {
    'vqvae': {
        'token_dim': BASE_CONFIG['token_dim'],
        'input_dim': BASE_CONFIG['input_dim'],
        'embed_dim': 64,
        'binary_dim': 8,},
    'klvae': {
        'block_cfg': [
            {'layer_num': 2, 'channel_dim': 64, 'group_dim': 32, 'head_dim': 32, 'head_num': 8, 'add_sampling': False, 'add_attention': False,},
            {'layer_num': 2, 'channel_dim': 128, 'group_dim': 32, 'head_dim': 32, 'head_num': 16, 'add_sampling': True, 'add_attention': False,},
            {'layer_num': 4, 'channel_dim': 256, 'group_dim': 64, 'head_dim': 64, 'head_num': 16, 'add_sampling': True, 'add_attention': True,},],
        'input_dim': BASE_CONFIG['input_dim'],
        'embed_dim': 64,
        'output_dim': 8 * BASE_CONFIG['token_dim'],
        'step_min': 0,
        'step_max':  BASE_CONFIG['steps'],
        'beta_min': 0.0001,
        'beta_max': 0.01,
        'dropout_rate': BASE_CONFIG['dropout'],
        'epsilon_rate': BASE_CONFIG['epsilon'],},}

In [None]:
# MODEL PARAMETERS #############################################################

MODEL_FACTORY = {
    'cnn': tr1cot.models.cnn.CnnDiffusionModel,
    'vit': tr1cot.models.vit.VitDiffusionModel,
    'unet': tr1cot.models.unet.UnetDiffusionModel,}

MODEL_CONFIG = {
    'cnn': {
        'block_num': 4,
        'latent_dim': [64, 128, 256],
        'start_rate': BASE_CONFIG['start_rate'],
        'end_rate': BASE_CONFIG['end_rate'],},
    'vit': {
        'patch_dim': [1, 1, 2, 2, 1, 1],
        'start_rate': BASE_CONFIG['start_rate'],
        'end_rate': BASE_CONFIG['end_rate'],
        'dropout_rate': 0.01,},
    'unet': {
        'channel_dim': [64, 128, 128, 128, 192, 192],
        'group_dim': None,
        'head_dim': None,
        'head_num': None,
        'layer_num': 2,
        'add_attention': [0, 0, 1, 1, 0, 0],
        'add_downsampling': [0, 1, 1, 0, 0, 0],
        'add_upsampling': [0, 0, 0, 1, 1, 0],
        'start_rate': BASE_CONFIG['start_rate'],
        'end_rate': BASE_CONFIG['end_rate'],
        'dropout_rate': BASE_CONFIG['dropout'],
        'epsilon_rate': BASE_CONFIG['epsilon'],},}

In [None]:
# DERIVED MODEL PARAMETERS #####################################################

META_CONFIG = {
    'tokun': '{}.{}x{}'.format(ARCH0, BASE_CONFIG['token_dim'], TOKUN_CONFIG[ARCH0]['embed_dim']),
    'tr1cot': '{}'.format(ARCH1),}

IO_CONFIG = {
    'tokun': {
        'url': 'https://github.com/apehex/tokun/raw/main/models/{}.keras'.format(META_CONFIG['tokun']),
        'path': 'tokun.keras',},
    'tr1cot': {
        'url': 'https://github.com/apehex/tr1cot/raw/main/models/{}.keras'.format(META_CONFIG['tr1cot']),
        'path': 'tr1cot.keras',},}

In [None]:
# PREPROCESSING ################################################################

ANSI_REGEX = r'\x1b\[[0-9;]*[mGKHF]'

FILTER_CONFIG = {
    'pattern': '.*',} # '.*[Cc]ats.*'

BATCH_CONFIG = {
    'batch_size': BASE_CONFIG['batch_dim'],
    'drop_remainder': True,
    'num_parallel_calls': tf.data.AUTOTUNE,}

PIPELINE_FACTORY = {
    'flat': tokun.pipeline.flat.preprocess.factory,
    'hilbert': tokun.pipeline.hilbert.preprocess.factory,
    'square': tokun.pipeline.square.preprocess.factory,}

PIPELINE_CONFIG = {
    'flat': {
        'batch_dim': BATCH_CONFIG['batch_size'],
        'token_dim': BASE_CONFIG['token_dim'],
        'drop_dim': BASE_CONFIG['drop_dim'],
        'sample_dim': (BASE_CONFIG['token_dim'] + BASE_CONFIG['drop_dim']) * BASE_CONFIG['sample_dim'],
        'pattern': ANSI_REGEX,
        'rewrite': '',
        'separator': '\u001d',
        'encoding': BASE_CONFIG['encoding'],
        'bigendian': BASE_CONFIG['bigendian'],
        'targets': False,},
    'hilbert': {
        'batch_dim': BATCH_CONFIG['batch_size'],
        'token_dim': BASE_CONFIG['token_dim'],
        'order_num': BASE_CONFIG['order_num'],
        'rank_num': BASE_CONFIG['rank_num'],
        'pattern': ANSI_REGEX,
        'rewrite': '',
        'separator': '\u001d',
        'encoding': BASE_CONFIG['encoding'],
        'bigendian': BASE_CONFIG['bigendian'],
        'targets': False,},
    'square': {
        'batch_dim': BATCH_CONFIG['batch_size'],
        'token_dim': BASE_CONFIG['token_dim'],
        'drop_dim': BASE_CONFIG['drop_dim'],
        'height_dim': BASE_CONFIG['height_dim'],
        'width_dim': BASE_CONFIG['width_dim'],
        'pattern': ANSI_REGEX,
        'rewrite': '',
        'separator': '\u001d',
        'encoding': BASE_CONFIG['encoding'],
        'bigendian': BASE_CONFIG['bigendian'],
        'targets': False,},}

In [None]:
# POSTPROCESSING ###############################################################

POSTPROCESSING_FACTORY = {
    'flat': tokun.pipeline.flat.postprocess.factory,
    'hilbert': tokun.pipeline.hilbert.postprocess.factory,
    'square': tokun.pipeline.square.postprocess.factory,}

POSTPROCESSING_CONFIG = {
    'flat': {
        'drop_dim': PIPELINE_CONFIG['flat']['drop_dim'],
        'encoding': PIPELINE_CONFIG['flat']['encoding'],
        'bigendian': PIPELINE_CONFIG['flat']['bigendian'],
        'threshold': 0.0,
        'errors': 'replace',},
    'hilbert': {
        'order_num': PIPELINE_CONFIG['hilbert']['order_num'],
        'rank_num': PIPELINE_CONFIG['hilbert']['rank_num'],
        'encoding': PIPELINE_CONFIG['hilbert']['encoding'],
        'bigendian': PIPELINE_CONFIG['hilbert']['bigendian'],
        'threshold': 0.0,
        'errors': 'replace',},
    'square': {
        'drop_dim': PIPELINE_CONFIG['square']['drop_dim'],
        'encoding': PIPELINE_CONFIG['square']['encoding'],
        'bigendian': PIPELINE_CONFIG['square']['bigendian'],
        'threshold': 0.0,
        'errors': 'replace',},}

In [None]:
# TRAINING PARAMETERS ##########################################################

TRAINING_CONFIG = {
    'epochs': BASE_CONFIG['epochs'],
    'batch_size': None,
    'validation_split': None,
    'validation_freq': list(range(1, 9)),
    # 'class_weight': {__c: 1. if __c == 0 else 1. for __c in range(256)}, # there are 3 times more 0s than other bytes
    'verbose': 1,}

OPTIMIZER_CONFIG = {
    'learning_rate': 0.0001 * (0.1 if IMPORT else 1.0),
    'weight_decay': 0.000001,
    'beta_1': 0.9,
    'beta_2': 0.999,
    'epsilon': 1e-7,
    'clipnorm': 0.1,
    'amsgrad': False,
    'use_ema': False,
    'ema_momentum': 0.99,
    'ema_overwrite_frequency': BASE_CONFIG['steps'] // 8,}
    # 'gradient_accumulation_steps': 2,}

SCHEDULER_CONFIG = {
    'initial_learning_rate': 0.0001 * OPTIMIZER_CONFIG['learning_rate'],
    'decay_steps': TRAINING_CONFIG['epochs'] * BASE_CONFIG['steps'],
    'alpha': 0.01,
    'name': 'cosine_lr',
    'warmup_target': OPTIMIZER_CONFIG['learning_rate'],
    'warmup_steps': BASE_CONFIG['steps'] // 8,}

LOSS_CONFIG = {
    'from_logits': True,
    'label_smoothing': 0.0,
    'axis': -1,
    'reduction': 'sum_over_batch_size',
    'name': 'ce_loss',}

METRICS_CONFIG = {
    'depth': 8,
    'from_logits': True,}

CHECKPOINT_CONFIG = {
    'filepath': IO_CONFIG['tr1cot']['path'],
    'monitor': 'val_loss',
    'mode': 'auto',
    'save_freq': 'epoch',
    'save_best_only': False,
    'save_weights_only': False,
    'verbose': 1,}

TENSORBOARD_CONFIG = {
    'log_dir': os.path.join('.logs/', META_CONFIG['tr1cot'], datetime.datetime.now().strftime("%Y%m%d-%H%M%S")),
    'histogram_freq': 1,
    'embeddings_freq': 0,
    # 'profile_batch': (0, 4),
    'write_graph': True,
    'write_images': True,}

In [None]:
# DATASETS #####################################################################

DATASETS_CONFIG = {
    # 'pt-fineweb-edu': {
    #     'path': 'HuggingFaceFW/fineweb-edu',
    #     'name': 'sample-10BT',
    #     'split': 'train',
    #     'features': ['text'],},
    # 'pt-fineweb-kor': {
    #     'path': 'HuggingFaceFW/fineweb-2',
    #     'name': 'kor_Hang',
    #     'split': 'train',
    #     'features': ['text'],},
    # 'pt-fineweb-fin': {
    #     'path': 'HuggingFaceFW/fineweb-2',
    #     'name': 'fin_Latn',
    #     'split': 'train',
    #     'features': ['text'],},
    # 'pt-wikipedia': {
    #     'path': 'wikimedia/wikipedia',
    #     'name': '20231101.en',
    #     'split': 'train',
    #     'features': ['text'],},
    # 'tp-wikipedia-1': {
    #     'path': 'wikimedia/wikipedia',
    #     'name': '20231101.en',
    #     'split': 'train',
    #     'features': ['text'],},
    # 'tp-wikipedia-2': {
    #     'path': 'wikimedia/wikipedia',
    #     'name': '20231101.en',
    #     'split': 'train',
    #     'features': ['text'],},
    # 'ft-retro-ascii-art': {
    #     'path': 'jdpressman/retro-ascii-art-v1',
    #     'name': None,
    #     'train': 'train',
    #     'split': 'train',
    #     'features': ['prompt', 'art_aic'],},
    # 'ft-stack-exchange': {
    #     'path': 'Alignment-Lab-AI/Stack-Exchange-April',
    #     'name': None,
    #     'split': 'train',
    #     'features': ['question', 'answer'],},
    # 'ft-math': {
    #     'path': 'HuggingFaceTB/finemath',
    #     'name': 'finemath-3plus',
    #     'split': 'train',
    #     'features': ['text'],},
    # 'cot-text-dolphin': {
    #     'path': 'cognitivecomputations/dolphin-r1',
    #     'name': 'reasoning-deepseek',
    #     'split': 'train',
    #     'features': ['reasoning', 'answer'],},
    # 'cot-text-openthoughts': {
    #     'path': 'open-thoughts/OpenThoughts-114k',
    #     'name': 'default',
    #     'split': 'train',
    #     'features': ['problem', 'solution'],},
    # 'ft-asciiart-asciiart': {
    #     'path': 'apehex/ascii-art',
    #     'name': 'asciiart',
    #     'split': 'train',
    #     'features': ['content'],},
    # 'ft-asciiart-copypasta': {
    #     'path': 'apehex/ascii-art',
    #     'name': 'copypasta',
    #     'split': 'train',
    #     'features': ['content'],},
    # 'ft-asciiart-graffiti': {
    #     'path': 'apehex/ascii-art',
    #     'name': 'graffiti',
    #     'split': 'train',
    #     'features': ['content'],},
    # 'ft-asciiart-images': {
    #     'path': 'apehex/ascii-art',
    #     'name': 'images',
    #     'split': 'train',
    #     'features': ['content'],},
    'ft-asciiart-datacompdr': {
        'path': 'apehex/ascii-art-datacompdr-12m',
        'name': 'default',
        'split': 'fixed',
        'features': ['content'],},
    # 'cot-math-numi': {
    #     'path': 'AI-MO/NuminaMath-CoT',
    #     'name': None,
    #     'split': 'train',
    #     'features': ['problem', 'solution'],},
}

In [None]:
# PLOT #########################################################################

## Downloading The Model Weights

In [None]:
# IMPORT #######################################################################

# tokun
urllib.request.urlretrieve(IO_CONFIG['tokun']['url'], IO_CONFIG['tokun']['path'])

# tr1cot
if IMPORT and DOWNLOAD:
    urllib.request.urlretrieve(IO_CONFIG['tr1cot']['url'], IO_CONFIG['tr1cot']['path'])

## Downloading The Data

In [None]:
# DOWNLOAD #####################################################################

DATASETS = {
    __name: hd.load_dataset(path=__args['path'], name=__args['name'], split=__args['split']).to_tf_dataset(shuffle=False, batch_size=None)
    for __name, __args in DATASETS_CONFIG.items()}

In [None]:
# STATS ########################################################################

STATS = {__n: mlable.data.stats(dataset=DATASETS[__n], features=DATASETS_CONFIG[__n]['features'], count=2048) for __n in DATASETS}

print(STATS)

In [None]:
# VIZ ##########################################################################

# __i = iter(DATASETS['ft-asciiart-datacompdr'])

In [None]:
# __s = next(__i)
# print(__s['caption'].numpy().decode('utf-8'), __s['labels'].numpy().decode('utf-8'), len(__s['content'].numpy().decode('utf-8')))
# print(__s['content'].numpy().decode('utf-8'))

## Preprocess

In [None]:
# ITERATE ######################################################################

# __filter = lambda __s: True
# __filter = lambda __s: tf.strings.regex_full_match(__s['labels'], **FILTER_CONFIG)

for __name in DATASETS:
    # specialized preprocessing fn
    __preprocess = PIPELINE_FACTORY[DATA](
        features=DATASETS_CONFIG[__name]['features'],
        **PIPELINE_CONFIG[DATA])
    # apply
    DATASETS[__name] = DATASETS[__name].batch(**BATCH_CONFIG).map(__preprocess, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
# POSTPROCESS ##################################################################

__postprocess_greedy = POSTPROCESSING_FACTORY[DATA](**POSTPROCESSING_CONFIG[DATA])
__postprocess_sampler = POSTPROCESSING_FACTORY[DATA](temp=1.0, topp=0.9, topk=4, **POSTPROCESSING_CONFIG[DATA])
__postprocess_probs = POSTPROCESSING_FACTORY[DATA](**{__k: (0.5 if __k == 'threshold' else __v) for __k, __v in POSTPROCESSING_CONFIG[DATA].items()})

In [None]:
# CONCATENATE ##################################################################

DATASET_KEYS = set(DATASETS.keys()) - {'random'}

DATASET_ALL = functools.reduce(lambda __l, __r: __l.concatenate(__r), [DATASETS[__n] for __n in DATASET_KEYS])
DATASET_DIM = DATASET_ALL.cardinality().numpy()

DATASET_TEST = DATASET_ALL.take(1)
DATASET_TRAIN = DATASET_ALL.skip(1).take(BASE_CONFIG['steps'])

# RANDOM_TEST = DATASETS['random'].take(128)
# RANDOM_TRAIN = DATASETS['random'].skip(128)

In [None]:
# INSPECT ######################################################################

__X = next(iter(DATASET_TRAIN.take(1)))
__V = tf.zeros(mlable.shapes.filter(__X.shape, axes=[0]), dtype=tf.float32)

print(DATASET_TRAIN.element_spec)
print(DATASET_TEST.element_spec)

print('train: {:,}'.format(DATASET_TRAIN.cardinality().numpy()))
print('test:  {:,}'.format(DATASET_TEST.cardinality().numpy()))

## Init The Model

In [None]:
# COMPILE ######################################################################

with DISTRIBUTION_STRATEGY.scope():
    # metrics
    # byte_accuracy = mlable.metrics.BinaryGroupAccuracy(group=1, name='byte_accuracy', **METRICS_CONFIG)
    # token_accuracy = mlable.metrics.BinaryGroupAccuracy(group=BASE_CONFIG['token_dim'], name='token_accuracy', **METRICS_CONFIG)
    # cosing LR
    OPTIMIZER_CONFIG['learning_rate'] = tf.keras.optimizers.schedules.CosineDecay(**SCHEDULER_CONFIG)
    # weights
    MODEL = MODEL_FACTORY[ARCH1](**MODEL_CONFIG[ARCH1])
    if IMPORT and os.path.isfile(IO_CONFIG['tr1cot']['path']): MODEL = tf.keras.models.load_model(IO_CONFIG['tr1cot']['path'], compile=False)
    # vq-vae
    TOKUN = tf.keras.models.load_model(IO_CONFIG['tokun']['path'], compile=False)
    TOKUN.trainable = False
    MODEL.set_vae(TOKUN)
    # compile
    MODEL.compile(
        optimizer=tf.keras.optimizers.AdamW(**OPTIMIZER_CONFIG),
        loss=tf.keras.losses.MeanAbsoluteError(reduction='sum_over_batch_size'), # tf.keras.losses.BinaryCrossentropy(**LOSS_CONFIG),
        weighted_metrics=[]) # byte_accuracy, token_accuracy
    # build tokun
    TOKUN(__X, training=False)
    # encode inputs
    __L = TOKUN.encode(__X, training=False)
    # build the model in the latent space
    MODEL((__L, __V), training=False)
    MODEL.compute_metrics((__L, __V), __L, __L)
    MODEL.compute_loss((__L, __V), __L, __L)
    # normalize the latent space
    # MODEL.adapt(DATASET_TRAIN)

In [None]:
# INSPECT ######################################################################

MODEL.summary()

In [None]:
print(MODEL.compute_metrics((__L, __V), __L, __L))
print(MODEL.compute_loss((__L, __V), __L, __L))

In [None]:
# DATAVIZ ######################################################################

def unpack(data: tf.Tensor) -> list:
    return [b'\n'.join(__s).decode('utf-8', errors='replace') for __s in data.numpy().tolist()]

def generate_samples(model: tf.keras.models.Model=MODEL, sample_num: int=1, step_num: int=8, eta_rate: float=0.1) -> str:
    __logits = model.generate_samples(sample_num=sample_num, total_step=step_num, eta_rate=0.1)
    __text = __postprocess_sampler(__logits)
    # return mlable.text.unpack(__text) # 1D
    return unpack(__text) # 2D

def print_sample(epoch: int=None, logs: dict=None, step_num: int=32, model: tf.keras.models.Model=MODEL) -> None:
    print(generate_samples(sample_num=1, step_num=step_num, model=model)[0])

## Train

In [None]:
# TRAIN ########################################################################

if TRAINING:
    with DISTRIBUTION_STRATEGY.scope():
        # callbacks
        cp_callback = tf.keras.callbacks.ModelCheckpoint(**CHECKPOINT_CONFIG)
        tb_callback = tf.keras.callbacks.TensorBoard(**TENSORBOARD_CONFIG)
        tn_callback = tf.keras.callbacks.TerminateOnNaN()
        gs_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=print_sample)
        # fit model
        TRAINING_HISTORY = MODEL.fit(
            x=DATASET_TRAIN.prefetch(tf.data.AUTOTUNE),
            validation_data=DATASET_TEST.prefetch(tf.data.AUTOTUNE),
            callbacks=[cp_callback, tb_callback, tn_callback, gs_callback],
            **TRAINING_CONFIG)

## Dataviz

In [None]:
# DATASET SAMPLES ##############################################################

__X = next(iter(DATASET_TRAIN.take(1)))
__Y = TOKUN(__X, logits=True)

In [None]:
__O_T = unpack(__postprocess_sampler(__Y))

In [None]:
__i = 0
print(__O_T[__i])

In [None]:
# GENERATE #####################################################################

__s = generate_samples(sample_num=4, step_num=256, model=MODEL)

In [None]:
__i = 2
print(__s[__i])

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir .logs