In [None]:
import numpy as np
import scipy.fftpack as fftpack
from skimage.io import imread, imsave
from skimage.util import img_as_ubyte
from skimage import img_as_float
from scipy.fftpack import idct
import os
import csv

In [None]:
def dct2(a):
    return fftpack.dct(fftpack.dct(a.T, norm='ortho').T, norm='ortho')

def idct2(a):
    return fftpack.idct(fftpack.idct(a.T, norm='ortho').T, norm='ortho')

def dct_compress(image, k_value):
    im_f = img_as_float(image)
    im_dct = dct2(im_f)
    for i in range(0, im_dct.shape[0], 8):
        for j in range(0, im_dct.shape[1], 8):
            im_dct[i:i+8, j:j+8] = keep_top_k(im_dct[i:i+8, j:j+8], k_value)
    
    return im_dct

def keep_top_k(matrix, k):
    flattened = matrix.flatten()
    sorted_indices = np.argsort(np.abs(flattened))[::-1]
    flattened[sorted_indices[k:]] = 0
    return flattened.reshape(matrix.shape)

def dct_decompress(dct_compressed):
    im_reconstructed = idct(idct(dct_compressed, axis=0, norm='ortho'), axis=1, norm='ortho')
    im_reconstructed_normalized = (im_reconstructed - im_reconstructed.min()) / (im_reconstructed.max() - im_reconstructed.min())
    im_reconstructed_ubyte = img_as_ubyte(im_reconstructed_normalized)
    return im_reconstructed_ubyte


def format_size(size_in_bytes):
    for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
        if size_in_bytes < 1024.0:
            break
        size_in_bytes /= 1024.0
    return f"{size_in_bytes:.2f} {unit}"

def mean_squared_error(original_image, decompressed_image):
    if original_image.shape != decompressed_image.shape:
        raise ValueError("Both images must have the same shape for MSE calculation.")
    mse = np.mean((original_image - decompressed_image) ** 2)
    return mse

def csv_write(file_path, adjusted_list, mean_score_list):
    csv_file_path = file_path  
    with open(csv_file_path, 'w', newline='') as csv_file:
        csv_writer = csv.writer(csv_file)
        for k, size_dct in enumerate(adjusted_list):
            csv_writer.writerow([k+1, size_dct, mean_score_list[k]])

In [None]:
image_path = '/Users/aravdhoot/Math-EE/ct_image_bw.png'

In [None]:
mean_score_dct = list()
compress_size_dct = list()
image = imread(image_path)
original_size = format_size(os.path.getsize(image_path))
print(f"Original Size—{original_size}")
for k_value in range(1, 64): 
    dct_compressed = dct_compress(image, k_value=k_value)
    image_reconstructed = dct_decompress(dct_compressed)
    imsave(f'dct_compressed_{k_value}.jpg', image_reconstructed)
    compressed_size = format_size(os.path.getsize(f'dct_compressed_{k_value}.jpg'))
    print(image_reconstructed.shape)
    mse_score = mean_squared_error(img_as_float(image), img_as_float(image_reconstructed))
    print(f"K-Value—{k_value} || Compressed Size-{compressed_size} || Mean Score—{mse_score}")
    mean_score_dct.append(mse_score)
    compress_size_dct.append(compressed_size)

In [None]:
adjusted_dct = [float(value.split(' ')[0]) for value in compress_size_dct]
csv_write('dct.ipynb', adjusted_dct, mean_score_dct)