<a href="https://colab.research.google.com/github/OdysseyGuy/codec/blob/main/Codec.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [280]:
import os
import numpy as np
from PIL import Image
import cv2
from scipy.fft import dct
from scipy.signal import convolve2d
from multiprocessing.pool import Pool
import heapq
from collections import defaultdict, Counter

class Downsampling():
    def __init__(self, ratio='4:2:0'):
        assert ratio in ('4:4:4', '4:2:2', '4:2:0'), "Please choose one of the following ('4:4:4', '4:2:2', '4:2:0')"
        self.ratio = ratio

    def __call__(self, x):
        # no subsampling
        if self.ratio == '4:4:4':
            return x

        if self.ratio == '4:2:2':
            kernel = np.array([[0.5], [0.5]])
            out = np.repeat(convolve2d(x, kernel, mode='valid')[::2,:], 2, axis=0)
        else:
            kernel = np.array([[0.25, 0.25], [0.25, 0.25]])
            out = np.repeat(np.repeat(convolve2d(x, kernel, mode='valid')[::2,::2], 2, axis=0), 2, axis=1)
        return out.astype('int')

class ImageBlock():
    def __init__(self, block_height=8, block_width=8):
        self.block_height = block_height
        self.block_width = block_width
        self.left_padding = self.right_padding = self.top_padding = self.bottom_padding = 0

    def forward(self, image):
        self.image_height = image.shape[0]
        self.image_width = image.shape[1]
        self.image_channel = image.shape[2]

        # vertical padding
        vpad = self.image_height % self.block_height
        if vpad != 0:
            self.top_padding = vpad // 2
            self.bottom_padding = vpad - self.top_padding
            image = np.concatenate((np.repeat(image[:1], self.top_padding, 0), image,
                                    np.repeat(image[-1:], self.bottom_padding, 0)), axis=0)

        # horizontal padding
        hpad = self.image_width % self.block_width
        if hpad != 0:
            self.left_padding = hpad // 2
            self.right_padding = hpad - self.left_padding
            image = np.concatenate((np.repeat(image[:,:1], self.left_padding, 1), image,
                                    np.repeat(image[:,-1:], self.right_padding, 1)), axis=1)

        # update size with padding
        self.image_height = image.shape[0]
        self.image_width = image.shape[1]

        # create blocks
        blocks = []
        indices = []
        for i in range(0, self.image_height, self.block_height):
            for j in range(0, self.image_width, self.block_width):
                for k in range(self.image_channel):
                    blocks.append(image[i:i+self.block_height, j:j+self.block_width, k])
                    indices.append((i, j, k))

        blocks = np.array(blocks)
        indices = np.array(indices)
        return blocks, indices

    def backward(self, blocks, indices):
        image = np.zeros((self.image_height, self.image_width, self.image_channel)).astype('int')
        for block, index in zip(blocks, indices):
            i, j, k = index
            image[i:i+self.block_height, j:j+self.block_width, k] = block

        # remove padding
        if self.top_padding > 0:
            image = image[self.top_padding:,:,:]
        if self.bottom_padding > 0:
            image = image[:-self.top_padding,:,:]
        if self.left_padding > 0:
            image = image[:,self.left_padding:,:]
        if self.right_padding > 0:
            image = image[:,:-self.right_padding,:]

        return image

class DCT2D():
    def __init__(self, norm='ortho'):
        if norm is not None:
            assert norm == 'ortho', "norm needs to be in (None, 'ortho')"
        self.norm = norm

    def forward(self, x):
        # apply dct on both axis
        out = dct(dct(x, norm=self.norm, axis=0), norm=self.norm, axis=1)
        return out

    def backward(self, x):
        # apply dct on both axis
        out = dct(dct(x, type=3, norm=self.norm, axis=0), type=3, norm=self.norm, axis=1)
        return out.astype(int)

