In [None]:
import math
import time
from collections import Counter

with open('chinese_text.txt', 'r', encoding='utf-8') as f:
    content = f.read()
# print(content)
counter = Counter(content)
total = sum(counter.values()) # 总字数

symbols = list(counter.keys()) # 字符
probs = [counter[s] / total for s in symbols]

def shannon_code(symbols, probs):
    codes = {}
    q = 0.0
    
    # sort the probabilities(large to small)
    sorted_probs = sorted(zip(symbols, probs), key=lambda x:-x[1])
    
    for s, p in sorted_probs:
        l = math.ceil(-math.log2(p)) # get optimal codeword length
        code = ''
        q_bin = q
        # get binary
        for _ in range(l):
            q_bin *= 2
            bit = int(q_bin)
            code += str(bit)
            q_bin -= bit
        codes[s] = code
        q += p
    return codes

codes = shannon_code(symbols, probs)

for s in symbols:
    if s == '\n':
        display_s = '\\n'
    elif s == ' ':
        display_s = "' '"
    else:
        display_s = s
    print(f"symbol: {display_s} code: {codes[s]}")
    
# 计算信源熵
entropy = -sum(p * math.log2(p) for p in probs)

# 计算平均码长
avg_code_len = sum([len(codes[s]) * probs[i] for i, s in enumerate(symbols)])

# 计算编码效率
efficiency = entropy / avg_code_len if avg_code_len > 0 else 0

print(f"\ninfo source entropy: {entropy:.4f}")
print(f"avg. codeword length: {avg_code_len:.4f}")
print(f"encoding efficiency: {efficiency:.4f}")

# 编码整个文本
encoded_text = ''.join([codes[c] for c in content])
# print(encoded_text)

# 构建解码映射
decode_map = {v: k for k, v in codes.items()}

# 解码函数
def shannon_decode(encoded_str, decode_map):
    decoded = []
    buffer = ''
    max_code_len = max(len(code) for code in decode_map)
    i = 0
    while i < len(encoded_str):
        buffer = ''
        for l in range(1, max_code_len + 1):
            if i + l > len(encoded_str):
                break
            buffer = encoded_str[i:i+l]
            if buffer in decode_map:
                decoded.append(decode_map[buffer])
                i += l
                break
        else:
            # 没有匹配到，跳出
            break
    return ''.join(decoded)

# 解码并计时
start_time = time.time()
decoded_text = shannon_decode(encoded_text, decode_map)
end_time = time.time()

with open('decoded_text.txt', 'w', encoding='utf-8') as f:
    f.write(decoded_text)

print(f"\nwhether decoding is correct: {decoded_text == content}")
print(f"decoding time : {end_time - start_time:.6f} seconds")
print("\ndecoded text: ")
print(decoded_text)