<a href="https://colab.research.google.com/github/EugenHotaj/pytorch-generative/blob/master/notebooks/__draft__coding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
import math
from typing import Optional
from dataclasses import dataclass
import heapq

In [12]:
message = "A_DEAD_DAD_CEDED_A_BAD_BABE_A_BEADED_ABACA_BED"

abc = {}
for m in message:
    abc[m] = abc.get(m, 0) + 1

total = sum(abc.values())
abc_probs = {k:v/total for k, v in abc.items()}

## Huffman Codes

In [5]:
@dataclass
class Node:
    data: str 
    weight: int
    left: Optional["Node"] = None
    right: Optional["Node"] = None

    def __lt__(self, other: "Node"):
        return self.weight < other.weight

In [6]:
def build_tree(abc):
    nodes = []
    for k, v in abc.items():
        nodes.append(Node(data=k, weight=v))
    heapq.heapify(nodes)

    while len(nodes) > 1:
        n1, n2 = heapq.heappop(nodes), heapq.heappop(nodes)
        data = n1.data + n2.data
        weight = n1.weight + n2.weight
        new_node = Node(data=data, weight=weight, left=n1, right=n2)
        nodes.append(new_node)

    return nodes[0]

def build_codes(tree):

    def _dfs(tree, codes, prefix):
        if not tree.left:
            codes[tree.data] = prefix 
        else:
            _dfs(tree.left, codes, prefix + "0")
            _dfs(tree.right, codes, prefix + "1")

    codes = {}
    _dfs(tree, codes, "")
    return codes

def encode(message, codes):
    encoded = ""
    for m in message:
        encoded += codes[m]
    return encoded

def decode(message, codes):
    inverse_codes = {v: k for k, v in codes.items()}
    decoded = ""
    current = ""
    for m in message:
        current += m
        if current in inverse_codes:
            decoded += inverse_codes[current]
            current = ""
    return decoded

In [7]:
tree = build_tree(abc)
codes = build_codes(tree)

In [8]:
encoded = encode(message, codes)
decoded = decode(encoded, codes)

print(decoded == message)

True


## Symmetric Numeral Systems

In [195]:
def encode(message, abc):
    n_abc = len(abc)
    idxs = {k:i for i, k in enumerate(abc)}
    encoded = 1
    for m in reversed(message):
        encoded = encoded * n_abc + idxs[m]
    return encoded


def decode(encoded, abc):
    n_abc = len(abc)
    idxs = {i: k for i, k in enumerate(abc)}
    message = ""
    while encoded != 1:
        message += idxs[encoded % n_abc]
        encoded = encoded // n_abc
    return message 


In [197]:
encoded = encode(message, abc)
decoded = decode(encoded, abc)
decoded == message

True

## Asymetric Numeral Systems

In [198]:
def scale_abc(abc, n=10):
    """Scales the alphabet counts into the range [0, 2^n]."""
    abc_total = sum(abc.values())
    abc_probs = {k: v / abc_total for k, v in abc.items()}

    n = 1 << n 
    errors = []
    rescaled_abc = {}
    for k, v in abc_probs.items():
        vn = v * n
        new_v = math.floor(vn)
        error = vn - new_v
        rescaled_abc[k] = new_v
        errors.append((error, k))
    
    left = n - sum(rescaled_abc.values())
    errors = sorted(errors, key=lambda x: -x[0])
    for _, k in errors[:left]:
        rescaled_abc[k] += 1
    return rescaled_abc


def encode(message, scaled_abc, n=10):
    full_range = list(scaled_abc.values())
    idxs = {k: i for i, k in enumerate(scaled_abc)}

    encoded = 1
    for s in reversed(message):
        c_s = sum(full_range[:idxs[s]])
        f_s = scaled_abc[s]
        encoded = (encoded // f_s << n) + (encoded % f_s + c_s)
    return encoded


def decode(encoded, abc, n=10):
    full_range = list(scaled_abc.values())
    keys = list(scaled_abc.keys())
    mask = (1 << n) - 1

    message = ""
    while encoded != 1:
        modded = encoded & mask
        c_s, idx, s = 0, 0, None
        while True:
            if c_s + full_range[idx] <= modded:
                c_s += full_range[idx]
                idx +=1
            else:
                s = keys[idx]
                break
        message += s
        f_s = scaled_abc[s]
        encoded = f_s * (encoded >> n) + modded - c_s

    return message 

In [199]:
n = 32
scaled_abc = scale_abc(abc, n)
encoded = encode(message, scaled_abc, n)
decoded = decode(encoded, scaled_abc, n)