class Quantization():
    # Luminance
    Q_lum = np.array([[16,11,10,16,24,40,51,61],
                      [12,12,14,19,26,58,60,55],
                      [14,13,16,24,40,57,69,56],
                      [14,17,22,29,51,87,80,62],
                      [18,22,37,56,68,109,103,77],
                      [24,35,55,64,81,104,113,92],
                      [49,64,78,87,103,121,120,101],
                      [72,92,95,98,112,100,103,99]])

    # Chrominance
    Q_chr = np.array([[17,18,24,47,99,99,99,99],
                      [18,21,26,66,99,99,99,99],
                      [24,26,56,99,99,99,99,99],
                      [47,66,99,99,99,99,99,99],
                      [99,99,99,99,99,99,99,99],
                      [99,99,99,99,99,99,99,99],
                      [99,99,99,99,99,99,99,99],
                      [99,99,99,99,99,99,99,99]])

    def __init__(self, q=50):
        self.q = q
        s =  (5000/q) if (q < 50) else (200 - 2 * q)
        self.Q_lum = np.floor((self.Q_lum * s + 50)/100).astype('int16')
        self.Q_chr = np.floor((self.Q_chr * s + 50)/100).astype('int16')

    def forward(self, x, channel_type):
        assert channel_type in ('lum', 'chr')

        if channel_type == 'luma':
            Q = self.Q_lum
        else:
            Q = self.Q_chr

        out = np.divide(x, Q).astype(int)
        return out

    def backward(self, x, channel_type):
        assert channel_type in ('lum', 'chr')

        if channel_type == 'luma':
            Q = self.Q_lum
        else:
            Q = self.Q_chr

        out = x*Q
        return out

def get_zigzag_indices(N):
    zigzag_indices = [(i,j) for i in range(N) for j in range(N)]
    zigzag_indices.sort(key = lambda x: (x[0]+ x[1], x[1]) if (x[0]+x[1])%2 == 0 else (x[0]+x[1], x[0]))

    return zigzag_indices

def zigzag_traverse(block):
    '''
    rows = 8
    cols = 8
    ll = [[] for i in range(rows+cols-1)]
    for i in range(rows):
        for j in range(cols):
            sum=i+j
            if (sum % 2 == 0):
                ll[sum].insert(0, block[i][j])
            else:
                ll[sum].append(block[i][j])
    out = []
    for i in ll:
        for j in i:
            out.append(j)
    return np.array(out)
    '''
    # this might perform better?
    return [
        block[0][0], block[0][1], block[1][0], block[2][0], block[1][1], block[0][2], block[0][3], block[1][2],
        block[2][1], block[3][0], block[4][0], block[3][1], block[2][2], block[1][3], block[0][4], block[0][5],
        block[1][4], block[2][3], block[3][2], block[4][1], block[5][0], block[6][0], block[5][1], block[4][2],
        block[3][3], block[2][4], block[1][5], block[0][6], block[0][7], block[1][6], block[2][5], block[3][4],
        block[4][3], block[5][2], block[6][1], block[7][0], block[7][1], block[6][2], block[5][3], block[4][4],
        block[3][5], block[2][6], block[1][7], block[2][7], block[3][6], block[4][5], block[5][4], block[6][3],
        block[7][2], block[7][3], block[6][4], block[5][5], block[4][6], block[3][7], block[4][7], block[5][6],
        block[6][5], block[7][4], block[7][5], block[6][6], block[5][7], block[6][7], block[7][6], block[7][7]
    ]

# def reverse_zig_zag_traverse(block):
#     # this might perform better?
#     return np.array([
#         np.array([block[0], block[1], block[8], block[16], block[9], block[2], block[3], block[10]]),
#         np.array([block[17], block[24], block[32], block[25], block[18], block[11], block[4], block[5]]),
#         np.array([block[12], block[19], block[26], block[33], block[40], block[48], block[41], block[34]]),
#         np.array([block[27], block[20], block[13], block[6], block[7], block[14], block[21], block[28]]),
#         np.array([block[35], block[42], block[49], block[56], block[57], block[50], block[43], block[36]]),
#         np.array([block[29], block[22], block[15], block[23], block[30], block[37], block[44], block[51]]),
#         np.array([block[58], block[59], block[52], block[45], block[38], block[31], block[39], block[46]]),
#         np.array([block[53], block[60], block[61], block[54], block[47], block[55], block[62], block[63]])
#     ])

def reverse_zigzag_traverse(array, N=8):
    zigzag_indices = get_zigzag_indices(N)
    block = np.zeros((N,N))

    for i,zigzag_index in enumerate(zigzag_indices):
        block[zigzag_index[0], zigzag_index[1]] = array[i]

    return block


class HuffmanNode:
    def __init__(self,symbol=None,frequency=None):
        self.symbol = symbol
        self.frequency = frequency
        self.left = None
        self.right = None

    def __lt__(self,other):
        return self.frequency < other.frequency

def build_huffman_tree(frequencies):
    heap = [HuffmanNode(symbol=s, frequency=f) for s, f in frequencies.items()]
    heapq.heapify(heap)

    while len(heap) > 1:
        left = heapq.heappop(heap)
        right = heapq.heappop(heap)
        internal_node = HuffmanNode(frequency=left.frequency + right.frequency)
        internal_node.left = left
        internal_node.right = right
        heapq.heappush(heap,internal_node)

    return heap[0]

