Sai Saradha Kalidaikurichi Lakshmanan
EECS 531 - Assignment A2

Exercise 1:
Plot the basis functions of a 16x16 discrete cosine transform (DCT).

In this exercise, I have done the following: 
1. Functions to build basis functions of DCT (any size, input by the user) from scratch
2. Functions to generate basis functions of DFT
3. 2D DCT is separable and we show that by doing two 1D DCT using Scipy
4. Direct DCT with formula (O(N^4))
5. DCT on image (basics of JPEG Compression)

Discrete Cosine Transform expresses the images in terms of sum of cosine functions of different frequencies. DCT is one of the most commonly used techniques for image compression. The base of the JPEG compression technique is through DCT and we explore that by performing DCT on images. DCT is a linear invertible function and there are different variants to it, the most commonly used ones being DCT-II and DCT-III (also known as IDCT). DCT-II transforms an image block of size $N$ $\times$ $N$ into DCT coefficients that can be expressed as:

$$S(k_{1},k_{2}) = \sqrt{\frac{4}{N^{2}}}C(k_{1})C(k_{2}) \sum_{n_{1}=0}^{N-1} \sum_{n_{2}=0}^{N-1} s(n_{1},n_{2})\cos\left(\frac{\pi(2n_{1}+1)k_{1}}{2N}\right) \cos\left(\frac{\pi(2n_{2}+1)k_{2}}{2N}\right)$$

Where $k_{1}$,$k_{2}$, $n_{1}$,$n_{2}$ = $0,1,... N-1$ and $C(k) = 1/\sqrt{2}$ for $k = 0$ and $1$ otherwise.
DCT-III also known as IDCT can be expressed as:

$$s(n_{1},n_{2}) = \sqrt{\frac{4}{N^{2}}}\sum_{k_{1}=0}^{N-1} \sum_{k_{2}=0}^{N-1} C(k_{1})C(k_{2}) S(k_{1},k_{2})\cos\left(\frac{\pi(2n_{1}+1)k_{1}}{2N}\right) \cos\left(\frac{\pi(2n_{2}+1)k_{2}}{2N}\right)$$

The indices$(k_{1}, k_{2})$ after transformation are the spatial frequencies. The value for $k_{1}=k_{2}=0$ corresponds to the DC component and the remaining higher frequencies correspond to the AC components.

We explore DCT and some of its properties in this exercise.

In [1]:
import numpy as np
import os
import cv2
from scipy import ndimage
import scipy
import time
from math import sqrt, cos, pi
import matplotlib 
import matplotlib.pyplot as ax
from scipy.fftpack import dct, idct

'\nExercise 1: Basis functions of DCT\nThe following implementations have been done:\n1. Basis function of DCT (any size, input by the user) from scratch\n2. Basis function generation of DFT\n3. Show separability - by 2 1D DCT - using DCT with Scipy\n4. Direct DCT with formula (O(N^4))\n5. DCT on image (basics of JPEG Compression)\n'

