In [None]:
import numpy as np

def scale_interval_edges(orig_interval_edges, min_val, max_val):
    new_edges = min_val + (max_val - min_val)*orig_interval_edges
    return new_edges

def symbol_to_index(symbol, alphabeth):
    assert len(set(alphabeth)) == len(alphabeth), 'Redundant alphabeth'
    assert symbol in alphabeth, 'Symbol {} not in alphabeth'.format(symbol)
    return alphabeth.index(symbol) + 1 #

def get_interval_from_symbol(current_symbol, alphabeth, current_min, current_max, orig_interval_edges):
    curr_signal_ind = symbol_to_index(current_symbol, alphabeth)
    curr_interval_edges = scale_interval_edges(orig_interval_edges, current_min, current_max)
    new_min = curr_interval_edges[curr_signal_ind - 1]
    new_max = curr_interval_edges[curr_signal_ind]
    
    return (new_min, new_max)

def interval_edges(pmf):
    return np.array([np.sum(pmf[:i]) for i in range(len(pmf) + 1)])

def arithmetic_intervals(alphabeth, signal, pmf):
    orig_interval_edges = interval_edges(pmf)
    signal_list = list(signal)
    curr_min, curr_max = get_interval_from_symbol(signal_list[0], alphabeth, 0.0, 1.0, orig_interval_edges)
    intervals = [(curr_min, curr_max)]
    
    for i, symbol in enumerate(signal_list[1:]):
        curr_min, curr_max = get_interval_from_symbol(symbol, alphabeth, curr_min, curr_max, orig_interval_edges)
        intervals.append((curr_min, curr_max))
    return intervals

def shortest_binary(d_interval):
    d_min, d_max = d_interval
    assert d_min < d_max, 'Need strictly increasing interval'
    assert d_min >= 0, 'Negative lower bound on interval'
    assert d_max < 1, 'Upper interval bound greater or equal to 1'
    
   
    c_min = 0.0
    c_max = 1.0
    
    
    k = 1
    bin_seq = []
    while True:
        # Allways add 1 to bin_seq if possible
        if c_min < d_min and c_min + 1/2**k < d_max:
            c_min = c_min + 1/2**k
            bin_seq.append(1)
        else:
            if c_max > d_max and c_max - 1/2**k > d_min:
                c_max = c_max - 1/2**k
                bin_seq.append(0)
            else:
                # No change is made to the current interval, we are finished
                break
        k = k + 1
    return bin_seq            

def arithmetic_encoding(alphabeth, pmf, signal):
    intervals = arithmetic_intervals(alphabeth, signal, pmf)
    bin_seq = shortest_binary(intervals[-1])
    return bin_seq

alphabeth = ['a', 'b', 'c']
pmf = np.array([0.6, 0.2, 0.2]) 
signal = 'abccba'

encoded_signal = arithmetic_encoding(alphabeth, pmf, signal)
print('Signal:         {}'.format(signal))
print('Encoded signal: {}'.format(''.join([str(s) for s in encoded_signal])))


Signal:         abccba
Encoded signal: 0111101010000
