## Import deps

In [None]:
!pip install -U datasets mlable tokun llaminate

In [None]:
import datetime
import functools
import itertools
import math
import os
import random
import urllib.request

import datasets as hd
import tensorflow as tf

import mlable.data
import mlable.metrics
import mlable.text

import llaminate.models
import llaminate.pipeline
import llaminate.utils

In [None]:
print("Tensorflow version " + tf.__version__)

## Setup the GPU / TPU

In [None]:
# MIXED PRECISION #############################################################

tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')

In [None]:
# DEVICES #####################################################################

tf.debugging.set_log_device_placement(False)

CPU = tf.config.list_logical_devices('CPU')
GPU = tf.config.list_logical_devices('GPU')
TPU = tf.config.list_logical_devices('TPU')

if TPU:
    RESOLVER = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(RESOLVER)
    tf.tpu.experimental.initialize_tpu_system(RESOLVER)
    DISTRIBUTION_STRATEGY = tf.distribute.TPUStrategy(RESOLVER)
elif GPU:
    DISTRIBUTION_STRATEGY = tf.distribute.MirroredStrategy(GPU)
else:
    DISTRIBUTION_STRATEGY = tf.distribute.MirroredStrategy(CPU)

print(DISTRIBUTION_STRATEGY)

## Mode

In [None]:
# TOGGLE ######################################################################

IMPORT = False
DOWNLOAD = False
TRAINING = True

## Defining The Metadata

In [None]:
# COMMON #######################################################################

BASE_CONFIG = {
    'batch_dim': 32,
    'token_dim': 4, # in bytes
    'sample_dim': 4 * 128, # in bytes
    'drop_dim': 0, # in bytes
    'encoding': 'UTF-8',}

In [None]:
# MODEL PARAMETERS ############################################################

LLAMINATE_CONFIG = {
  'layer_num': 12,
  'head_num': 4,
  'token_dim': BASE_CONFIG['token_dim'],
  'embed_dim': 64,
  'head_dim': 64 // 4,
  'hidden_dim': 64 * 4,
  'epsilon': 1e-6,
  'dropout': 0.0,}

In [None]:
# DERIVED PARAMETERS ##########################################################

META_CONFIG = {
    'version': '{}x{}x{}'.format(LLAMINATE_CONFIG['layer_num'], LLAMINATE_CONFIG['token_dim'], LLAMINATE_CONFIG['embed_dim']),
    'path': 'llaminate.keras',
    'url': '',}

In [None]:
# PIPELINE ####################################################################

BATCH_CONFIG = {
    'batch_size': BASE_CONFIG['batch_dim'],
    'drop_remainder': True,
    'num_parallel_calls': tf.data.AUTOTUNE,}

SHUFFLE_CONFIG = {
    'buffer_size': 4096 * BATCH_CONFIG['batch_size'],}

PIPELINE_CONFIG = {
    'batch_dim': BASE_CONFIG['batch_dim'],
    'sample_dim': BASE_CONFIG['sample_dim'],
    'token_dim': BASE_CONFIG['token_dim'],
    'drop_dim': BASE_CONFIG['drop_dim'],
    'data_weight': 1.0,
    'padding_weight': 0.0001,
    'separator': '\x1d',
    'encoding': BASE_CONFIG['encoding'],}

In [None]:
# TRAINING PARAMETERS #########################################################

TRAINING_CONFIG = {
    'epochs': 8,
    'batch_size': None,
    'validation_split': None,
    'validation_freq': list(range(1, 9)),
    # 'class_weight': {__c: 0.2 if __c in [0, 10] else 1. for __c in range(PIPELINE_CONFIG['channel_dim'])},
    'verbose': 1,}

OPTIMIZER_CONFIG = {
    'learning_rate': 0.001 * (0.1 if IMPORT else 1.0),
    'weight_decay': 0.001,
    'beta_1': 0.9,
    'beta_2': 0.95,
    'epsilon': 1e-7,
    'clipnorm': 1.0,
    'amsgrad': False,
    'use_ema': False,
    'ema_momentum': 0.99,
    'ema_overwrite_frequency': 1024,}

SCHEDULER_CONFIG = {
    'initial_learning_rate': OPTIMIZER_CONFIG['learning_rate'],
    'decay_steps': 16384 * 8,
    'alpha': 0.01,
    'name': 'cosine_lr',
    'warmup_target': None,
    'warmup_steps': 0,}