In [2]:
class transforms_freq:
    
    def __init__(self):
        pass

    def dct_direct(self, img_folder):
        directory = os.fsencode(img_folder)
        # Here trying 16 x 16 basis fns, change the value below for different DCT coeff size
        fs = 16
        for file in os.listdir(directory):
            tic = time.time()
            filename = os.fsdecode(file)
            Img = self.read_image(img_folder, filename)
            if Img is None:
                continue
            r, c = Img.shape
            dct_coeff = np.zeros((r,c))
            for i in range(0, r):
                coeff1 = sqrt(2/r)
                if i == 0:
                    coeff1 = 1/sqrt(r)            
                for j in range(0, c):
                    coeff2 = sqrt(2/c)
                    if j == 0:
                        coeff2 = 1/sqrt(c)
                    for m in range(0, r):
                        for n in range(0, c):
                            dct_coeff[i,j] += Img[m, n]*np.cos((pi/(2*r))*(2*m+1)*i)* np.cos((pi/(2*c))*(2*n+1)*j)
                    dct_coeff[i,j] = coeff1*coeff2* dct_coeff[i,j]
            cv2.imshow('dct_image', dct_coeff)
            cv2.waitKey(0)
            self.idct_direct(dct_coeff)
            return
        
    def idct_direct(self, img):
        r, c = img.shape
        idct_coeff = np.zeros((r,c))
        for i in range(0, r):
                coeff1 = sqrt(2/r)
                if i == 0:
                    coeff1 = 1/sqrt(r)            
                for j in range(0, c):
                    coeff2 = sqrt(2/c)
                    if j == 0:
                        coeff2 = 1/sqrt(c)
                    for m in range(0, r):
                        for n in range(0, c):
                            idct_coeff[i,j] += coeff1*coeff2*img[m, n]*np.cos((pi/(2*r))*(2*i+1)*m)* np.cos((pi/(2*c))*(2*j+1)*n)
        print(idct_coeff)
        cv2.imshow('idct_image', idct_coeff.astype(np.uint8))
        cv2.waitKey(0)
        return
    
    def basis_dct(self, n1, n2, k1):
        k2 = k1
        if n1 == 0:
            coeff1 = 1/sqrt(k1)
        else:
            coeff1 = sqrt(2/k1)
        if n2 == 0:
            coeff2 = 1/sqrt(k2)
        else:
            coeff2 = sqrt(2/k2)
        x = np.array(list(range(k1)))
        y = np.array(list(range(k2)))
        Y, X = np.meshgrid(y, x)
        constv= coeff1*coeff2
        #cos1 = np.cos(pi*(n1/2.0/k1)*(2*X+1))
        #cos2 = np.cos(pi*(n2/2.0/k2)*(2*Y+1))
        cos1 = np.cos((pi/(2*k1))*((2*X)+1)*n1)
        cos2 = np.cos((pi/(2*k2))*((2*Y)+1)*n2)
        basis_mt = constv*np.multiply(cos1,cos2)
        return basis_mt

    def basis_idct(self, n1, n2, k1):
        k2 = k1
        coeff1 = sqrt(2/k1)
        if n1 == 0:
            coeff1 = 1/sqrt(k1)
        coeff2 = sqrt(2/k2)
        if n2 == 0:
            coeff2 = 1/sqrt(k2)
        constv= coeff1*coeff2
        x = np.array(list(range(k1)))
        y = np.array(list(range(k2)))
        Y, X = np.meshgrid(y, x)
        cos1 = np.cos(pi*(2*n1+1)*X/2.0/k1)
        cos2 = np.cos(pi*(2*n2+1)*Y/2.0/k2)
        #cos1 = np.cos((pi/(2*k1))*((2*n1)+1)*X)
        #cos2 = np.cos((pi/(2*k2))*((2*n2)+1)*Y)
        basis_mt = constv*np.multiply(cos1,cos2)
        return basis_mt

    def idct_f(self, fs):
        '''basis_fns = np.zeros((fs*fs, fs*fs))
        el = 0
        for i in range(0, fs):
            for j in range(0, fs):
                this_basis = self.basis_idct(i, j, fs)
                basis_fns[el,:] = np.ravel(this_basis)
                el += 1
        print(basis_fns)       
        new_idct = self.basis_plot(basis_fns, fs)
        return new_idct.T'''
        k=0
        l=0
        dump2 = np.zeros((fs, fs*fs))
        dump1 = np.zeros((fs, fs))
        idct_basis = np.zeros((fs*fs, fs*fs))
        for m in range(0, fs):
            dump2 = np.zeros((fs, fs*fs))
            for n in range(0, fs):
                dump1 = np.zeros((fs, fs))
                k=0
                for p in range(0, fs):
                    for q in range(0, fs):
                        ap = sqrt(2/fs)
                        if p==0:
                            ap = 1/sqrt(fs)
                        aq = sqrt(2/fs)
                        if q==0:
                            aq = 1/sqrt(fs)
                        dump1[k,l] = ap*aq*cos(pi*(2*m+1)*p/2/fs)*cos(pi*(2*n+1)*q/2/fs);
                        l += 1
                    k+= 1
                    l=0
                if n == 0:
                    dump2 = dump1
                    continue
                dump2 = np.hstack((dump2,dump1))
            if m==0:
                idct_basis = dump2
                continue
            idct_basis = np.vstack((idct_basis,dump2))
        return idct_basis

    def basis_plot(self, dct_c, fs):
        norm_inst = matplotlib.colors.Normalize(-1,1)
        new_dct = np.zeros((fs*fs, fs*fs))
        patx = 0
        paty = 0
        count=0
        for i in range(0, fs*fs):
            this_row = dct_c[i,:]
            this_row_array = np.reshape(this_row, (fs, fs))
            # To only visualize the basis plot, uncomment the below section
            # and the imshow section
            '''max_val = abs(np.amax(this_row_array))
            min_val = abs(np.amin(this_row_array))
            if max_val >= min_val:
                this_row_array = this_row_array/max_val
            else:
                this_row_array = this_row_array/min_val'''
            new_dct[patx:patx+fs, paty:paty+fs] = this_row_array
            '''patx += fs
            count += 1
            if (i>0 and count==fs):
                paty += fs
                patx = 0
                count = 0'''
            paty += fs
            count += 1
            if (i>0 and count==fs):
                patx += fs
                paty = 0
                count = 0
        #ax.figure()
        #ax.imshow(new_dct, cmap='gray', aspect='auto', norm=norm_inst)
        #ax.show()
        return new_dct

    def dct_f(self, fs):
        basis_fns = np.zeros((fs*fs, fs*fs))
        el = 0
        for i in range(0, fs):
            for j in range(0, fs):
                this_basis = self.basis_dct(i, j, fs)
                basis_fns[el,:] = np.ravel(this_basis)
                el += 1
        new_dct = self.basis_plot(basis_fns, fs)
        return new_dct

    def get_coeff(self, imgpat, bdct, fs):
        h, w = imgpat.shape
        dct_coeff = np.zeros((h,w))
        p = 0
        q = 0
        for i in range(0, fs*fs, fs):
            q = 0
            for j in range(0, fs*fs, fs):
                dct_patch = bdct[i:i+fs, j:j+fs]
                #coeff1 = sqrt(2/h)
                #if p == 0:
                        #coeff1 = 1/sqrt(h)
                #coeff2 = sqrt(2/w)
                #if q == 0:
                        #coeff2 = 1/sqrt(w)
                #constv= coeff1*coeff2
                Bpq = np.sum(np.multiply(imgpat, dct_patch))
                dct_coeff[p, q] = Bpq
                q+=1
            p += 1                
        return dct_coeff                
    
    def dct_img_calc(self, Img, fs):
        # get the basis fns
        basis_dct = self.dct_f(fs)
        # get the image size:
        r, c = Img.shape
        print(r, c)
        # center the images to [-128, 127] by subtracting 128 from each pixel in the image
        #Img = Img - 128
        # dct coefficients image:
        dct_coeff = np.zeros((r, c))
        # now let's split the image into fs x fs blocks and compute the dct image:
        for i in range(0, r, fs):
            for j in range(0, c, fs):
                img_patch = Img[i:i+fs, j:j+fs]
                # now with this image patch, we are going to find the dct image :
                dct_coeff[i:i+fs, j:j+fs] = self.get_coeff(img_patch, basis_dct, fs)              
        return dct_coeff

    def idct_img_calc(self, Img, fs):
        # get the basis fns
        basis_idct = self.idct_f(fs)
        # get the image size:
        r, c = Img.shape
        # dct coefficients image:
        idct_img = np.zeros((r, c))
        # now let's split the image into fs x fs blocks and compute the dct image:
        for i in range(0, r, fs):
            for j in range(0, c, fs):
                img_patch = Img[i:i+fs, j:j+fs]
                # now with this image patch, we are going to find the dct image :
                idct_img[i:i+fs, j:j+fs] = self.get_coeff(img_patch, basis_idct, fs)
        # idct_img += 128
        return idct_img

    def dct_scipy(self, img_folder):
        directory = os.fsencode(img_folder)
        for file in os.listdir(directory):
            tic = time.time()
            filename = os.fsdecode(file)
            Img = self.read_image(img_folder, filename)
            if Img is None:
                continue
            dct_img = dct( dct(Img, axis=0, norm='ortho' ), axis=1, norm='ortho' )
            print("DCT Image")
            cv2.imshow('DCT Image', dct_img.astype(np.uint8))
            idct_img = idct( idct( dct_img, axis=0 , norm='ortho'), axis=1 , norm='ortho')
            print("Showing IDCT image")
            cv2.imshow('IDCT Recovered image', idct_img.astype(np.uint8))
        return

    def dct_im(self, img_folder):
        directory = os.fsencode(img_folder)
        # Here trying 16 x 16 basis fns, change the value below for different DCT coeff size
        fs = 16
        for file in os.listdir(directory):
            tic = time.time()
            filename = os.fsdecode(file)
            Img = self.read_image(img_folder, filename)
            if Img is None:
                continue
            dct_image = self.dct_img_calc(Img, fs)
            #dct_image = dct_image.astype(np.uint8)
            print("DCT Image")
            cv2.imshow('DCT Image', dct_image.astype(np.uint8))
            # Applying IDCT to recover the image from DCT Image
            idct_image = self.idct_img_calc(dct_image, fs)
            #idct_image = idct_image.astype(np.uint8)
            print("Showing IDCT image")
            cv2.imshow('IDCT Recovered image', idct_image.astype(np.uint8))
            cv2.waitKey(0)
        return

    def basis_dft_real(self, n1, n2, k1):
        k2 = k1
        coeff = 1/(k1*k2)
        x = np.linspace(0,k1, k1)
        y = np.linspace(0,k2, k2)
        X, Y = np.meshgrid(x, y)
        cos_val = np.cos((2*pi)*(((n1/k1)*X)+((n2/k2)*Y)))
        basis_mt = coeff*cos_val
        return basis_mt

    def basis_dft_imgnry(self, n1, n2, k1):
        k2 = k1
        coeff = -1/(k1*k2)
        x = np.linspace(0,k1, k1)
        y = np.linspace(0,k2, k2)
        X, Y = np.meshgrid(x, y)
        cos_val = np.sin((2*pi)*(((n1/k1)*X)+((n2/k2)*Y)))
        basis_mt = coeff*cos_val
        return basis_mt

    def dft_f(self, fs):
        # real basis fns
        r_basis_fns = np.zeros((fs*fs, fs*fs))
        el = 0
        for i in range(0, fs):
            for j in range(0, fs):
                this_basis = self.basis_dft_real(i, j, fs)
                r_basis_fns[el,:] = np.ravel(this_basis)
                el += 1
        self.basis_plot(r_basis_fns, fs)
        # imaginary basis fns
        i_basis_fns = np.zeros((fs*fs, fs*fs))
        el = 0
        for i in range(0, fs):
            for j in range(0, fs):
                this_basis = self.basis_dft_imgnry(i, j, fs)
                i_basis_fns[el,:] = np.ravel(this_basis)
                el += 1
        self.basis_plot(i_basis_fns, fs)
        return r_basis_fns, i_basis_fns 
        return
    
    def read_image(self, directory, filename):
        if filename.startswith("i") and (filename.endswith(".jpg") or filename.endswith(".bmp") or filename.endswith(".png") or filename.endswith(".gif")):
            print(filename)
            img = cv2.imread(directory + filename,0)
            if img is None:
                print('Invalid image:' + filename)
                return None
            else:
                # cv2.imshow('Image', img)
                # cv2.waitKey(0)
                return img
   
    def main(self, transform_choice, input_val):
        options = {'dct_f':self.dct_f,
                   'dft_f':self.dft_f,
                   'dct_im': self.dct_im,
                   'dct_scipy': self.dct_scipy,
                   'dct_dir':self.dct_direct,
            }
        options[transform_choice](input_val)
        return

