In [43]:
!pip install bitarray



In [44]:
import math
from struct import pack
from bitarray import bitarray
import numpy as np
import cv2

In [45]:
N = 8 #block size

In [46]:
def rgb_to_ycbcr(r, g, b):
    N = len(r)

    y = [[0] * N for _ in range(N)]
    cb = [[0] * N for _ in range(N)]
    cr = [[0] * N for _ in range(N)]

    for i in range(N):
        for j in range(N):
            y[i][j] = 0.299 * r[i][j] + 0.587 * g[i][j] + 0.114 * b[i][j]
            cb[i][j] = -0.169 * r[i][j] - 0.331 * g[i][j] + 0.5 * b[i][j]
            cr[i][j] = 0.5 * r[i][j] - 0.419 * g[i][j] - 0.081 * b[i][j]

    return y, cb, cr

In [47]:
def dct(a):
    n = len(a)
    if n == 1:
        return list(a)
    else:
        half = n // 2
        alpha = [(a[i] + a[-(i + 1)]) for i in range(half)]
        beta = [(a[i] - a[-(i + 1)]) / (math.cos((i + 0.5) * math.pi / n) * 2.0) for i in range(half)]
        alpha = dct(alpha)
        beta = dct(beta)
        result = []
        for i in range(half - 1):
            result.append(alpha[i])
            result.append(beta[i] + beta[i + 1])
        result.append(alpha[-1])
        result.append(beta[-1])
        return result

In [48]:
def norm_coeff(n):
    if n == 0:
        return 1.0 / math.sqrt(2.0)
    else:
        return 1.0

def dct2(a):
    N = len(a)
    A = [[0 for _ in range(N)] for __ in range(N)]

    scaling = math.sqrt(2/N)
    for col in range(N):
        cur = [a[row][col] for row in range(N)]
        cur = dct(cur)
        for row in range(N):
            A[row][col] = cur[row] * scaling
            if row == 0:
                A[row][col] /= math.sqrt(2)

    for row in range(N):
        cur = [A[row][col] for col in range(N)]
        cur = dct(cur)
        for col in range(N):
            A[row][col] = cur[col] * scaling
            if col == 0:
                A[row][col] /= math.sqrt(2)

    return A


In [49]:
def quantize(input_matrix, quantization_matrix):
    size = len(input_matrix)
    quantized_matrix = [[round(input_matrix[row][col] / quantization_matrix[row][col]) for col in range(size)] for row in range(size)]

    return quantized_matrix

In [50]:
def zigzag_order(matrix):
    size = len(matrix)
    index = 0
    go_down = False
    zigzag_result = [0 for _ in range(size * size)]

    for total_sum in range(2 * size - 1):
        current_row, current_col = None, None

        if go_down:
            current_row = 0
            current_col = total_sum - current_row

            if current_col >= size:
                current_col = size - 1
                current_row = total_sum - current_col
        else:
            current_col = 0
            current_row = total_sum - current_col

            if current_row >= size:
                current_row = size - 1
                current_col = total_sum - current_row

        while current_row >= 0 and current_row < size and current_col >= 0 and current_col < size:
            zigzag_result[index] = matrix[current_row][current_col]
            index += 1

            if go_down:
                current_row += 1
                current_col -= 1
            else:
                current_row -= 1
                current_col += 1

        go_down = not go_down

    return zigzag_result



In [51]:
def run_length_encode(zigzag_sequence):
    start, end = 1, 1
    last_nonzero_index = N * N - 1
    rle_result = []

    # Find the last non-zero element in the zigzag sequence
    while last_nonzero_index >= 0 and zigzag_sequence[last_nonzero_index] == 0:
        last_nonzero_index -= 1

    while end <= last_nonzero_index:
        if (zigzag_sequence[end] == 0 and end - start + 1 == 16) or zigzag_sequence[end] != 0:
            rle_result += [end - start, int(zigzag_sequence[end]).bit_length(), zigzag_sequence[end]]
            start = end + 1
        end += 1

    # Add zero run-length if the last element is zero
    if last_nonzero_index != N * N - 1:
        rle_result += [0, 0]

    return rle_result