SCHEDULER_CONFIG = {
    'initial_learning_rate': OPTIMIZER_CONFIG['learning_rate'],
    'decay_steps': TRAINING_CONFIG['epochs'] * 8 * 1024,
    'alpha': 0.01,
    'name': 'cosine_lr',
    'warmup_target': None,
    'warmup_steps': 0,}

METRICS_CONFIG = {
    'depth': -1,
    'axis': -1,
    'dtype': tf.uint8,}

LOSS_CONFIG = {
    'from_logits': True,
    'label_smoothing': 0.0,
    'axis': -1,
    'reduction': 'sum_over_batch_size',
    'name': 'loss',}

CHECKPOINT_CONFIG = {
    'filepath': META_CONFIG['path'],
    'monitor': 'val_loss',
    'mode': 'auto',
    'save_freq': 'epoch',
    'save_best_only': False,
    'save_weights_only': False,
    'verbose': 1,}

TENSORBOARD_CONFIG = {
    'log_dir': os.path.join('.logs/', META_CONFIG['version'], datetime.datetime.now().strftime("%Y%m%d-%H%M%S")),
    'histogram_freq': 1,
    'embeddings_freq': 4,
    'profile_batch': (4, 8),
    'write_graph': False,
    'write_images': True,}

In [None]:
# DATASETS ####################################################################

DATASETS_CONFIG = {
    'pt-fineweb-edu': {
        'path': 'HuggingFaceFW/fineweb-edu',
        'name': 'sample-10BT',
        'splits': [f'train[{__p}%:{__p + 10}%]' for __p in range(0, 100, 10)],
        'features': ['text'],},
    # 'pt-fineweb-kor': {
    #     'path': 'HuggingFaceFW/fineweb-2',
    #     'name': 'kor_Hang',
    #     'splits': [f'train[{__p}%:{__p + 10}%]' for __p in range(0, 100, 10)],
    #     'features': ['text'],},
    # 'pt-fineweb-fin': {
    #     'path': 'HuggingFaceFW/fineweb-2',
    #     'name': 'fin_Latn',
    #     'splits': [f'train[{__p}%:{__p + 10}%]' for __p in range(0, 100, 10)],
    #     'features': ['text'],},
    # 'pt-wikipedia': {
    #     'path': 'wikimedia/wikipedia',
    #     'name': '20231101.en',
    #     'splits': [f'train[{__p}%:{__p + 9}%]' for __p in range(0, 80, 8)],
    #     'features': ['text'],},
    # 'tp-wikipedia-1': {
    #     'path': 'wikimedia/wikipedia',
    #     'name': '20231101.en',
    #     'splits': [f'train[{__p}%:{__p + 1}%]' for __p in range(80, 90, 1)],
    #     'features': ['text'],},
    # 'tp-wikipedia-2': {
    #     'path': 'wikimedia/wikipedia',
    #     'name': '20231101.en',
    #     'splits': [f'train[{__p}%:{__p + 1}%]' for __p in range(90, 100, 1)],
    #     'features': ['text'],},
    # 'ft-retro-ascii-art': {
    #     'path': 'jdpressman/retro-ascii-art-v1',
    #     'name': None,
    #     'train': 'train',
    #     'splits': [f'train[{__p}%:{__p + 10}%]+validation[{__p}%:{__p + 10}%]' for __p in range(0, 100, 10)],
    #     'features': ['prompt', 'art_aic'],},
    # 'ft-stack-exchange': {
    #     'path': 'Alignment-Lab-AI/Stack-Exchange-April',
    #     'name': None,
    #     'splits': [f'train[{__p}%:{__p + 10}%]' for __p in range(0, 100, 10)],
    #     'features': ['question', 'answer'],},
    # 'ft-math': {
    #     'path': 'HuggingFaceTB/finemath',
    #     'name': 'finemath-3plus',
    #     'splits': [f'train[{__p}%:{__p + 10}%]' for __p in range(0, 100, 10)],
    #     'features': ['text'],},
    # 'cot-text-dolphin': {
    #     'path': 'cognitivecomputations/dolphin-r1',
    #     'name': 'reasoning-deepseek',
    #     'splits': [f'train[{__p}%:{__p + 10}%]' for __p in range(0, 100, 10)],
    #     'features': ['reasoning', 'answer'],},
    # 'cot-text-openthoughts': {
    #     'path': 'open-thoughts/OpenThoughts-114k',
    #     'name': 'default',
    #     'splits': [f'train[{__p}%:{__p + 10}%]' for __p in range(0, 100, 10)],
    #     'features': ['problem', 'solution'],},
    # 'ft-asciiart-asciiart': {
    #     'path': 'apehex/ascii-art',
    #     'name': 'asciiart',
    #     'splits': [f'train[{__p}%:{__p + 9}%]' for __p in range(0, 100, 10)],
    #     'features': ['content'],},
    # 'ft-asciiart-copypasta': {
    #     'path': 'apehex/ascii-art',
    #     'name': 'copypasta',
    #     'splits': [f'train[{__p}%:{__p + 9}%]' for __p in range(0, 100, 10)],
    #     'features': ['content'],},
    # 'ft-asciiart-graffiti': {
    #     'path': 'apehex/ascii-art',
    #     'name': 'graffiti',
    #     'splits': [f'train[{__p}%:{__p + 9}%]' for __p in range(0, 100, 10)],
    #     'features': ['content'],},
    # 'ft-asciiart-images': {
    #     'path': 'apehex/ascii-art',
    #     'name': 'images',
    #     'splits': [f'train[{__p}%:{__p + 9}%]' for __p in range(0, 100, 10)],
    #     'features': ['content'],},
    # 'ft-asciiart-datacompdr': {
    #     'path': 'apehex/ascii-art-datacompdr-12m',
    #     'name': 'default',
    #     'splits': [f'train[{__p}%:{__p + 9}%]' for __p in range(0, 100, 10)],
    #     'features': ['content'],},
    # 'cot-math-numi': {
    #     'path': 'AI-MO/NuminaMath-CoT',
    #     'name': None,
    #     'splits': [f'train[{__p}%:{__p + 10}%]' for __p in range(0, 100, 10)],
    #     'features': ['problem', 'solution'],},
}