In [None]:
if __name__ == "__main__":

    # creating an instance of the transforms class
    transform_class = transforms_freq()

    # Call to the main function
    # Arguments to the main function are transform choice & size and if
    # user chooses image based operation(jpeg, with scipy, on image), there will be a prompt to input the directory to test images
    # from the function

    tic = time.time()
    print('Choices are : \n 1. DCT (dct_f), \n 2. DFT (dft_f), \n 3. DCT with scipy (dct_scipy)  \n 4. DCT on image (dct_im) and \n 5. dct_direct (dct_dir)')
    choice_v = input("Enter the choice: ")
    img_folder = ""
    if choice_v != 'dct_f' and choice_v != 'dft_f':
        input_val = input("Enter the path to image directory: ")
    else:
        input_val = int(input("Enter the size: "))
    transform_class.main(choice_v, input_val)
    toc = time.time() - tic
    print("Running time: " + str(toc))


Plotting DCT Basis:
The functions basis_dct, basis_idct compute the basis functions of the DCT of given size (eg. 8, 16, etc.,). I have not separated the different basis functions by boxes, but confirmed that the values are the same. It is clear that the frequencies increase as we move from left to right and top to bottom and the last box is completely checked. The plot of 256 basis functions (16 x 16) are :

![Alt text](imgs/dct_basis_256_noborder.png?raw=true "dct_basis_256_noborder")

