In [209]:
from collections import Counter,defaultdict
import math
import pandas as pd
import copy

In [256]:
"""
@author: KhangTran2503
"""

class Shannon_FanoTree():
    def __init__(self):
        #assert data is None
        self.data = []
        self.frequency = []
        self.dict = defaultdict(list)
        self.B1 = 0
        self.B0 = 0
        self.ideal_entropy = 0
        self.encode = ''
    
    # add data to dict
    def add_node(self,name,code):
        f = self.dict[name][-1]
        pi = f/len(self.data)
        lencode = len(code)
        entropy = math.log(1/pi,2)
        self.dict[name].append(entropy)
        self.dict[name].append(code)
        self.dict[name].append(lencode*f)
        self.B1 += f*lencode
        self.ideal_entropy += pi*entropy
    
    """
        recursive [L,R] and code
        go to left => code + '0'
        go to right => code + '1'
    """ 
    def recur(self,Range,code):
        l, r = Range
        if l == r: return 
        if l + 1 == r:
            self.add_node(self.frequency[l][0],code)
            return
        
        idx = self.break_point(Range)
        self.recur((l,idx),code + '0')
        self.recur((idx,r),code + '1')
        
    # function compress data
    def compress(self,data):
        #assign data
        self.data = list(data)
        self.B0 = len(self.data)*8
        
        # get frequency of each elements in data
        frequency = Counter(self.data)
        self.frequency = frequency.most_common()
        
        # insert  
        for key, val in self.frequency:
            self.dict[key].append(val)
            
        self.recur((0,len(self.frequency)),'')
        self.encode = ' '.join([self.dict[x][2] for x in self.data])    
        
    # find idx such that diff sum([L,idx)) and sum([idx,R)) is minimize 
    def break_point(self,Range):
        L, R = Range
        Sum = 0
        # Calc Sum in [L,R]
        for i in range(L,R):
            _,f = self.frequency[i]
            Sum += f
            
        mindiff = 9999999999999
        part1 = 0
        breakp = L + 1
        for i in range(L,R):
            _,f = self.frequency[i]
            part1 += f
            diff = abs(Sum - 2*part1)
            if diff < mindiff:
                mindiff = diff
                breakp = i + 1
        return breakp
    
    # sent encode to decoder 
    def share_encode(self):
        return (self.encode,self.dict)
    
    def get_DataFrame(self):
        return pd.DataFrame.from_dict(self.dict, orient='index',
                                columns=['Freq', '-log2(pi)', 'Code', 'Bits used'])
    def get_information(self):
        print(self.get_DataFrame())
        print('\nB0                  : {}'.format(self.B0))
        print('B1                  : {}'.format(self.B1))
        print('Compression Ratio   : {}'.format(self.B0/self.B1))
        print('Ideal Entropy       : {}'.format(self.ideal_entropy))
        print('Compression Entropy : {}'.format((8*self.B1)/self.B0))
        
    @staticmethod
    def decoder(encode):
        code,encode_dict = encode
        decode_dict = dict([(val[2],key) for key,val in encode_dict.items()])
        return ''.join(list(map(lambda x: decode_dict[x],code.split())))

# 1. Encode

In [263]:
Text = 'HELLO' 
encoder = Shannon_FanoTree()
encoder.compress(Text)
encoder.get_information()

   Freq  -log2(pi) Code  Bits used
L     2   1.321928    0          2
H     1   2.321928   10          2
E     1   2.321928  110          3
O     1   2.321928  111          3

B0                  : 40
B1                  : 10
Compression Ratio   : 4.0
Ideal Entropy       : 1.9219280948873623
Compression Entropy : 2.0


In [264]:
share_encode = encoder.share_encode()

In [270]:
print('Data is encoded:')
print(share_encode[0])

Data is encoded:
10 110 0 0 111


# 2. Decode

In [265]:
print('Data after decode: ')
print(Shannon_FanoTree().decoder(share_encode))

Data after decode: 
HELLO
