# Codecs As Generative Models 

In this notebook, we explore how to use codecs as generative models by sampling from their implicit distributions.

## 1. Introduction

A codec is an algorithm for compressing and decompressing data, often of a specific modality such as text or video. Codecs consist of a encoder/code, which describes data in a more concise form, and a decoder which reconstructs data from its encoding. We'll be concerning ourselves mainly with the former.

Let $L_C(x)$ denote the length of $x$ when compressed using some code $C(\cdot)$.

According to the Kraft-McMillan inequality, if $C$ is uniquely decodable, then there exists a probability distribution $p_C$, such that $p_C(x) = 2^{- L_C(x)}$. In other words, for any uniquely decodable code, we can always find a *statistical model* which produces matching code-lengths.  

Now, since $p_C$ is a generative model, it should in theory be possible to sample from it. In this way, codecs can be used to *create* data rather than just compress it. Now, for many codecs, the generated data is unlikely to be very interesting because they aren't tuned to *specific* sources such as "English text." Nonetheless, this exercise is useful for elucidating the implicit statistical assumptions present in human-designed compression schemes. 

## 2. Method

Let's be a little more precise about what we're attempting to do. 

Given some code $C(\cdot)$, we wish to sample from $p_C$ where $p_C(x) = 2^{-L_C(x)}$. 

1. At each time step $t$, we compute the distribution $p_C(x_1,\dots,x_t) = 2^{-{L_C(x_1,\dots,x_t)}}$. We then use top-k sampling to redistribute the probability mass amoungst $k$ *most likely* sequences.

2. We then randomly sample from the top-k distribution and return to step 1.

## 3. Implementation

In [None]:
import zlib
import math
import numpy as np
import torch

In [None]:
#@markdown Codec sampler class...

class CodecSampler:

    def __init__(self, code, alphabet):
        self.code = code
        self.alphabet = alphabet
    

    def length(self, sequence):
        return len(self.code(sequence)) * 8
    

    def complete(self, prompt, length=10, width=2, height=3):
        
        queue = [ prompt ]
        result = []


        while len(queue):

            sequence = queue.pop()
            lengths = [self.length(sequence + [token]) for token in self.alphabet]

            # Find shortest encodings...

            shortest = torch.topk(torch.tensor(lengths), k=width, largest=False)

            for token in shortest.indices:

                child = sequence + [token.item()]
                array = result if len(child) == length else queue
                array.append(child)
        
        lengths = [self.length(res) for res in result]
        shortest = torch.topk(torch.tensor(lengths), k=height, largest=False)

        return torch.tensor(result)[shortest.indices]

    # def complete(self, sequence, size=10, width=2, height=3):

    #     sequence = list(sequence)

    #     for i in range(size):

    #         lengths = [self.length(sequence + [token]) for token in range(self.alphabet)]
    #         shortest = torch.topk(torch.tensor(lengths), k=width, largest=False)

    #         probabilities = [1 / (2 ** (8 * (length/len(sequence)))) for length in shortest.values]
    #         probabilities = np.array(probabilities)
    #         probabilities = probabilities / probabilities.sum()


    #         # TODO: apply tempurature

    #         choice = np.random.choice(shortest.indices, p=probabilities)
    #         sequence += [choice]
        
    #     return sequence
        



In [None]:
import gzip

In [None]:
a = CodecSampler(lambda x: zlib.compress(bytes(x)), alphabet=range(256))
b = CodecSampler(lambda x: gzip.compress(bytes(x)), alphabet=range(256))

Wow! This is really cool. `zlib` successfully continues the alternting pattern. Let's see what other patterns it tends to recognise.

In [None]:
a.complete([1,2,3,4,5,6,7,8,9,10, 1,2,3], length=20)

tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10,  1,  2,  3,  4,  5,  6,  7,  8,
          9, 10],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10,  1,  2,  3,  4,  5,  6,  7,  8,
          9,  0],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10,  1,  2,  3,  4,  5,  6,  7,  8,
          0,  0]])

In contrast, counting doesn't seem to be a pattern recognised. 

In [None]:
prompt = '''
nucleus sampling also called top p sampling is a more advanced
version of top sampling that results in a more consistent sampling 
performance it cuts between selectable and non selectable tokens 
based on the sum of their probabilities totaled bar heights when 
going from left to right on the linked picture until the specified 
cut value p is reached as opposed to top k sampling which cuts 
based on position index it can be used similarity to top k sampling 
you can combine low value of nucleus sampling with high value of 
randomness while the other sampling methods set to off to break out
of loops with coherence once you ve got your head around how these 
work you can try layering settings it is difficult to explain 
sampling methods in simple terms without getting sloppy but it is 
not a complicated concept try to visualize the process using the 
bars as crutches randomness controls bar heights while sampling 
controls where the dividing line will be between bars that can be 
selected and bars tokens that will be discarded from participating in a'''.replace('\n', ' ').replace('  ', ' ')

In [None]:
prompt_chunks = [prompt[i:i+2] for i in range(0, len(prompt), 2)]

In [None]:
vocabulary = list(set(prompt_chunks))

In [None]:
len(vocabulary)

184

In [None]:
prompt_tokenized = [vocabulary.index(word) for word in prompt_chunks]

In [None]:
len(bytes(prompt_tokenized))

527

In [None]:
len(zlib.compress(bytes(prompt_tokenized)))

511

In [None]:
import random

def get_next_token(prompt, model, n=5, width=2, height=2):

    prompt = list(prompt)

    for i in range(n):
        completions = model.complete(prompt, width=width, height=height, length=len(prompt) + 3)
        completion = random.choice(completions)[-3 :]
        prompt += completion

    #out = ''.join([vocabulary[val] for val in prompt])
    #return out

    return prompt
    

In [None]:
def complete(x): return bytes(a.complete(list(bytes(x, 'ascii')), width=2, height=5, length=len(x) + 3)[0])

In [None]:
bytes(get_next_token(bytes('abcdabcd', 'ascii'), a, width=5, n=10))

b'abcdabcdabcdabcdabcdab\xa9abcda\xa9\xa9\xabbbbbbbb'

In [None]:
!pip install python-rle

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting python-rle
  Downloading python_rle-0.0.3-py3-none-any.whl (6.1 kB)
Installing collected packages: python-rle
Successfully installed python-rle-0.0.3


In [None]:
import rle


def rle_encoder(data):
    result = b''

    for char, rl in zip(*rle.encode(data)):
        
        result += bytes([char])
        if rl > 1:
            result += bytes(str(rl).encode('ascii'))

    return result

In [None]:
c = CodecSampler(rle_encoder, alphabet=range(10))

In [None]:
(get_next_token(bytes('\x01', 'ascii'), c, width=1,height=1, n=6))

[1,
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8),
 tensor(8)]

In [None]:
(get_next_token(bytes('\x01', 'ascii'), c, width=4,height=2, n=10))

[1,
 tensor(6),
 tensor(6),
 tensor(6),
 tensor(6),
 tensor(6),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(9),
 tensor(7),
 tensor(7),
 tensor(7),
 tensor(7),
 tensor(7),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(5),
 tensor(5)]

As expected, under this model, high likelyhood is given to strings with many repeats. The model makes this implicit statistical assumption.

In [None]:
!pip install lzw3

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting lzw3
  Downloading lzw3-0.4-py3-none-any.whl (24 kB)
Installing collected packages: lzw3
Successfully installed lzw3-0.4


In [None]:
import lzw3

In [None]:
from lzw3 import *

In [None]:
lzw3

<module 'lzw3' from '/usr/local/lib/python3.7/dist-packages/lzw3/__init__.py'>