**Plotting DFT basis:**

For a discrete signal $f(m, n)$ the 2D DFT is given as:

$$F(k,l) = \frac{1}{MN} \sum_{m=0}^{M-1}\sum_{n=0}^{N-1}f(m,n)\exp\left(-2\pi \mathrm{i} \left(\frac{km}{M}+\frac{ln}{N}\right)\right)$$

DFT includes both real and imaginary parts, hence the two parts of the basis functions can be given by:

$$\mathrm{Re} = \frac{\cos(\theta)}{MN} = \cos\left(2\pi\left(\frac{km}{M}+\frac{ln}{N}\right)\right)\frac{1}{MN}$$
$$\mathrm{Im} = - \frac{\sin(\theta)}{MN} = -\sin\left(2\pi\left(\frac{km}{M}+\frac{ln}{N}\right)\right)\frac{1}{MN}$$

The real and imaginary part plot of the 256 basis functions of DFT are (16 x 16):
![Alt text](imgs/dft_basis_real.png?raw=true "dft_basis_real")
![Alt text](imgs/dft_basis_imaginary.png?raw=true "dft_basis_imaginary")

There are three other implementations in this exercise - Direct DCT on image, DCT with scipy to demonstrate separability and my implementation of dct on image to demonstrate the basics of JPEG compression.

