In [105]:
from typing import Dict, List

def get_probs(string : str) -> Dict[str, int]:
  return {x: string.count(x) / len(string) for x in set(string)}

probs = get_probs("TO BE OR NOT TO BE")
print(probs)

print(sum(probs.values()))

{'T': 0.16666666666666666, ' ': 0.2777777777777778, 'N': 0.05555555555555555, 'O': 0.2222222222222222, 'E': 0.1111111111111111, 'B': 0.1111111111111111, 'R': 0.05555555555555555}
1.0


In [107]:
from collections import namedtuple
Node = namedtuple('Node', ['symbol', 'prob', 'left_child', 'right_child'])

def build_tree(probs : Dict[str, int]) -> Node:
  # initialize nodes list
  nodes = []
  for k in sorted(probs, key=probs.get):
    nodes.append(Node(symbol=k, prob=probs[k], left_child=None, right_child=None))
  # merge trees until len(trees) <= 1
  while True:
    merged = _merge_nodes(nodes.pop(0), nodes.pop(0))
    nodes.append(merged)
    nodes.sort(key=lambda x: x.prob)
    if len(nodes) <= 1:
      break
  return nodes[0]

def _merge_nodes(node1, node2):
  return Node('', prob=node1.prob+node2.prob, left_child=node1, right_child=node2)

root = build_tree(probs)
print(root.prob)

1.0


In [108]:
def get_codes(root : Node) -> Dict[str, str]:
  codes = {}
  _get_codes_recc(node=root, curr_code='', codes=codes)
  return codes

def _get_codes_recc(node : Node, curr_code : str, codes: Dict[str, str]):
  if node.symbol == '':
    if node.left_child:
      code = curr_code + '0'
      _get_codes_recc(node.left_child, code, codes)
    if node.right_child:
      code = curr_code + '1'
      _get_codes_recc(node.right_child, code, codes)
  else:
    codes[node.symbol] = curr_code

codes = get_codes(root)
print(codes)


{'O': '00', 'E': '010', 'B': '011', ' ': '10', 'N': '1100', 'R': '1101', 'T': '111'}


In [109]:
def encode(string : str) -> str:
  probs = get_probs(string)
  root = build_tree(probs)
  codes = get_codes(root)
  return ''.join([codes[c] for c in string]), codes
encode("TO BE OR NOT TO BE")

('11100100110101000110110110000111101110010011010',
 {' ': '10',
  'B': '011',
  'E': '010',
  'N': '1100',
  'O': '00',
  'R': '1101',
  'T': '111'})

In [110]:
def decode(root:Node, code:str):
  decoded = ''
  curr_code = code
  while(len(curr_code) > 0):
    symbol, remain = _walk_tree(root, curr_code)
    decoded += symbol
    curr_code = remain
  return decoded


def _walk_tree(node : Node, curr_code : str):
  if node.symbol != '':
    return node.symbol, curr_code
  elif curr_code[0] == '0':
    return _walk_tree(node.left_child, curr_code[1:])
  else:
    return _walk_tree(node.right_child, curr_code[1:])

decode(root, '11100100110101000110110110000111101110010011010')

'TO BE OR NOT TO BE'

In [0]:
from copy import deepcopy

def expand_tree(root : Node, max_depth : int) -> Node:
  copy = deepcopy(root)
  return _expand_reccurent(copy, 0, max_depth, '')

def _expand_reccurent(node : Node, curr_depth : int, max_depth : int, curr_symbol : str):
  if curr_depth == max_depth:
    return deepcopy(node) # no expansion
  else:
    if node.symbol != '':
      node = Node(symbol=node.symbol, prob=node.prob, left_child=deepcopy(node), right_child=deepcopy(node)) # shadow nodes
    left_child = _expand_reccurent(node.left_child, curr_depth + 1, max_depth, curr_symbol)
    right_child = _expand_reccurent(node.right_child, curr_depth + 1, max_depth, curr_symbol)
    node = Node(symbol=node.symbol, prob=node.prob, left_child=left_child, right_child=right_child)
    return node

In [0]:
def get_sym_codes(extended_root : Node, max_depth : int) -> Dict[str, str]:
  sym_codes = {}
  _walk_extended(extended_root, 0, max_depth, '', sym_codes)
  return sym_codes

def _walk_extended(node : Node, curr_depth, max_depth, curr_code, sym_codes):
  if curr_depth == max_depth:
    sym_codes[curr_code] = node.symbol
  else:
    _walk_extended(node.left_child, curr_depth + 1, max_depth, curr_code + '0', sym_codes)
    _walk_extended(node.right_child, curr_depth + 1, max_depth, curr_code + '1', sym_codes)

In [0]:
def fast_decode(code : str, sym_codes : Dict[str, str], codes : Dict[str, str], max_depth : int):
  decoded = ''
  curr_code = code
  while len(code) > 0:
    curr_code = code[:max_depth].ljust(max_depth, '0')
    symbol = sym_codes.get(curr_code)
    symbol_length = len(codes[symbol])
    code = code[symbol_length:]
    decoded += symbol
  return decoded


In [125]:
max_depth = max([len(code) for code in codes.values()])
expanded = expand_tree(root, max_depth)
sym_codes = get_sym_codes(expanded, max_depth)
print(sym_codes)
code, codes = encode("TO BE OR NOT TO BE")
print(code)
fast_decode(code, sym_codes, codes, max_depth)

{'0000': 'O', '0001': 'O', '0010': 'O', '0011': 'O', '0100': 'E', '0101': 'E', '0110': 'B', '0111': 'B', '1000': ' ', '1001': ' ', '1010': ' ', '1011': ' ', '1100': 'N', '1101': 'R', '1110': 'T', '1111': 'T'}
11100100110101000110110110000111101110010011010


'TO BE OR NOT TO BE'