In [1]:
import torch 
import numpy as np
import os
import sys
sys.path.append("..")
from src.model import Encoder, Decoder
from src.dataset import ResidualDataset
from src.HuffmanCompression import HuffmanCoding
import pdb
from copy import copy
%load_ext autoreload
%autoreload 2

In [2]:
checkpoint_enc = '../checkpoint/test/encoder-epoch_49_iter_40550.pth'
checkpoint_dec = '../checkpoint/test/decoder-epoch_49_iter_40550.pth'
enc = Encoder()
dec = Decoder()
# load both models
enc.load_state_dict(torch.load(checkpoint_enc, map_location='cpu'))
dec.load_state_dict(torch.load(checkpoint_dec, map_location='cpu'))

<All keys matched successfully>

In [3]:
train_data_path = '../data'
device = torch.device('cpu')
dSet_train = ResidualDataset(train_data_path, 'final_test', device)
dataset_train = torch.utils.data.DataLoader(dSet_train,
                                            batch_size=1, shuffle=True,
                                            num_workers=0)

dataset mode: final_test, length: 86


In [4]:
bin_out_list = []
with torch.no_grad():
    for idx, sample in enumerate(dataset_train):
        model_input = sample['image']
        bin_out = enc(model_input).squeeze().numpy()
        bin_out_list.append(bin_out)
bin_out_arr = ((np.array(bin_out_list) + 1)/2).astype(np.uint8)
bin_out_arr.shape

(86, 32, 45, 150)

In [5]:
test_bin = bin_out_arr[:10,:,:,:]
print(test_bin.shape)

(10, 32, 45, 150)


In [39]:
N,C,H,W = test_bin.shape
chunk_size = 50 # W // 5
huffmap ={}
i=j=k=l=0
# N = H = W = 1
while i<N:
    j=0
    while j<C:
        k=0
        while k<H:
            channel = np.split(test_bin[i][j][k][:],chunk_size)
            for num in channel:
                s =""
                for ele in num:
                    s += str(ele)
                huffmap[s] = huffmap.get(s, 0) + 1
            k+=1
        j+=1
    i+=1
            

In [51]:
N,C,H,W = test_bin.shape
tile_x = 5
tile_y = 5
huffmap ={}

pH = tile_y - H % tile_y
pW = tile_x - W % tile_x
test_bin_padded = np.pad(test_bin, ((0,0), (0,0), (0, pH), (0, pW)))

for n in range(N):
    for c in range(C):
        for i in range(H // tile_x):
            ii = i*tile_y
            for j in range(W // tile_y):
                jj = j*tile_x
                tile = test_bin_padded[n,c, ii:ii+tile_y, jj:jj+tile_x].flatten()
                s = ""
                for ele in tile:
                    s += str(ele)
                huffmap[s] = huffmap.get(s, 0) + 1
                
                

In [52]:
huffmap

{'1111110010110100000010011': 1,
 '1111100100001000010001100': 1,
 '1010110101101101111010101': 1,
 '0011101111011110001101111': 1,
 '1111111100111101111111111': 3,
 '1111110110100100011000000': 1,
 '1101101011000100001010111': 1,
 '1100011101100100110011100': 1,
 '1100101101001101000001110': 1,
 '1111101011011111100101011': 1,
 '1101111000100001010000000': 1,
 '1011001100101010010110000': 1,
 '1101100010001000111010010': 1,
 '1111111111011110111110011': 1,
 '1111111111111111111111111': 1588,
 '0011000011010010101010100': 1,
 '0001000000100010000110110': 1,
 '0110001001010011100111001': 1,
 '0110000001110100011100000': 1,
 '0011011101101011000000000': 1,
 '1111110111100110100100000': 1,
 '1011011100001001010100111': 1,
 '1011100101010110101101100': 1,
 '0110110111000001101110001': 1,
 '1100111100010010100011000': 1,
 '0010101111000110000100100': 1,
 '1110100010001000000010000': 1,
 '0001010010010100011111101': 1,
 '1101110001111001011001111': 1,
 '0000010011111101011110000': 1,
 '11111

In [53]:
h = HuffmanCoding()
h.make_heap(huffmap)
h.merge_nodes()
h.make_codes()
h.codes

{'0110010100011010101101110': '0000000000000000',
 '1001100100000011011001000': '0000000000000001',
 '0110011111101111010111010': '0000000000000010',
 '0000110001011010100011110': '0000000000000011',
 '1111111001010101011011011': '0000000000000100',
 '0100110101100011110001011': '0000000000000101',
 '1000100000000111001100100': '0000000000000110',
 '0101001010110100010110000': '0000000000000111',
 '0111110001100000110001010': '0000000000001000',
 '0011100111101001101000010': '0000000000001001',
 '0000000001110000010000101': '0000000000001010',
 '0111110011101100010101011': '0000000000001011',
 '0000000000000010111011101': '0000000000001100',
 '1101111111110100001010000': '0000000000001101',
 '1011100011110111100111101': '0000000000001110',
 '1110111001101001001000010': '0000000000001111',
 '1011110111000111111111101': '0000000000010000',
 '1111111101011011110111010': '0000000000010001',
 '1100000000101101110011100': '0000000000010010',
 '1100100100010011100100111': '0000000000010011',


In [54]:
tot_bits = 0
entropy =0.
avg = 0
total_counts = 0.
codes = h.codes
for key in huffmap.keys():
    tot_bits += huffmap[key] * len(codes[key])
    total_counts += huffmap[key]
print("Total bits in data: {}, Total counts: {}".format(N*C*H*W, total_counts))
print("Total bits needed: {}".format(tot_bits))
for key in huffmap.keys():
    avg += huffmap[key] * len(codes[key])/total_counts
    p = float(huffmap[key])/total_counts
    entropy += -(p * np.log2(p))
print("Entropy: {}".format(entropy))
print("Avg code work length: {}".format(avg))
print("How close are we to Entropy :{} ".format(entropy/avg))

Total bits in data: 2160000, Total counts: 86400.0
Total bits needed: 1377848
Entropy: 15.86816720862458
Avg code work length: 15.947314814795451
How close are we to Entropy :0.9950369321048682 