def generate_huffman_codes(node,current_code="",codes=None):
    if codes is None:
        codes = {}

    if node.symbol is not None:
        codes[node.symbol] = current_code
    else:
        generate_huffman_codes(node.left, current_code + "0", codes)
        generate_huffman_codes(node.right, current_code + "1", codes)

    return codes

def huffman_encode(data):
    frequencies = Counter(data)
    root = build_huffman_tree(frequencies)
    codes = generate_huffman_codes(root)
    # encode the message with huffman codes
    encoded_data = "".join(codes[symbol] for symbol in data)
    return encoded_data, root


def huffman_decode(encoded_data, root):
    decoded_data = []
    current_node = root
    for bit in encoded_data:
        if bit == "0":
            current_node = current_node.left
        else:
            current_node = current_node.right

        if current_node.symbol is not None:
            decoded_data.append(current_node.symbol)
            current_node = root

    return decoded_data

def run_length_encode(image):
    zeros_count = 0
    length = image.shape
    encoded = np.array([])
    for i in range(length[0]):
        if image[i] == 0:
            if zeros_count == 0:
                encoded = np.append(encoded, 0)
            zeros_count += 1
        else:
            if zeros_count != 0:
                encoded = np.append(encoded, zeros_count)
                zeros_count = 0
            encoded = np.append(encoded,image[i])
    if zeros_count != 0:
        encoded = np.append(encoded,zeros_count)
    return encoded

def run_length_decode(encoded,length):
    encoded_length = encoded.shape
    image = np.zeros(length)
    idx = 0
    for i in range(encoded_length[0]):
        if encoded[i - 1] == 0:
            continue
        if encoded[i] == 0:
            for j in range(int(encoded[i+1])):
                image[idx] = 0
                idx += 1
        else:
            image[idx] = encoded[i]
            idx += 1
    return image

# no downsampling for the luminance channel
lum_downsample = Downsampling(ratio='4:4:4')
# 4:2:0 subsampling for the chrominance channel
chr_downsample = Downsampling(ratio='4:4:4')
image_block = ImageBlock(block_height=8, block_width=8)
dct2d = DCT2D(norm='ortho')
quant = Quantization(20)

rgb_img = np.asarray(Image.open("12.png"),np.uint8)

# RGB -> YCbCr
ycc_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2YCrCb)
ycc_img = ycc_img.astype(int)-128

# chroma subsampling
Y = lum_downsample(ycc_img[:,:,0])
Cb = chr_downsample(ycc_img[:,:,1])
Cr = chr_downsample(ycc_img[:,:,2])

ycc_img = np.stack((Y, Cb, Cr), axis=2)

# create 8x8 blocks
blocks, indices = image_block.forward(ycc_img)
nxblocks = rgb_img.shape[1]//8
nyblocks = rgb_img.shape[0]//8

def process_block(block, index):
    # dct
    encoded = dct2d.forward(block).astype('int32')

    if index[2] == 0:
        channel_type = 'lum'
    else:
        channel_type = 'chr'
    encoded = quant.forward(encoded, channel_type)
    flattened = np.array(zigzag_traverse(encoded))
    return flattened

i = 0
j = 0
rle_encoded = np.array([])
for block, index in zip(blocks, indices):
    flattened_block = run_length_encode(process_block(block, index))
    rle_encoded = np.append(rle_encoded, flattened_block)

encoded, root = huffman_encode(rle_encoded)

decoded = np.array(huffman_decode(encoded, root))
rle_decoded = run_length_decode(decoded,nxblocks*nyblocks*64*3)
dec_blocks = np.zeros((nxblocks*nyblocks*3, 8, 8))

index = 0
for i in range(nxblocks * nyblocks * 3):
    block = reverse_zigzag_traverse(rle_decoded[index:index+8*8], 8)
    if i % 3 == 0:
        channel_type = 'lum'
    else:
        channel_type = 'chr'

    dequant = quant.backward(block.reshape((8,8)), channel_type)
    decompressed = dct2d.backward(dequant)
    dec_blocks[i] = decompressed
    index += 64

ycc_img_compressed = image_block.backward(dec_blocks, indices)
ycc_img_compressed = (ycc_img_compressed+128).astype('uint8')
rgb_img_compressed = cv2.cvtColor(ycc_img_compressed, cv2.COLOR_YCrCb2RGB)
Image.fromarray(rgb_img_compressed).save(os.path.join("result6.jpeg"))