In [None]:
os.makedirs(TENSORBOARD_CONFIG['log_dir'], exist_ok=True)

## Loading The Weights

In [None]:
# DERIVED PARAMETERS ##########################################################

if IMPORT and DOWNLOAD:
    urllib.request.urlretrieve(META_CONFIG['url'], META_CONFIG['path'])

## Loading The Data

In [None]:
# DOWNLOAD ####################################################################

DATASETS = {
    __name: [
        hd.load_dataset(path=__args['path'], name=__args['name'], split=__s).to_tf_dataset(shuffle=True, batch_size=None)
        for __s in __args['splits']]
    for __name, __args in DATASETS_CONFIG.items()}

## Checking The Data

In [None]:
# STATS #######################################################################

STATS = {__n: mlable.data.stats(dataset=DATASETS[__n][0], features=DATASETS_CONFIG[__n]['features'], count=2048) for __n in DATASETS}

print(STATS)

In [None]:
__b = iter(DATASETS['pt-fineweb-edu'][0])
next(__b)

## Preprocess

In [None]:
# ITERATE #####################################################################

for __name in DATASETS:
    # specialized preprocessing fn
    __preprocess = llaminate.pipeline.preprocess_factory(
        features=DATASETS_CONFIG[__name]['features'],
        **PIPELINE_CONFIG)
    # apply
    for __idx in range(len(DATASETS[__name])):
        DATASETS[__name][__idx] = DATASETS[__name][__idx].batch(**BATCH_CONFIG).map(__preprocess, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
# CONCATENATE #################################################################

DATASET_KEYS = set(DATASETS.keys()) - {'ft-retro-ascii-art'}

FINE_TRAIN = functools.reduce(lambda __l, __r: __l.concatenate(__r), DATASETS['pt-fineweb-edu'][:-1])
FINE_TEST = DATASETS['pt-fineweb-edu'][-1]

DATASET_TRAIN = functools.reduce(lambda __l, __r: __l.concatenate(__r), [DATASETS[__n][__i] for __n in DATASET_KEYS for __i in range(len(DATASETS[__n]) - 1)]) # - {'pt-wikipedia'}
DATASET_TEST = functools.reduce(lambda __l, __r: __l.concatenate(__r), [DATASETS[__n][-1] for __n in DATASET_KEYS]) # - {'pt-wikipedia'}

In [None]:
# CHECK DATASET ###############################################################

__X, __T, __W = next(iter(FINE_TRAIN.take(1)))

print(FINE_TRAIN.element_spec)
print(FINE_TEST.element_spec)

print(DATASET_TRAIN.element_spec)
print(DATASET_TEST.element_spec)

print('fine: {:,} / {:,} samples'.format(FINE_TRAIN.cardinality().numpy(), FINE_TEST.cardinality().numpy()))
print('total: {:,} / {:,} samples'.format(DATASET_TRAIN.cardinality().numpy(), DATASET_TEST.cardinality().numpy()))

## Initializing The Model

In [None]:
# OVERALL SCOPE ###############################################################

with DISTRIBUTION_STRATEGY.scope():
    # COSINE LR ###############################################################
    cosine_lr = tf.keras.optimizers.schedules.CosineDecay(**SCHEDULER_CONFIG)
    OPTIMIZER_CONFIG['learning_rate'] = cosine_lr

    # METRICS #################################################################
    byte_accuracy = mlable.metrics.BinaryGroupAccuracy(group=1, name='byte_accuracy', **METRICS_CONFIG)
    character_accuracy = mlable.metrics.BinaryGroupAccuracy(group=4, name='character_accuracy', **METRICS_CONFIG)
    token_accuracy = mlable.metrics.BinaryGroupAccuracy(group=PIPELINE_CONFIG['token_dim'], name='token_accuracy', **METRICS_CONFIG)

    # WEIGHTS #################################################################
    LLAMINATE = llaminate.models.Transformer(**LLAMINATE_CONFIG)
    if IMPORT and os.path.isfile(META_CONFIG['path']): LLAMINATE = tf.keras.models.load_model(META_CONFIG['path'], compile=False)

    # COMPILE #################################################################
    LLAMINATE.compile(
        optimizer=tf.keras.optimizers.AdamW(**OPTIMIZER_CONFIG),
        loss=tf.keras.losses.BinaryCrossentropy(**LOSS_CONFIG),
        weighted_metrics=[byte_accuracy, character_accuracy, token_accuracy])

    # BUILD ###################################################################
    LLAMINATE(__X, training=False)
    LLAMINATE.compute_metrics(__X, __T, __T, __W)
    LLAMINATE.compute_loss(__X, __T, __T, __W)

In [None]:
LLAMINATE.summary()

## Train

In [None]:
# TRAIN #######################################################################

if TRAINING:
    with DISTRIBUTION_STRATEGY.scope():
        # callbacks
        cp_callback = tf.keras.callbacks.ModelCheckpoint(**CHECKPOINT_CONFIG)
        tb_callback = tf.keras.callbacks.TensorBoard(**TENSORBOARD_CONFIG)
        tn_callback = tf.keras.callbacks.TerminateOnNaN()
        # model fitting
        TRAINING_HISTORY = LLAMINATE.fit(
            x=DATASETS['pt-fineweb-edu'][0].prefetch(tf.data.AUTOTUNE),
            validation_data=DATASETS['pt-fineweb-edu'][-1].take(128).prefetch(tf.data.AUTOTUNE),
            callbacks=[cp_callback, tb_callback, tn_callback],
            **TRAINING_CONFIG)

## Dataviz

In [None]:
__i = iter(DATASETS['pt-fineweb-edu'][-1])

In [None]:
__x, __t, __w = next(__i)
__y = LLAMINATE(__x)

In [None]:
__s = llaminate.pipeline.postprocess(__y)
__s[:4]

In [None]:
__s = llaminate.pipeline.postprocess(__t)
__s[:4]

In [None]:
# DATA ########################################################################

SAMPLES = [
    """위키백과, 우리 모두의 백과사전.\nt-분포 확률적 임베딩(t-SNE)은 데이터의 차원 축소에 사용되는 기계 학습 알고리즘 중 하나로, 2002년 샘 로이스Sam Rowise와 제프리 힌튼에 의해 개발되었다.[1] t-SNE는 비선형 차원 축소 기법으로, 고차원 데이터를 특히 2, 3차원 등으로 줄여 가시화하는데에 유용하게 사용된다. 구체적으로 t-SNE는 비슷한 데이터는 근접한 2, 3차원의 지점으로, 다른 데이터는 멀리 떨어진 지점으로 맵핑한다.""",
    """class Encoder(tf.keras.models.Model):\n    def __init__(self, depth: int, token_dim: int, encoding_dim: int, embedding_dim: int, latent_dim: int, batch_dim: int=None, attention: bool=False, **kwargs) -> None:\n        super(Encoder, self).__init__(**kwargs)\n        self._encoder = tf.keras.Sequential([\n            tf.keras.Input(shape=(encoding_dim,), batch_size=batch_dim, name='input'), # (B * G ^ D, U)\n            tf.keras.layers.Dense(units=embedding_dim, activation=None, use_bias=False, kernel_initializer='glorot_uniform', bias_initializer=None, name='embed-1'),] # (B * G ^ D, U) => (B * G ^ D, E)\n            + [tokun.layers.TokenizeBlock(left_axis=-2, right_axis=-1, token_dim=token_dim, latent_dim=latent_dim, attention=attention, name='tokenize' + (__i + 1) * '-4') for __i in range(depth)]) # (B * G ^ i, E) => (B * G ^ (i-1), E)\n\n    def call(self, x: tf.Tensor) -> tf.Tensor:\n        return self._encoder(x)\n""",
    """class AutoEncoder(tf.keras.models.Model):\n    def __init__(self, token_dim: int, encoding_dim: int, embedding_dim: int, latent_dim: int, batch_dim: int=None, **kwargs) -> None:\n        super(AutoEncoder, self).__init__(**kwargs)\n        self._encoder = Encoder(token_dim=token_dim, encoding_dim=encoding_dim, embedding_dim=embedding_dim, latent_dim=latent_dim, batch_dim=batch_dim)\n        self._decoder = Decoder(token_dim=token_dim, encoding_dim=encoding_dim, embedding_dim=embedding_dim, latent_dim=latent_dim, batch_dim=batch_dim)\n\n    def call(self, x: tf.Tensor) -> tf.Tensor:\n        return self._decoder(self._encoder(x))""",
    """class AutoEncoder(tf.keras.models.Model):\n  def __init__(self, token_dim: int, encoding_dim: int, embedding_dim: int, latent_dim: int, batch_dim: int=None, **kwargs) -> None:\n    super(AutoEncoder, self).__init__(**kwargs)\n    self._encoder = Encoder(token_dim=token_dim, encoding_dim=encoding_dim, embedding_dim=embedding_dim, latent_dim=latent_dim, batch_dim=batch_dim)\n    self._decoder = Decoder(token_dim=token_dim, encoding_dim=encoding_dim, embedding_dim=embedding_dim, latent_dim=latent_dim, batch_dim=batch_dim)\n\n  def call(self, x: tf.Tensor) -> tf.Tensor:\n    return self._decoder(self._encoder(x))"""]

In [None]:
# CACHE #######################################################################

__cache = llaminate.utils.create_cache(batch_dim=N_BATCH_DIM, cache_dim=N_CACHE_DIM, head_dim=N_HEAD_DIM, layer_num=N_LAYERS_NUM, head_num=N_HEADS_NUM)
__step = 4

In [None]:
# PREPROCESS ##################################################################

__prompt = """Skynet is an artificial neural network-based conscious group mind and artificial general superintelligence system that serves as the antagonistic force of the Terminator franchise."""
__inputs = mlable.text.preprocess(text=__prompt, token_size=PIPELINE_CONFIG['sample_dim'], expand_dims=[1], output_dtype=tf.uint8)

In [None]:
__inputs = mlable.shaping.divide(__inputs, input_axis=-2, output_axis=-1, factor=PIPELINE_CONFIG['token_dim'], insert=True)

In [None]:
# PREDICT #####################################################################

__predictions = LLAMINATE(__inputs)
__outputs = llaminate.pipeline.postprocess(__predictions)

In [None]:
mlable.text.chunk(__prompt, size=4)

In [None]:
__outputs

In [None]:
__batch = iter(DATASETS['pt-wikipedia'][1])

In [None]:
__x, __y, __m = next(__batch)
__p = LLAMINATE(inputs=__x, training=True, mask=None)

In [None]:
__yt = mlable.text.interpret(__y, binary=True)
__yp = mlable.text.interpret(__p, binary=True)
__it = mlable.text.decode(__x)
__ot = mlable.text.decode(__yt)
__op = mlable.text.decode(__yp)

In [None]:
print(__it[:2])
print(__ot[:2])
print(__op[:2])

## Logs

In [None]:
%tensorboard --logdir .logs