In [None]:
from lz78 import Sequence, LZ78SPA
# from lz_python.lz import LZModel

import tensorflow_datasets as tfds
import tensorflow as tf
from sys import stdout
import numpy as np
from tqdm import tqdm

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
tf.config.set_visible_devices([], 'GPU')

## Data Loading

In [None]:
class PG19DataLoader:
    def __init__(self, data_type: str, start_index: int = 0, batch_size: int = 1, normalize: str = 'none'):
        self.data = tfds.load('pg19', split=data_type, shuffle_files=False)
        self.dataset = (self.data
                        .skip(start_index)
                        .batch(batch_size)
                        .prefetch(tf.data.experimental.AUTOTUNE))
        print(data_type, ": ", len(self.dataset))

    def __len__(self):
        return len(self.dataset)

    def __iter__(self):
        for batch in self.dataset:
            text_bytes = np.frombuffer(batch['book_text'].numpy()[0], dtype=np.uint8)
            text_bytes = text_bytes.tolist()
            yield text_bytes

## Set Up Models

In [None]:
class ConfigObject:
    def __init__(self, config_dict):
        self.__dict__.update(config_dict)

config = ConfigObject({
    "top_k": 256,
    "method": "Depth-Guided", # ensemble
    "ensemble_max_num": 6,
    "min_depth": 10,
    "vocab_size": 256,
    "adaptive_gamma": "none",
    "gamma": 1/256,
    "lower_bound": 1e-5,
    "temp": 1,
    "ensemble_type": "depth",
    "lb_or_temp": "lb_first",
})

In [None]:
py_lz = LZModel(config)

In [None]:
rust_lz = LZ78SPA(alphabet_size=256, gamma=1/256, compute_training_loss=False)

## Train Models

In [None]:
N_TRAIN = 100

stdout.flush()
train_dataloader = PG19DataLoader("train")
rust_lz.reset_state()
for trn_iter, batch in enumerate(tqdm(train_dataloader, desc="Building LZ tree"), start=1):
    # build LZ model only 1 epoch
    stdout.flush()
    rust_lz.train_on_block(Sequence(batch, alphabet_size=256))
    rust_lz.reset_state()

    if trn_iter >= N_TRAIN:
        break

In [None]:
train_dataloader = PG19DataLoader("train")
for trn_iter, batch in enumerate(tqdm(train_dataloader, desc="Building LZ tree"), start=1):
    # build LZ model only 1 epoch
    py_lz.build_tree(batch)

    if trn_iter >= N_TRAIN:
        break

## Evaluate Models

In [None]:
rust_lz.set_inference_config(
    lb=1e-5,
    temp=1,
    lb_or_temp_first="lb_first",
    ensemble_type="depth",
    ensemble_n=6,
    adaptive_gamma="disabled",
    backshift_parsing=True,
    backshift_ctx_len=10,
    backshift_break_at_phrase=True
)

py_lz.config = ConfigObject({
    "top_k": 256,
    "method": "Depth-Guided", # ensemble
    "ensemble_max_num": 6,
    "min_depth": 10,
    "vocab_size": 256,
    "adaptive_gamma": "none",
    "gamma": 1/256,
    "lower_bound": 1e-5,
    "temp": 1,
    "ensemble_type": "depth",
    "lb_or_temp": "lb_first",
})

In [None]:
val_dataloader = PG19DataLoader("validation")
test_seq = next(iter(val_dataloader))

In [None]:
test_seqs = []
for i in range(0, len(test_seq)-1023, 512):
    test_seqs.append(test_seq[i:i+1024])

test_seqs = test_seqs[:10]

In [None]:
res = rust_lz.compute_test_loss_parallel(
    [Sequence(seq, alphabet_size=256) for seq in test_seqs], output_patch_info=False
)

In [None]:
stdout.flush()

inputs = [Sequence(seq[512:],alphabet_size=256) for seq in test_seqs]
ctxs = [Sequence(seq[:512],alphabet_size=256) for seq in test_seqs]

res = rust_lz.compute_test_loss_parallel(
    inputs, ctxs, num_threads=32, output_prob_dists=False, output_per_symbol_losses=False
)

In [None]:
print(np.array([x["avg_log_loss"] for x in res]))

In [None]:
py_lz_losses = []
for seq in test_seqs:
    depths, btb, _ = py_lz.get_depth_and_perplexity(seq)
    py_lz_losses.append(float(np.mean(btb)))
print(np.array(py_lz_losses))

## Time Full Validation

In [None]:
log_loss = 0
n_seqs = 0
val_dataloader = PG19DataLoader("validation")
for seq in tqdm(val_dataloader):
    stdout.flush()
    test_seqs = []
    for i in range(0, len(seq)-1023, 512):
        test_seqs.append(seq[i:i+1024])

    rust_lz.reset_state()

    inputs = [Sequence(seq[64:],alphabet_size=256) for seq in test_seqs]
    ctxs = [Sequence(seq[:64],alphabet_size=256) for seq in test_seqs]

    res = rust_lz.compute_test_loss_parallel(
        inputs, ctxs, num_threads=32, output_prob_dists=False, output_per_symbol_losses=False
    )

    log_loss += np.sum(np.array([x["avg_log_loss"] for x in res]))
    n_seqs += len(test_seqs)

In [None]:
print(f"Val loss: {float(log_loss / n_seqs)}")

## Return Patch Information

In [None]:
val_dataloader = PG19DataLoader("validation")
test_seq = next(iter(val_dataloader))[:40]

In [None]:
res = rust_lz.compute_test_loss( # also works for the parallel version!
    Sequence(test_seq, alphabet_size=256), output_prob_dists=False, output_per_symbol_losses=True, output_patch_info=True
)

In [None]:
# This looks reasonable; should maybe debug more
for info in res['patch_info']:
    print(f"{info[0]} through {info[1] - 1}")

In [None]:
import matplotlib.pyplot as plt
plt.stem(np.array(res['log_losses']))
plt.title("Log Loss per Symbol")