In [60]:
from collections import deque
from bitarray import bitarray
from unidecode import unidecode

In [82]:
# worcs only with ASCII characters
class StaticHuffman:
    
    class _Node:
        def __init__(self, weight, left=None, right=None, parent=None, letter=None):
            self.left = left
            self.right = right
            self.letter = letter
            self.parent = parent
            self.weight = weight
        
        def __str__(self):
            return str(self.encoding)
        
        def __repr__(self):
            return str(self.encoding)
        
    def __init__(self, text=None, file_name=None):
        self.root = None
        self.leafs = dict()
        self.encodings = dict()
        
        if text:
            self._build(unidecode(text))
            if file_name:
                self.encode_to_file(text, file_name)
            
    def _cache_leaf_representations(self):
        for (l, n) in self.leafs.items():
            e = []
            node = n
            while node.parent is not None:
                if node is node.parent.left:
                    e.append(False)
                else:
                    e.append(True)
                node = node.parent
            e = bitarray(reversed(e))
            self.encodings[l] = e
        
    def _build(self, text):
        alphabet_statistics = dict()
        
        for a in text:
            if a not in alphabet_statistics:
                alphabet_statistics[a] = 1
            else:
                alphabet_statistics[a] += 1
        leafs = [self._Node(w, letter=l) for (l, w) in alphabet_statistics.items()]
        self.leafs = {l.letter:l for l in leafs}
        leafs.sort(key=lambda x:x.weight)
        left = deque(leafs)
        right = deque()
        
        def get_min(left, right):
            if len(left) == 0:
                return right.popleft()
            if len(right) == 0:
                return left.popleft()
            
            if left[0].weight <= right[0].weight:
                return left.popleft()
            else:
                return right.popleft()
            
        # Build Huffman tree
        while(len(left) + len(right) > 1):
            n1 = get_min(left, right)
            n2 = get_min(left, right)
            top = self._Node(n1.weight + n2.weight, left=n1, right=n2)
            n1.parent = top
            n2.parent = top
            right.append(top)
        # Create encodings for each leaf
        self._cache_leaf_representations()
        self.root = right.popleft()
    
    # depricated
    def _encode(self, c):
        return self.leafs[c]
    
    def _encode_tree(self):
        #raise NotImplementedError()
        s = [self.root]
        encoded_tree = bitarray()
        
        while s:
            node = s.pop()
            if node.letter is None:
                encoded_tree.append(False)
                s.append(node.right)
                s.append(node.left)
            else:
                encoded_tree.append(True)
                encoded_tree.frombytes(node.letter.encode())
        print(encoded_tree)
        return encoded_tree
    
    def _decode_tree(self, encoded_data):
        if self.root is not None:
            raise RuntimeError('Tree should be empty')
        
        offset = 0
        
        # skip padding
        while encoded_data[offset]:
            offset += 1
        
        self.root = self._Node(0)
        s = [self.root]
        #offset+=1
        while s:
            node = s.pop()
            if encoded_data[offset]:
                node.letter = encoded_data[offset+1:offset+9].tobytes().decode()
                self.leafs[node.letter] = node
                offset += 9
            else:
                left = self._Node(0, parent=node)
                right = self._Node(0, parent=node)
                
                node.left = left
                node.right = right
                
                s.append(right)
                s.append(left)
                offset += 1
                
        self._cache_leaf_representations()
        
        return encoded_data[offset:]
                
    def encode_to_file(self, text, file_name):
        if file_name is None:
            raise RuntimeError('Specify file_name to be the name of file to write to')
        
        with open(file_name, 'wb') as f:
            tree_encoded = self._encode_tree()
            encoded_text = bitarray()
            encoded_text.encode(self.encodings, text)
            
            padding = 8 - (tree_encoded.length() + encoded_text.length()) % 8
            if padding != 8:
                tree_encoded = bitarray([True]*(padding)) + tree_encoded
            # danger
            (tree_encoded + encoded_text).tofile(f)
    
    @staticmethod
    def decode(file):
        with open(file, 'rb') as f:
            encoded_data = bitarray()
            encoded_data.fromfile(f)
            
            hc = StaticHuffman()
            encoded_data = hc._decode_tree(encoded_data)
            result = encoded_data.decode(h.encodings)
            return ''.join(result)

In [83]:
h = StaticHuffman("aabc", file_name='test_file')

bitarray('01011000010101100010101100011')


In [84]:
StaticHuffman.decode('test_file')

'aabc'