In [162]:
from numba import jit, prange
import numpy as np

@jit#(cache=True)
def Lambda(x) -> np.ndarray:
    return 1 if x else 1/np.sqrt(2)

@jit#(nopython=True, fastmath=True, nogil=True, cache=True, parallel=True)
def precompute_cos() -> np.ndarray:
    cos_table = np.zeros((8,8))
    for u in prange(8):
        for i in prange(8):
            cos_table[u,i] = np.cos(np.pi/8*(i+.5)*u)
    return cos_table
@jit#(nopython=True, fastmath=True, nogil=True, cache=True, parallel=True)
def forward_dct(X:np.ndarray):
    # precompute cosines
    N, M, _, _ = X.shape
    cos_table = precompute_cos()

    # iterate blocks
    Y = np.zeros((N, M, 8, 8))
    for h in prange(N):
        for w in prange(M):

            # iterate coefficients
            for u in prange(8):
                for v in prange(8):
                    
                    # iterate pixels
                    Y[h,w,u,v] = 0
                    for i in prange(8):
                        for j in prange(8):
                            
                            # compute Y_uv
                            Y[h,w,u,v] += (
                                X[h,w,i,j] *
                                cos_table[u,i] *
                                cos_table[v,j])
                    Y[h,w,u,v] *= (
                        np.sqrt(2/8) *
                        np.sqrt(2/8) *
                        Lambda(u) *
                        Lambda(v))
    return Y

@jit#(nopython=True, fastmath=True, nogil=True, cache=True, parallel=True)
def backward_dct(Y:np.ndarray):
    # precompute cosines
    N, M, _, _ = Y.shape
    cos_table = precompute_cos()
    
    # iterate blocks
    X = np.zeros((N, M, 8, 8))
    for h in prange(N):
        for w in prange(M):

            # iterate pixels
            for i in prange(8):
                for j in prange(8):
                    
                    # iterate coefficients
                    X[h,w,i,j] = 0
                    for u in prange(8):
                        for v in prange(8):
                            
                            # compute X_ij
                            X[h,w,i,j] += (
                                Lambda(u) *
                                Lambda(v) *
                                Y[h,w,u,v] *
                                cos_table[u,i] *
                                cos_table[v,j])
                    X[h,w,i,j] *= (
                        np.sqrt(2/8) *
                        np.sqrt(2/8))
    return X

In [164]:
# read image
import jpeglib
im = jpeglib.read_spatial('IMG_0311.jpeg')
x = im.spatial
N, M, C = x.shape

# split into 8x8 blocks
xh = np.split(x, (N+7)//8, axis=0)
xhv = [np.split(r, (M+7)//8, axis=1) for r in xh]
xb = np.array(xhv)


# input to DCTs
X = xb[:,:,:,:,0]
#Y = np.zeros(X.shape)
X2 = np.zeros(X.shape)

# DCTs
print(X[0,0])
import time
start = time.time()
Y = forward_dct(X)
end = time.time()
print(Y[0,0])
print(f"forward {end - start}")
start = time.time()
X2 = backward_dct(Y)
end = time.time()
print(X2[0,0])
print(f"backward {end - start}")



[[169 169 169 170 170 170 170 170]
 [169 169 169 170 170 170 170 170]
 [169 169 169 170 170 170 170 171]
 [169 169 170 170 170 170 171 171]
 [169 169 170 170 170 170 171 171]
 [169 170 170 170 170 171 171 171]
 [170 170 170 170 170 171 171 171]
 [170 170 170 170 170 171 171 171]]
[[ 1.36000000e+03 -4.13967549e+00 -1.35655988e-13 -3.78659296e-01
   4.27065149e-14  3.48332835e-01 -1.05510213e-12 -3.13637073e-01]
 [-2.31503406e+00  1.42108547e-14 -8.37152602e-01  7.10542736e-15
   1.03357634e+00  1.42108547e-14 -3.46759961e-01  0.00000000e+00]
 [-1.40680284e-13  7.68177757e-01 -1.42108547e-14  9.06127446e-01
   3.55271368e-15 -1.80239956e-01 -1.77635684e-15 -5.13279967e-01]
 [ 1.85377553e-02 -7.10542736e-15 -2.93968901e-01  7.10542736e-15
  -4.68525867e-01 -3.55271368e-15 -1.21765906e-01  7.10542736e-15]
 [ 8.03887339e-14 -9.75451610e-02  3.55271368e-15 -2.77785117e-01
  -7.10542736e-15  4.15734806e-01 -5.32907052e-15  4.90392640e-01]
 [ 3.13058976e-01  7.10542736e-15  1.96423740e-01 -7.1

In [153]:
import jpeglib
start = time.time()
jpeg = jpeglib.read_dct('IMG_0311.jpeg')
jpeg.Y
end = time.time()
print(f"jpeglib {end - start}")

jpeglib 0.10067510604858398