1. Direct DCT on image:
The 2D DCT from section above (eq. 1), is highly inefficient as it involves O(N&4)) computations. It takes too long to perform DCT as we can see from the output below (~ 101 sec even for a very small 48 x 48 image) as opposed to the efficient DCT implementation that I have done (~ 9 seconds for the same image). There are FFT based algorithms to speed up the process. 

![Alt text](imgs/runtime_dctdir.PNG?raw=true "runtime_dctdir")

![Alt text](imgs/runtime_dctim.PNG?raw=true "runtime_dctim")

An even efficient implementation is using the separability property of the DCT. Scipy's dct uses this property and hence in place of 2D DCT, we perform two 1D DCTs - first in the horizontal direction (on rows) and then apply DCT again in the columns on this image to get the DCT coefficients. The runtime, DCT image and reconstructed image (applying IDCT) are shown below :

![Alt text](imgs/runtime_dctscipy_tiger.PNG?raw=true "runtime_dctscipy_tiger")

The DCT on image that I have done here involves DCT computation on sub-image basis. So the DCT image output that we see here is much different from the ones from Scipy or matlab because scipy and matlab use fft based algorithms to work on the whole image instead of a sub-image. But as we can see, the reconstructed image looks equivalent to the original image.

![Alt text](imgs/dct_tiger.PNG?raw=true "dct_tiger")

![Alt text](imgs/dct_tiger_scipy.PNG?raw=true "dct_tiger_scipy")

Reconstructed Image:
![Alt text](imgs/idct_tiger.PNG?raw=true "idct_tiger")

JPEG Compression :
While the actual jpeg technique involves a lot of sophisticated techniques, it is possible to specify at a higher level, the primary steps involved in the compression. DCT is the most preferred compression transform, mainly because of the boundary behavior of cosine functions. Basically there are two major things to compression - encoding and decoding. To send the images, at the encoding side it is compressed and at the decoding side the image is reconstructed again. JPEG is a lossy compression, meaning there is some loss of information at the receiving end and the reconstructed image is not exactly the same as the original image. Steps for compression include:

1. Choose the right color model for the image (for example YCbCr space instead of the RGB color channel).
2. Select the appropriate resolution (downsampling if required)
3. Perform DCT and obtain the DCT coefficients, this provides the contribution of the different cosine frequencies to the image
4. Quantization - quantization results in loss of information as number of frequency components get values close or equal to 0
5. Encoding - The quantized result is encoded using one of the algorithms such as huffman encoding
6. Add appropriate information before transmission and in the receiving end, the reverse process is done. 

To demonstrate this, we perform DCT on image. In exercise 2, we observe that some of the coeffiencts in the image even when zeroed, there is very minimal loss of information and the image looks pretty good.

![Alt text](imgs/krishna_dct_idct.PNG?raw=true "krishna_dct_idct")
![Alt text](imgs/small_dct_idct.PNG?raw=true "small_dct_idct")