In [52]:
def huffcode(code_dictionary, node, current_code):
    if node[1][1] is None and node[1][2] is None:
        code_dictionary[node[1][0]] = current_code
        return

    if node[1][1] is not None:
        next_code = current_code + "0"
        huffcode(code_dictionary, node[1][1], next_code)

    if node[1][2] is not None:
        next_code = current_code + "1"
        huffcode(code_dictionary, node[1][2], next_code)

def modify_huffcode(code_dictionary):
    modified_code_dict = dict()
    values_list = [[] for _ in range(17)]

    for key, val in code_dictionary.items():
        values_list[len(val)].append(key)

    current_value = 0
    for i in range(17):
        for j in range(len(values_list[i])):
            current_bin = bin(current_value)[2:]
            if '0' * (i - len(current_bin)) + current_bin == '1' * i:
                values_list[i + 1] = [values_list[i][j]] + values_list[i + 1]
                continue
            modified_code_dict[values_list[i][j]] = '0' * (i - len(current_bin)) + current_bin
            current_value += 1
        current_value <<= 1

    return modified_code_dict

def huffcode_gen(data):
    frequency_dict = dict()
    for char in data:
        if char in frequency_dict:
            frequency_dict[char] += 1
        else:
            frequency_dict[char] = 1

    priority_queue = []
    for element in frequency_dict:
        priority_queue.append([frequency_dict[element], [element, None, None]])

    priority_queue.sort()
    while len(priority_queue) > 1:
        tree1 = priority_queue[0]
        tree2 = priority_queue[1]

        current_node = [tree1[0] + tree2[0], [-1, tree1, tree2]]
        priority_queue.remove(tree1)
        priority_queue.remove(tree2)
        priority_queue.append(current_node)
        priority_queue.sort(key=lambda tree: tree[0])

    huffman_tree = priority_queue[0]

    code_dict = dict()
    code = ""
    huffcode(code_dict, huffman_tree, code)
    code_dict = modify_huffcode(code_dict)

    return code_dict



In [53]:
def write_quant_table(jpeg_out, quant_table, type):
    jpeg_out.write(b'\xff\xdb') # marker
    jpeg_out.write(pack(">H",67)) # length of chunk
    jpeg_out.write(pack("B",type))
    zigzag_quant = zigzag_order(quant_table)
    for i in range(N*N):
        jpeg_out.write(pack("B",zigzag_quant[i]))

In [54]:
def write_SOF(jpeg_out,height,width,components):
    jpeg_out.write(b'\xff\xc0') # marker
    jpeg_out.write(pack(">H",8+3*components)) # length of chunk
    jpeg_out.write(pack("B",8)) # bits/sample
    jpeg_out.write(pack(">HHB",height,width,components)) # height,width,components of image

    for i in range(components):
        #id, sampling factor, quant table id
        jpeg_out.write(pack("BBB", i+1, 16+1, min(i,1)))


In [55]:
def write_huffman_table(jpeg_out,type,code):
    jpeg_out.write(b'\xff\xc4') # marker
    jpeg_out.write(pack(">HB",19+len(code),type))

    val_list = [[] for _ in range(16)]
    for (key,val) in code.items():
        val_list[len(val)-1].append(key)

    for i in range(16):
        jpeg_out.write(pack("B",len(val_list[i])))
    for i in range(16):
        for val in val_list[i]:
            jpeg_out.write(pack("B",val))


In [56]:
def add_FF00(data):
    result = ""
    for i in range(0,len(data),8):
        cur_byte = data[i:i+8]
        result = result+cur_byte
        if cur_byte=="11111111":
            result = result+"00000000"
    return result

