<a href="https://colab.research.google.com/github/EugenHotaj/pytorch-generative/blob/master/notebooks/__draft__coding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
from typing import Optional
from dataclasses import dataclass
import heapq

In [5]:
message = "A_DEAD_DAD_CEDED_A_BAD_BABE_A_BEADED_ABACA_BED"

## Huffman Codes

In [19]:
@dataclass
class Node:
    data: str 
    weight: int
    left: Optional["Node"] = None
    right: Optional["Node"] = None

    def __lt__(self, other: "Node"):
        return self.weight < other.weight

In [20]:
abc = {}
for m in message:
    abc[m] = abc.get(m, 0) + 1

In [21]:
def build_tree(abc):
    nodes = []
    for k, v in abc.items():
        nodes.append(Node(data=k, weight=v))
    heapq.heapify(nodes)

    while len(nodes) > 1:
        n1, n2 = heapq.heappop(nodes), heapq.heappop(nodes)
        data = n1.data + n2.data
        weight = n1.weight + n2.weight
        new_node = Node(data=data, weight=weight, left=n1, right=n2)
        nodes.append(new_node)

    return nodes[0]

def build_codes(tree):

    def _dfs(tree, codes, prefix):
        if not tree.left:
            codes[tree.data] = prefix 
        else:
            _dfs(tree.left, codes, prefix + "0")
            _dfs(tree.right, codes, prefix + "1")

    codes = {}
    _dfs(tree, codes, "")
    return codes

def encode(message, codes):
    encoded = ""
    for m in message:
        encoded += codes[m]
    return encoded

def decode(message, codes):
    inverse_codes = {v: k for k, v in codes.items()}
    decoded = ""
    current = ""
    for m in message:
        current += m
        if current in inverse_codes:
            decoded += inverse_codes[current]
            current = ""
    return decoded

In [22]:
tree = build_tree(abc)
codes = build_codes(tree)

In [23]:
encoded = encode(message, codes)
decoded = decode(encoded, codes)

print(decoded == message)

True


## Asymetric Numeral Systems

In [17]:
abc = {k:i for i, k in enumerate(set(message))}

In [48]:
def encode(message, abc):
    n_abc = len(abc)
    idxs = {k:i for i, k in enumerate(abc)}
    encoded = 1
    for m in reversed(message):
        encoded = encoded * n_abc + idxs[m]
    return encoded

def decode(message, abc):
    n_abc = len(abc)
    idxs = {i: k for i, k in enumerate(abc)}
    decoded = ""
    while message != 1:
        decoded += idxs[message % n_abc]
        message = message // n_abc
    return decoded
        



In [49]:
encoded = encode(message, abc)
decoded = decode(encoded, abc)

In [50]:
decoded

'A_DEAD_DAD_CEDED_A_BAD_BABE_A_BEADED_ABACA_BED'