In [57]:
def write_SOS(jpeg_out,components,Y_dc_list, CB_dc_list, CR_dc_list, Y_ac_list, CB_ac_list, CR_ac_list,Y_dc_vli_list, CB_dc_vli_list, CR_dc_vli_list, Y_ac_vli_list, CB_ac_vli_list, CR_ac_vli_list,lum_dc_code,lum_ac_code,chr_dc_code,chr_ac_code):
    jpeg_out.write(b'\xff\xda')
    jpeg_out.write(pack(">HB",6+2*components,components))
    jpeg_out.write(b'\x01\x00\x02\x11\x03\x11')
    jpeg_out.write(pack("BBB",0,63,0))
    Y_dc_ptr, CB_dc_ptr, CR_dc_ptr, Y_ac_ptr, CB_ac_ptr, CR_ac_ptr,Y_dc_vli_ptr, CB_dc_vli_ptr, CR_dc_vli_ptr, Y_ac_vli_ptr, CB_ac_vli_ptr, CR_ac_vli_ptr = 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0
    data = ""
    while Y_dc_ptr<len(Y_dc_list):
        data = data+lum_dc_code[Y_dc_list[Y_dc_ptr]]+Y_dc_vli_list[Y_dc_vli_ptr]
        Y_dc_ptr+=1
        Y_dc_vli_ptr+=1
        length_block = 1
        while length_block<N*N and Y_ac_list[Y_ac_ptr]!=0:
            data = data+lum_ac_code[Y_ac_list[Y_ac_ptr]]+Y_ac_vli_list[Y_ac_vli_ptr]

            length_block += (Y_ac_list[Y_ac_ptr]>>4)
            length_block += 1
            Y_ac_ptr+=1
            Y_ac_vli_ptr+=1
        if length_block<N*N:
            data = data+lum_ac_code[0]
            Y_ac_ptr+=1

        data = data+chr_dc_code[CB_dc_list[CB_dc_ptr]]+CB_dc_vli_list[CB_dc_vli_ptr]
        CB_dc_ptr+=1
        CB_dc_vli_ptr+=1
        length_block = 1
        while length_block<N*N and CB_ac_list[CB_ac_ptr]!=0:
            data = data+chr_ac_code[CB_ac_list[CB_ac_ptr]]+CB_ac_vli_list[CB_ac_vli_ptr]
            length_block += (CB_ac_list[CB_ac_ptr]>>4)
            length_block += 1
            CB_ac_ptr+=1
            CB_ac_vli_ptr+=1
        if length_block<N*N:
            data = data+chr_ac_code[0]
            CB_ac_ptr+=1

        data = data+chr_dc_code[CR_dc_list[CR_dc_ptr]]+CR_dc_vli_list[CR_dc_vli_ptr]
        CR_dc_ptr+=1
        CR_dc_vli_ptr+=1
        length_block = 1
        while length_block<N*N and CR_ac_list[CR_ac_ptr]!=0:
            data = data+chr_ac_code[CR_ac_list[CR_ac_ptr]]+CR_ac_vli_list[CR_ac_vli_ptr]
            length_block += (CR_ac_list[CR_ac_ptr]>>4)
            length_block += 1
            CR_ac_ptr+=1
            CR_ac_vli_ptr+=1
        if length_block<N*N:
            data = data+chr_ac_code[0]
            CR_ac_ptr+=1

    while len(data)%8!=0:
        data=data+"0"
    data = add_FF00(data)
    data=bitarray(data)
    jpeg_out.write(data)


In [58]:
def get_vli(number):
    if number == 0:
        return ''

    if number < 0:
        absolute_value = -number
        inverted_absolute_value = (~absolute_value) & ((1 << absolute_value.bit_length()) - 1)
        vli_representation = bin(inverted_absolute_value)[2:]
        vli_representation = "0" * (absolute_value.bit_length() - len(vli_representation)) + vli_representation
    else:
        vli_representation = bin(number)[2:]

    return vli_representation


In [59]:
def rle_to_bits(rle):
    rl_class,vli = [],[]
    for i in range(0,len(rle),3):
        rl_class.append(((rle[i]<<4)+rle[i+1]))
        if rl_class[-1]==0:
            break
        vli.append(get_vli(rle[i+2]))
    return rl_class,vli

In [60]:
def pad_image(original_red_channel, original_green_channel, original_blue_channel):
    new_height, new_width = ((len(original_red_channel) + 7) // 8) * 8, ((len(original_red_channel[0]) + 7) // 8) * 8
    padded_red_channel = [
        [original_red_channel[min(i, len(original_red_channel) - 1)][min(j, len(original_red_channel[0]) - 1)] for j
         in range(new_width)] for i in range(new_height)]
    padded_green_channel = [
        [original_green_channel[min(i, len(original_green_channel) - 1)][min(j, len(original_green_channel[0]) - 1)] for j
         in range(new_width)] for i in range(new_height)]
    padded_blue_channel = [
        [original_blue_channel[min(i, len(original_blue_channel) - 1)][min(j, len(original_blue_channel[0]) - 1)] for j
         in range(new_width)] for i in range(new_height)]

    return padded_red_channel, padded_green_channel, padded_blue_channel


In [61]:
def encode(r, g, b):
    height, width = len(r), len(r[0])
    r, g, b = pad_image(r, g, b)

    quant_lum_matrix = [[16, 11, 10, 16, 24, 40, 51, 61],
                        [12, 12, 14, 19, 26, 48, 60, 55],
                        [14, 13, 16, 24, 40, 57, 69, 56],
                        [14, 17, 22, 29, 51, 87, 80, 62],
                        [18, 22, 37, 56, 68, 109, 103, 77],
                        [24, 35, 55, 64, 81, 104, 113, 92],
                        [49, 64, 78, 87, 103, 121, 120, 101],
                        [72, 92, 95, 98, 112, 100, 103, 99]]

    quant_chr_matrix = [[17, 18, 24, 47, 99, 99, 99, 99],
                        [18, 21, 26, 66, 99, 99, 99, 99],
                        [24, 26, 56, 99, 99, 99, 99, 99],
                        [47, 66, 99, 99, 99, 99, 99, 99],
                        [99, 99, 99, 99, 99, 99, 99, 99],
                        [99, 99, 99, 99, 99, 99, 99, 99],
                        [99, 99, 99, 99, 99, 99, 99, 99],
                        [99, 99, 99, 99, 99, 99, 99, 99]]

    scale = float(input("Enter the scaling on the standard quant matrix:"))
    quant_lum_matrix = [[min(max(1, int(quant_lum_matrix[i][j] * scale)), 255) for j in range(N)] for i in range(N)]
    quant_chr_matrix = [[min(max(1, int(quant_chr_matrix[i][j] * scale)), 255) for j in range(N)] for i in range(N)]

    Y_dc_bits_list, CB_dc_bits_list, CR_dc_bits_list = [], [], []
    Y_ac_bits_list, CB_ac_bits_list, CR_ac_bits_list = [], [], []
    Y_dc_vli_list, CB_dc_vli_list, CR_dc_vli_list = [], [], []
    Y_ac_vli_list, CB_ac_vli_list, CR_ac_vli_list = [], [], []

    prev_Y_dc, prev_CB_dc, prev_CR_dc = 0, 0, 0

    for i in range(0, len(r)//8):
        for j in range(0, len(r[0])//8):
            cur_r = [[r[x][y] - 128 for y in range(j*8, j*8 + 8)] for x in range(i*8, i*8 + 8)]
            cur_g = [[g[x][y] - 128 for y in range(j*8, j*8 + 8)] for x in range(i*8, i*8 + 8)]
            cur_b = [[b[x][y] - 128 for y in range(j*8, j*8 + 8)] for x in range(i*8, i*8 + 8)]

            y, cb, cr = rgb_to_ycbcr(cur_r, cur_g, cur_b)
            Y, CB, CR = dct2(y), dct2(cb), dct2(cr)

            Y, CB, CR = quantize(Y, quant_lum_matrix), quantize(CB, quant_chr_matrix), quantize(CR, quant_chr_matrix)
            Y_z, CB_z, CR_z = zigzag_order(Y), zigzag_order(CB), zigzag_order(CR)

            Y_rle, CB_rle, CR_rle = run_length_encode(Y_z), run_length_encode(CB_z), run_length_encode(CR_z)
            Y_rl_class, Y_vli = rle_to_bits(Y_rle)
            CB_rl_class, CB_vli = rle_to_bits(CB_rle)
            CR_rl_class, CR_vli = rle_to_bits(CR_rle)

            Y_ac_bits_list += Y_rl_class
            CB_ac_bits_list += CB_rl_class
            CR_ac_bits_list += CR_rl_class

            Y_ac_vli_list += Y_vli
            CB_ac_vli_list += CB_vli
            CR_ac_vli_list += CR_vli

            Y_dc, CB_dc, CR_dc = Y_z[0] - prev_Y_dc, CB_z[0] - prev_CB_dc, CR_z[0] - prev_CR_dc
            Y_dc_bits_list.append(int(Y_dc).bit_length())
            CB_dc_bits_list.append(int(CB_dc).bit_length())
            CR_dc_bits_list.append(int(CR_dc).bit_length())

            Y_dc_vli_list.append(get_vli(Y_dc))
            CB_dc_vli_list.append(get_vli(CB_dc))
            CR_dc_vli_list.append(get_vli(CR_dc))

            prev_Y_dc, prev_CB_dc, prev_CR_dc = Y_z[0], CB_z[0], CR_z[0]

    lum_dc_code = huffcode_gen(Y_dc_bits_list)
    lum_ac_code = huffcode_gen(Y_ac_bits_list)
    chr_dc_code = huffcode_gen(CB_dc_bits_list + CR_dc_bits_list)
    chr_ac_code = huffcode_gen(CB_ac_bits_list + CR_ac_bits_list)

    output_file = input("Enter output file name: ")
    jpeg_out = open(output_file + ".jpg", "wb")
    jpeg_out.write(b'\xff\xd8\xff\xe0')
    jpeg_out.write(pack(">H", 16))
    jpeg_out.write(b'JFIF\x00\x01\x01\x01\x00H\x00H\x00\x00')

    write_quant_table(jpeg_out, quant_lum_matrix, 0)
    write_quant_table(jpeg_out, quant_chr_matrix, 1)

    write_SOF(jpeg_out, height, width, 3)

    write_huffman_table(jpeg_out, 0, lum_dc_code)
    write_huffman_table(jpeg_out, 16, lum_ac_code)
    write_huffman_table(jpeg_out, 1, chr_dc_code)
    write_huffman_table(jpeg_out, 17, chr_ac_code)

    write_SOS(jpeg_out, 3, Y_dc_bits_list, CB_dc_bits_list, CR_dc_bits_list,
              Y_ac_bits_list, CB_ac_bits_list, CR_ac_bits_list,
              Y_dc_vli_list, CB_dc_vli_list, CR_dc_vli_list,
              Y_ac_vli_list, CB_ac_vli_list, CR_ac_vli_list,
              lum_dc_code, lum_ac_code, chr_dc_code, chr_ac_code)

    jpeg_out.write(b'\xff\xd9')
    jpeg_out.close()


In [65]:
file = input("Enter raw image file name: ")
img = cv2.imread(file)
r = img[:,:,2]
r = [[r[i][j] for j in range(len(r[0]))] for i in range(len(r))]
g = img[:,:,1]
g = [[g[i][j] for j in range(len(g[0]))] for i in range(len(g))]
b = img[:,:,0]
b = [[b[i][j] for j in range(len(b[0]))] for i in range(len(b))]

encode(r,g,b)

Enter raw image file name: lena_colored_256.bmp
Enter the scaling on the standard quant matrix:10
Enter output file name: out
