# An infodump on JPEG compression

In [4]:
# Python libraries which we will need later

from collections import namedtuple
from math import pi, cos, sin, sqrt
from PIL import Image
import struct

In [5]:
# These libraries are used for various explanations and demos
# They are *NOT* required for any of the actual JPEG code itself
import numpy as np
import scipy as sp
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
%matplotlib inline

In [11]:
# this will be useful later
def divroundup(num, divisor):
    return (num + divisor - 1) // divisor

def clamp(val):
    val = round(val)
    if val < 0:
        return 0
    if val > 255:
        return 255
    return val

## Preface

A while back, I happened to come across the following meme on fedi:

<img src="imgs/fedi-meme.png" width="500" alt="Screenshot of a meme by @retr0id@retr0.id, featuring a variation on the &quot;how to draw an owl&quot; meme. The first step says &quot;Get some of these thingies&quot; and features a grid of DCT basis functions."/>

Heh, humorous. Especially if you *do* actually know how JPEG works.

But wait, how many people *actually* know how JPEG works? Not just in theory but in enough detail to build a working JPEG decoder and encoder? In this notebook, I hope to explain it in enough detail to not just be able to _implement_ JPEG (since there are plenty of existing software libraries for that) but to understand the "how and why" of how JPEG works.

This article is being written at the beginning of the year 2025, at a time when social structures around the world are... struggling. Technological change has played a massive role in this (for both "good" as well as "bad"!), and, although other technologists, activists, and other stakeholders all have their own ideas about what should be done about this situation, I the author believe that a powerful tool to wield against current trends is to *understand* how technological systems work. Take them apart, layer by layer. See what is hiding under the hood. Perhaps use this resulting knowledge to design newer, better systems, or perhaps use it to inform other work.

With that, let's encode the rest of the fucking owl!

## The JPEG specification

Before we try to encode our own JPEGs, we will start by trying to decode one. For this exercise, we will be working with the following image of London Pride by MangakaMaiden Photography, obtained [from Wikimedia Commons](https://commons.wikimedia.org/wiki/File:London_Pride_2023_(53079352061).jpg):

<img src="imgs/London_Pride_2023_(53079352061).jpg" width="1000" alt="photo of London Pride"/>

In order to be able to decode a JPEG, the first thing we are going to need is the _specification_ defining how the format works. This standard is called either **ISO/IEC 10918-1** or **ITU-T Recommendation T.81**. Here we run into our first tiny hurdle: due to the way certain standards bodies operate, access to this standard requires a payment of 216 CHF if we attempt to obtain it directly from those organizations.

Fortunately, not every standards organization works this way, and restricting the spread of information this manner is not aligned with certain ideological trends on the Web, so the specification can instead be had from the W3C's website [here](https://www.w3.org/Graphics/JPEG/itu-t81.pdf).

## How to read a specification

The first thing that I usually do when confronted with a specification is to start with the table of contents, followed by very quickly skimming the rest of the document. In this case, we can see that the "main" part of the specification is actually very short, only about twenty pages! Although there is important information here, most of the actual details are specified in the various annexes.

Section 3, "Definitions, abbreviations and symbols" contains, well, exactly what it says it contains. When reading a formal specification, it is _important_ to at least remain aware of these definitions. Formal specification documents often use specialized terminology and jargon in order to refer to ideas in a concise way, and the usage doesn't always correspond with common definitions of words (e.g. 3.1.111 "run").

The remainder of the "main" part, especially section 4 "General", gives a high-level overview of how JPEG works and how and in what order the various steps (which are explained in the annexes) are put together. If we read through it, we find that JPEG is actually quite complicated! There are many different choices that can be made during the encoding process and multiple _modes of operation_ (section 4.5) which can be used.

Finally, each of the annexes explains one particular operation or sub-step in detail. We will reference each of them as needed.

A reasonable question one might ask would be "how little can we get away with implementing and have it still work?" After all, technical standards are only words on a page at the end and don't intrinsically come with any enforcement mechanisms (see related: "MUST (BUT WE KNOW YOU WON'T)" in the April Fool's [RFC 6919](https://datatracker.ietf.org/doc/html/rfc6919#page-3)). As this notebook is just a proof-of-concept, we are only going to implement the bare essentials needed for this particular image. This code will only handle a _baseline DCT_ image with _interleaved_ components in a single _scan_. (This is the most common way to encode a JPEG, but actually determining that our example image is encoded in this way requires decoding some basic metadata first.)

## Decoding the very-computer-y bits

Even though the JPEG specification requires a bunch of mathematics and more-theory-laden algorithms, file formats typically also contain "generic computer-y" information. What I mean by this phrase is "information, often metadata, which is encoded in ways that are common across many different file formats." This would include information such as the image width and height or data lengths, encoded using bytes.

Manipulating the "generic computer-y" information in a JPEG file requires basic familiarity with working with binary files and data representations. Working with binary files is outside the scope of the current notebook, but we will be using Python's [`struct` module](https://docs.python.org/3/library/struct.html) to handle it.

The relevant part of the JPEG specification we will need for this is Annex B, which explains how a JPEG file is broken up by _markers_.

### Parsing markers

JPEG _markers_ are listed in Table B.1. Each marker has a "code assignment" (i.e. bytes which will be found in the file), a shorthand symbol, and a description. Many of these markers are associated with some information, and this is called a _marker segment_. These segments all store the length of their associated data, as explained in B.1.1.4. The "actual image data" is stored as _entropy-coded data_, which is handled differently and does not contain a length. Entropy-coded data is encoded with _byte stuffing_ so that it can never contain a marker, which means that the end of the entropy-coded data occurs when a marker is seen.

As a diagram, a JPEG file looks something like this:
```text
+---------------------------+
| SOI marker (0xff 0xd8)    |
+---------------------------+   \                                                               \
| APPn marker (0xff 0xe_)   |   |                                                               |
+---------------------------+   |- APPn marker segment                                          |
| APPn metadata contents    |   |                                                               |
+---------------------------+   X                                                               |
| other markers (0xff 0x__) |   |                                                               |
+---------------------------+   |- other marker segments (e.g. data tables, Start of Frame)     |
| marker segment contents   |   |                                                               |- repeated as appropriate/necessary
+---------------------------+   X                                                               |
| SOS marker (0xff 0xda)    |   |                                                               |
+---------------------------+   |                                                               |
| scan metadata             |   |- "actual image data"                                          |
+---------------------------+   |                                                               |
| entropy-coded data        |   |                                                               |
+---------------------------+   /                                                               /
| EOI marker (0xff 0xd9)    |
+---------------------------+
```

The precise rules for how these segments must be ordered is illustrated in Figure B.16.

We can write some code to decode all of this:

In [12]:
# Load file into memory
with open('imgs/London_Pride_2023_(53079352061).jpg', 'rb') as f:
    jpeg_data = f.read()

In [13]:
# a very minimal parser for single-scan, baseline DCT JPEG images

STANDALONE_MARKERS = {
    0xffd8:     "SOI",
    0xffd9:     "EOI",
}

SEGMENT_MARKERS = {
    0xffc0:     "SOF0",
    0xffc4:     "DHT",
    0xffdb:     "DQT",
    0xffdd:     "DRI",
}

JpegSegment = namedtuple('JpegSegment', ['marker', 'data'])
JpegImage = namedtuple('JpegImage', ['marker_segs', 'scan_hdr', 'scan_data'])

def parse_jpeg_segments(data):
    got_a_scan = False
    segs = []
    # store a list of scan data, separated at restart intervals
    scan_data = []

    offset = 0
    while True:
        # try to read a marker
        (marker,) = struct.unpack(">H", data[offset:offset+2])
        if marker == 0xffff:
            # there is a fill byte here, which we skip
            offset += 1
        elif marker in STANDALONE_MARKERS:
            # these markers are labeled with a * and do not contain data
            marker_human_name = STANDALONE_MARKERS[marker]
            print(f"got {marker_human_name} @ 0x{offset:08x}")
            offset += 2
        elif marker in SEGMENT_MARKERS:
            # these markers contain a length and then data
            marker_human_name = SEGMENT_MARKERS[marker]
            (len_,) = struct.unpack(">H", data[offset+2:offset+4])
            print(f"got {marker_human_name} of length 0x{len_:04x} @ 0x{offset:08x}")
            payload = data[offset+4:offset+2+len_]
            segs.append(JpegSegment(marker, payload))
            offset += 2 + len_
        elif marker in range(0xffe0, 0xfff0):
            # these markers *also* contain a length and then data,
            # but APPn is specifically used for metadata which we are ignoring
            (len_,) = struct.unpack(">H", data[offset+2:offset+4])
            print(f"got APP{marker-0xffe0} of length 0x{len_:04x} @ 0x{offset:08x}")
            payload = data[offset+4:offset+2+len_]
            segs.append(JpegSegment(marker, payload))
            offset += 2 + len_
        elif marker == 0xffda:
            print(f"got SOS @ 0x{offset:08x}")
            assert not got_a_scan, "multi-scan not supported"
            got_a_scan = True

            # scan header, which contains a length
            (sos_hdr_len,) = struct.unpack(">H", data[offset+2:offset+4])
            sos_hdr = data[offset+4:offset+2+sos_hdr_len]

            # this now contains entropy-coded data
            offset += 2 + sos_hdr_len
            scan_data_i_start_off = offset
            while offset < len(data):
                if data[offset] != 0xff:
                    # normal data
                    offset += 1
                else:
                    # either a marker or _byte stuffing_
                    if data[offset+1] == 0x00:
                        # skip over byte stuffing for now
                        offset += 2
                    else:
                        # skip over fill bytes
                        while data[offset+1] == 0xff:
                            offset += 1
                        
                        if data[offset+1] in range(0xd0, 0xd8):
                            # print(f"got RST{data[offset+1] - 0xd0} @ 0x{offset:08x}")
                            scan_data.append(data[scan_data_i_start_off:offset])
                            offset += 2
                            scan_data_i_start_off = offset
                        else:
                            break
            # last scan
            scan_data.append(data[scan_data_i_start_off:offset])
        else:
            print(f"unknown marker 0x{marker:04x} @ 0x{offset:08x}")
            raise NotImplementedError

        if marker == 0xffd9:
            # break on EOI
            break
    
    return JpegImage(segs, sos_hdr, scan_data)

In [None]:
parsed_jpeg = parse_jpeg_segments(jpeg_data)

From this output, there are a few things to notice:

First of all, there are a number of `APPn` "Reserved for application segments" in this example file. These segments are used to store metadata which is outside of the scope of the JPEG standard. For example, the first `APP1` segment stores [Exif](https://en.wikipedia.org/wiki/Exif) data and includes information about the camera used to take the photo (a Nikon D5300 in this case). Other segments store Adobe Photoshop metadata, Adobe XMP, and a [ICC profile](https://en.wikipedia.org/wiki/ICC_profile) describing how to interpret the colors in the image.

We will not be parsing any of this metadata in our proof-of-concept, but it is important to note that this metadata can result in unintentional privacy breaches. Some online services automatically strip out very sensitive information such as GPS coordinates, but this should not be relied upon in all cases.

Second of all, the presence of the `SOF0` marker (as opposed to any other `SOFn`) confirms that we have a baseline DCT image, and the presence of only a single `SOS` marker confirms that the information is stored as a single scan.

### Frame and scan parameters

Metadata which *is* within the scope of the JPEG specification is stored in the `SOF0` and `SOS` segments. Their data fields are specified in B.2.2 and B.2.3 respectively. Let's print them out:

In [15]:
JpegComponentInfo = namedtuple('JpegComponentInfo', ['quant_idx', 'dc_idx', 'ac_idx'])
JpegSof0SosInfo = namedtuple('JpegSof0SosInfo', ['x', 'y', 'component_info'])

def parse_sof_sos(jpeg):
    # decode SOF0
    quant_table_map = {}
    for seg in jpeg.marker_segs:
        if seg.marker == 0xffc0:
            # we have already skipped over the length, Lf
            (P, Y, X, Nf) = struct.unpack(">BHHB", seg.data[:6])
            print(f"~~~~~ SOF0 ~~~~~")
            print(f"Sample precision (bits per pixel): {P}")
            print(f"Size: {X} x {Y}")
            print(f"Number of components: {Nf}")
            for cidx in range(Nf):
                component_params = seg.data[6+cidx*3:6+cidx*3+3]
                (Ci, HVi, Tqi) = struct.unpack(">BBB", component_params)
                Hi = HVi >> 4
                Vi = HVi & 0xf
                print(f"Component index {cidx}:")
                print(f"\tCi identifier: {Ci}")
                print(f"\tHori sampling factor: {Hi}")
                print(f"\tVert sampling factor: {Vi}")
                print(f"\tQuantization table: {Tqi}")

                assert Hi == 1, "chroma subsampling not supported"
                assert Vi == 1, "chroma subsampling not supported"

                assert Ci not in quant_table_map, "duplicate Ci"
                quant_table_map[Ci] = Tqi
            
            assert P == 8, "only 8bpp supported"
            assert Nf == 3, "only 3-component images supported"
    
    # decode SOS
    # again we have already skipped over Ls
    print(f"~~~~~ SOS ~~~~~")
    Ns = jpeg.scan_hdr[0]
    print(f"Number of components: {Ns}")
    component_info = []
    for cidx in range(Ns):
        component_params = jpeg.scan_hdr[1+cidx*2:1+cidx*2+2]
        Csj = component_params[0]
        Tdj = component_params[1] >> 4
        Taj = component_params[1] & 0xf
        print(f"Component index {cidx}:")
        print(f"\tCi identifier: {Csj}")
        print(f"\tDC entropy coding table: {Tdj}")
        print(f"\tAC entropy coding table: {Taj}")

        component_info.append(JpegComponentInfo(quant_table_map[Csj], Tdj, Taj))

    assert Ns == 3, "only 3-component scans supported"
    
    return JpegSof0SosInfo(X, Y, component_info)

In [None]:
parsed_sof_sos_info = parse_sof_sos(parsed_jpeg)

From this information, we now have the dimensions of the image. Our example image happens to be 6000 pixels wide and 4000 pixels tall. There are 8 bits per pixel, and there are three color components per pixel. We also have various indices which associate color components with data tables for "quantization," "DC entropy coding," and "AC entropy coding." These techniques are the core of what makes JPEG actually able to compress images as well as it does.

One important bit of information that we are skipping over in this demo are the horizontal and vertical sampling factors. These factors are used for indicating [chroma subsampling](https://en.wikipedia.org/wiki/Chroma_subsampling), a process where color information is stored at a lower resolution compared to the brightness information. We can often get away with doing this because of how the human visual system works, where it is much more sensitive to changes in brightness than it is to changes in color. Support for chroma subsampling is quite common in formats for "multimedia" (such as JPEG and Blu-Ray) but is less common in formats for "computer graphics" (such as PNG or GIF). For the image we are working with, the sampling factors are all 1 which indicates that chroma subsampling has not been applied (also known as "4:4:4" sampling).

Something which is explicitly **not** specified in T.81 itself is how to actually interpret the three color components ("Application-dependent information, e.g. colour space, is outside the scope of this Specification.") We only know that there are three of them. We will revisit this problem later.

### Quantization tables

Before we can start decoding, we need to read the data tables referenced by the `SOF0` and `SOS` metadata. These tables are are stored in their corresponding marker segments. We will start with the quantization tables first as they are simpler.

In [17]:
def parse_dqt(jpeg):
    quant_tables = [None] * 4
    for seg in jpeg.marker_segs:
        if seg.marker == 0xffdb:
            # for 8-bit tables, each table is 65 bytes long
            for i in range(len(seg.data) // 65):
                Pq = seg.data[i*65] >> 4
                Tq = seg.data[i*65] & 0xf
                table_i = seg.data[i*65 + 1:(i+1)*65]

                print(f"Got quantization table {Tq}")
                quant_tables[Tq] = table_i
                assert Pq == 0, "only 8-bit quantization tables supported"
    return quant_tables

In [None]:
quant_tables = parse_dqt(parsed_jpeg)

Each quantization table is just an array of 64 numbers!

### Huffman tables

The next set of tables we need to read are those for DC and AC entropy coding. _Entropy coding_ is a category of techniques for performing _lossless_ data compression, storing information in (hopefully) less space while still allowing exactly the same original data to be recovered at the end. Even though JPEG as a whole (at least in the DCT modes being used here) is a form of _lossy_ compression (the output is not exactly the same as the original data), it still makes use of lossless compression as a sub-step. The supported entropy coding techniques are _Huffman coding_ and _arithmetic coding_. Our image is using Huffman coding.

Unlike all of the previous wrangling of bytes and file structures, Huffman coding is the first "fancy algorithm" we have run into. We will present some code which parses the JPEG Huffman trees before explaining how it works:

In [19]:
# store codes as a binary tree
def add_codeword_to_tree(tree, code, code_len, code_val):
    for biti in range(code_len):
        # start at the MSB
        bit = 1 if code & (1 << (code_len - 1 - biti)) else 0
        
        if biti == code_len - 1:
            # Leaf
            assert tree[bit] is None
            tree[bit] = code_val
        else:
            # Intermediate node
            if tree[bit] is None:
                tree[bit] = [None, None]
            tree = tree[bit]

def parse_dht(jpeg):
    huff_tables_ac = [None] * 2
    huff_tables_dc = [None] * 2
    for seg in jpeg.marker_segs:
        if seg.marker == 0xffc4:
            huff_data = seg.data
            while huff_data:
                huff_tree = [None, None]
                # indices
                Tc = huff_data[0] >> 4
                Th = huff_data[0] & 0xf
                # L_i
                num_codewords = huff_data[1:17]
                huff_data = huff_data[17:]
                print(f"Got Huffman table {Tc}, {Th}")

                min_code_len = 0
                for (i, num) in enumerate(num_codewords):
                    if num != 0:
                        min_code_len = i + 1
                        break
                assert min_code_len > 0, "no codewords!"

                code_wip = 0
                # for each code length...
                for code_len in range(min_code_len, 17):
                    num_codewords_of_this_len = num_codewords[code_len - 1]
                    # ...read the list of values
                    for _j in range(num_codewords_of_this_len):
                        code_val = huff_data[0]
                        huff_data = huff_data[1:]
                        print(f"{code_wip:0{code_len}b} = 0x{code_val:02x}")
                        add_codeword_to_tree(huff_tree, code_wip, code_len, code_val)
                        code_wip += 1
                    code_wip <<= 1

                if Tc == 0:
                    huff_tables_dc[Th] = huff_tree
                elif Tc == 1:
                    huff_tables_ac[Th] = huff_tree
                else:
                    assert False, "invalid Tc"

    return (huff_tables_dc, huff_tables_ac)

In [None]:
huff_tables = parse_dht(parsed_jpeg)

## Huffman and prefix codes

The fundamental idea behind Huffman coding is that we can make data take up less space if we use fewer bits to refer to the more common data items and more bits to refer to the less common data items. The input data items are called the _source symbols_, the bits that we use to represent them are called _codewords_, and the particular mapping between the data items and the codewords is called a _code_. Huffman _coding_ is one particular technique (an algorithm) for creating codes which are provably "optimal" under certain assumptions. (It is not possible for a lossless compression algorithm to be optimal in all cases. [Relevant xkcd](https://xkcd.com/1381/).)

Huffman coding generates a _prefix code_, a code where no codeword is a prefix of another codeword. Prefix codes are useful because they can be decoded without ambiguity without having to put any separators between the codewords (_uniquely decodable_).

### Examples

For example, this is a prefix code:

| Data   | Codeword |
| ------ | -------- |
| `a`    | 01       |
| `b`    | 100      |
| `c`    | 101      |

If we were to assume that each input data item was originally stored using 8 bits (one byte, perhaps no other bytes beyond `a`/`b`/`c` are ever used?), this particular code easily achieves a compression ratio of over 50% because it is able to store each item using only 2-3 bits.

If we change up some of the codeword assignments, then the following is no longer a prefix code, because the code for `a` (01) is a prefix of the code for `c` (011):

| Data   | Codeword |
| ------ | -------- |
| `a`    | 01       |
| `b`    | 100      |
| `c`    | 011      |

However, this code is still uniquely decodable! We can show this by coming up with a procedure for decoding it. An ambiguity can potentially arise only when we encounter a 01 input (is it going to be `a` or `c`?). To create a decoding procedure, we will list out all possibilities for the next bits which can follow and then specify what our decoder will do upon encountering each of them:

| Input | Decode                                                                             |
| ----- | ---------------------------------------------------------------------------------- |
| 010   | `a-` (consume 2 bits)                                                              |
| 01100 | `ab` (consume 5 bits; there is no valid 00x codeword, so it cannot decode as `c?`) |
| 01101 | `c-` (consume 3 bits)                                                              |
| 01110 | `cb` (consume 5 bits; the next bit must be 0 or else the input is invalid)         |
| 01111 | invalid (no code can start with 11)                                                |

In mathematical terms, prefix codes are _sufficient_ but not _necessary_ to be uniquely decodable.

However, all of these prefix code examples quite obviously have redundant information. (This can be verified more formally using [Kraft's inequality](https://en.wikipedia.org/wiki/Kraft%E2%80%93McMillan_inequality).) For example, it's possible to do the following:

| Data   | Code |
| ------ | ---- |
| `a`    | 0    |
| `b`    | 10   |
| `c`    | 11   |

This is also a prefix code, and it is shorter than all of the above examples.

In contrast this is a not a prefix code, and it is also not uniquely decodable:

| Data   | Code |
| ------ | ---- |
| `a`    | 1    |
| `b`    | 11   |
| `c`    | 111  |

It is not possible to tell whether a long string of 1s, such as `111111` is supposed to be all `a`s, `b`s, `c`s, or any other combination of source symbols without separators.

### Huffman as used in JPEG

Because of the optimality of Huffman coding, the name "Huffman code" is often used to refer to any prefix code regardless of how the code was actually created. When we are _decoding_ as we are now, we don't need to actually understand how to construct a Huffman code or otherwise make a code which minimizes redundant information. We only need to understand how to decode the prefix code stored in the JPEG file we are reading. In order to do this, we need to have the map from codewords back to (source) symbol values. Instead of being listed out explicitly like in the above example tables, this information is encoded in the `DHT` segment in an abbreviated form, and we need to understand how that works.

JPEG uses _length-limited canonical Huffman codes_ in order to store this mapping information even more efficiently. A _canonical_ Huffman code is a prefix code where the numerical values of the codewords (i.e. the particular `0`/`1` bits) are generated in a specific way (sorting by length, generating sequential binary values for each code of a given length, and appending 0 bits when the next length is reached). This _canonical_ assignment of codewords allows the table to be represented by only storing the _total number_ of codewords of each particular bit length and the associated source symbols. The actual bits which make up each of the codewords does not have to be stored but are instead reconstructed by following the same algorithm used to generate them.

For example, the following code (which is canonical according to this algorithm):

| Data   | Code |
| ------ | ---- |
| `a`    | 0    |
| `b`    | 10   |
| `c`    | 11   |

can instead be stored as the following:

| Codeword length | # codewords of this length | Source data symbols list |
| --------------- | -------------------------- | ------------------------ |
| 1               | 1                          | `a`                      |
| 2               | 2                          | `b`, `c`                 |

This can be turned back into the full table above by following the specific codeword assignment algorithm to calculate the codewords for length 1, followed by those for length 2 (and so on, for codes with more codewords).

### Our implementation

In our code, we will store prefix codes as binary trees. Every time an input bit is consumed, the tree is traversed one level until a leaf node is eventually reached containing a source symbol value. This works and is a common "theoretical" way of representing prefix codes but is not the most computationally efficient way of doing this.

## Overview of next steps

At this point, we have all of the data tables loaded and are almost ready to start decoding image scans (which contain the "actual data" we care about). However, before doing this, we are going to want a high-level overview of what to expect next. This is because information in the JPEG bitstream is "all mixed up":

```text
... 0101100011000111000101010111001010110001011001011111110100101010111000010001000011011100100010000101 ...
    \_/\__/\___/\__/\__/\___/\_/\____/\___/\_/\__/\__/\__/\___/\_/\____/\_/\____/ ...
     |  |    |   |  ... (more Huffman codewords and coefficient values)
     |  |    |   +- AC coefficient value
     |  |    +----- AC Huffman codeword (color component 0)
     |  +---------- DC coefficient value
     +------------- DC Huffman codeword (color component 0)

    \_____/\_______/\_______/\_______/\______/\______/\_______/\_______/\_______/ ...
       |       |        |        |       |       |        |        |        |
       |       +--------+--------+------ | ------+--------+--------+--------+-- AC coefficients
       +---------------------------------+------------------------------------- DC coefficients
    \________________________________/\_________________________________________/ ...
                    |                                       |                       |
                    |                                       |                       +- 8x8 block for color component 2  \
                    |                                       +------------------------- 8x8 block for color component 1  +--\
                    +----------------------------------------------------------------- 8x8 block for color component 0  /  |
                                                                                                                           |
                                                               minimum coded unit, repeats until entire image is encoded --/
```

In addition to being "all mixed up" and interleaved, each 8x8 block can contain a different number of "AC coefficients" due to _run-length encoding_, and each color component can use a different Huffman code.

In order to make sense of all of this, the following is the procedure for _encoding_ a JPEG using the _baseline DCT_ mode:
1. Do transformations related to color space and chroma subsampling. This eventually gives us three separate color components. This step is "more abstracted away" from the bitstream handling we are about to do, and so we can ignore it and come back to it later.
2. Break each color component up into 8x8 pixel blocks.
3. Perform a _discrete cosine transform_ (DCT) on each of these blocks. There is a lot of mathematical theory about how the DCT works, but it fundamentally just _changes each block of numbers (64 pixels) into a different block of numbers (64 coefficients)_. The only bit of the theory we need at this point is to know that _one_ of the numbers is referred to as the "DC" coefficient, and the remaining 63 are referred to as the "AC" coefficients.
4. Take the output coefficients and perform _quantization_ on the numbers. This takes all of the numbers and makes them smaller (as well as less accurate due to rounding). This step is _lossy_ and throws away information from the original image. The idea here is to hopefully try to only throw away information which is less perceptible to humans.
5. Perform what I will call "some miscellaneous transformations" on these blocks of numbers.
    1. The one "DC coefficient" in each block is stored as the _difference from the previous DC coefficient_ (i.e. the difference between the DC coefficient of the current 8x8 block minus the DC coefficient of the previous 8x8 block). This step is called _differential DC encoding_. The hope is that these differences will tend to be smaller in magnitude compared to the absolute values of the coefficients. This hope is justified based on the interpretation of what the DCT does as well as observations of what "typical" image data looks like.
    2. The 63 "AC coefficients" are rearranged into a zigzag order compared to the "normal" order in which the outputs of a DCT would be written. This helps to cluster together the coefficients which contribute to the image in a "similar" way (again, this "similarity" is explained based on the interpretation of what the DCT does).
    3. Perform _run-length encoding_ on any 0s that might appear in the AC coefficients. Run-length encoding replaces _runs_ of 0s (i.e. several 0s in a row) with one single symbol representing "there are $n$ 0s here." The combination of the DCT and quantization steps was designed to try to generate a lot of 0s for this step to encode.
6. Put the numbers through the entropy coding sub-step (Huffman coding in this case).
    1. The scale (number of bits) of the DC coefficient is encoded using the DC Huffman table for this color component, and then the DC coefficient difference is stored directly (using the minimum number of bits necessary). The differences themselves are _not_ Huffman encoded.
    2. The run length (number of 0s which come before this coefficient) and scale of the AC coefficients is packed together and then encoded using the AC Huffman table for this color component. The coefficient value is again stored directly.

To restate the result of this procedure once again, the final bitstream in a JPEG image contains a *mix* of Huffman codewords (from two different code tables) and data values (DC coefficient differences, AC coefficients). The exact number of values can also vary across 8x8 blocks due to run-length encoding. Even though decoding is the opposite of encoding, because all of these different pieces are all mixed together, the separation between the steps isn't very "clean". In this demo, we will be handling many of the steps all at once.

One final tiny piece of information we need is the _restart interval_, the number of blocks after which certain state (in this case, the DC difference value) is reset. This is stored by itself in a marker segment.

_Aside:_ The above explanation has not made a clear distinction between 8x8 pixel blocks and _minimum coded units_. This distinction is important when chroma subsampling is used, but it is much less critical for our example where the chroma is not subsampled.

In [21]:
def get_restart_interval(jpeg):
    rst_interval = -1
    for seg in jpeg.marker_segs:
        if seg.marker == 0xffdd:
            rst_interval = struct.unpack(">H", seg.data)[0]
    return rst_interval

In [None]:
restart_interval = get_restart_interval(parsed_jpeg)
print(f"Restart interval: {restart_interval} MCUs")

## Actually decoding scan data

The following code helps with processing the bitstream bit-by-bit (i.e. data does not align with byte boundaries):

In [23]:
# bit manip helpers

def get_bit(data, pos):
    (byte_pos, bit_pos) = pos
    bit_val = 1 if data[byte_pos] & (1 << (7 - bit_pos)) else 0
    if bit_pos != 7:
        bit_pos += 1
    else:
        bit_pos = 0
        byte_pos += 1
    return (bit_val, (byte_pos, bit_pos))

def get_bits(data, pos, nbits):
    val = 0
    for _i in range(nbits):
        (b, pos) = get_bit(data, pos)
        val = (val << 1) | b
    return (val, pos)

The following code pulls one "thing" from the bitstream, either a Huffman codeword or a coefficient. It pulls exactly as many bits as are required.

In [24]:
# traverse the Huffman table to get one symbol
def decode_huff(data, pos, huff_table):
    while True:
        (b, pos) = get_bit(data, pos)
        huff_table = huff_table[b]
        if isinstance(huff_table, int):
            return (huff_table, pos)

# grab a group of bits for a coefficient, converting negative numbers back
# see F.1.2.1.1
def decode_coeff(data, pos, nbits):
    if nbits == 0:
        return (0, pos)
    
    (val, pos) = get_bits(data, pos, nbits)
    if val & (1 << (nbits - 1)):
        # msb is 1 --> positive value
        return (val, pos)
    else:
        # negative value
        val = -(1 << nbits) + val + 1
        return (val, pos)

The following code handles Huffman decompression, run-length decompression of AC coefficients, and decoding the differential DC encoding. This is all done at once.

In [25]:
# decode one 8x8 block, known as a "data unit"
# see F.1.2.1 and F.1.2.2
def decode_block(data, pos, dc_huff_table, ac_huff_table, pred=0):
    # initialize output to all 0s
    coeffs = [0] * 64

    # decode DC difference
    (dc_sym, pos) = decode_huff(data, pos, dc_huff_table)
    assert dc_sym in range(0, 12), "invalid DC coeff scale"
    (dc_diff, pos) = decode_coeff(data, pos, dc_sym)
    coeffs[0] = dc_diff + pred

    # decode AC coefficients, including runs of 0s
    ac_i = 1
    while ac_i < 64:
        (ac_sym, pos) = decode_huff(data, pos, ac_huff_table)
        if ac_sym == 0xf0:
            # 16 zeros
            ac_i += 16
        elif ac_sym == 0x00:
            # end-of-block marker
            break
        else:
            rrrr = ac_sym >> 4
            ssss = ac_sym & 0xf
            assert ssss in range(1, 11), "invalid AC coeff scale"
            # a run of rrrr zeros
            ac_i += rrrr
            # followed by one AC coefficient
            (coeffs[ac_i], pos) = decode_coeff(data, pos, ssss)
            ac_i += 1
    
    return (coeffs, pos)

# decode one MCU, consisting of a data unit for each component (interleaved)
# see E.2.5
def decode_mcu(data, pos, sof_sos_info, dc_huff_tables, ac_huff_tables, preds=[0, 0, 0]):
    coeffs = [None] * 3
    for component_i in range(3):
        dc_huff = dc_huff_tables[sof_sos_info.component_info[component_i].dc_idx]
        ac_huff = ac_huff_tables[sof_sos_info.component_info[component_i].ac_idx]

        (coeffs[component_i], pos) = decode_block(data, pos, dc_huff, ac_huff, preds[component_i])
        # make sure to update the DC predictor after decoding a block
        preds[component_i] = coeffs[component_i][0]
    
    return (coeffs, pos)

# decode a restart interval
# see E.2.4
def decode_restart_interval(data, sof_sos_info, dc_huff_tables, ac_huff_tables, restart_interval=-1):
    # if there is no restart interval, calculate how many MCUs are in the image
    if restart_interval == -1:
        mcu_w = divroundup(sof_sos_info.x, 8)
        mcu_h = divroundup(sof_sos_info.y, 8)
        restart_interval = mcu_w * mcu_h

    # remove byte stuffing (see F.1.2.3)
    data_ = b''
    i = 0
    while i < len(data):
        if data[i] != 0xff:
            data_ += data[i:i+1]
            i += 1
        else:
            if data[i+1] == 0x00:
                # remove stuffing
                data_ += b'\xff'
                i += 2
            else:
                # we must be at the end of the data
                for j in range(i, len(data)):
                    assert data[j] == 0xff
    data = data_

    # output
    mcus = [None] * restart_interval
    # current position
    pos = (0, 0)
    # reset DC predictors
    preds = [0, 0, 0]

    for i in range(restart_interval):
        (mcus[i], pos) = decode_mcu(data, pos, sof_sos_info, dc_huff_tables, ac_huff_tables, preds)
    
    # for debugging, check to make sure all the trailing bits are 1
    if pos[0] != len(data):
        assert pos[0] == len(data) - 1
        last_byte = data[pos[0]]
        for biti in range(pos[1], 8):
            assert last_byte & (1 << (7 - biti))
    
    return mcus

In [26]:
def decode_jpeg_scan(scan_data, sof_sos_info, dc_huff_tables, ac_huff_tables, restart_interval=-1):
    mcus = []
    for i in range(len(scan_data)):
        # print(f"Decoding {i+1}/{len(scan_data)}")
        this_mcus = decode_restart_interval(scan_data[i], sof_sos_info, dc_huff_tables, ac_huff_tables, restart_interval)
        mcus += this_mcus
    print(f"Decoded {len(mcus)} MCUs in total")
    return mcus

In [None]:
decode_jpeg_mcus = decode_jpeg_scan(parsed_jpeg.scan_data, parsed_sof_sos_info, huff_tables[0], huff_tables[1], restart_interval)

Well, this implementation certainly isn't winning any benchmarks. However, we *have* successfully decoded all of the MCUs. This means that all of the steps related to "entropy coding" have been taken care of!

### Dequantization

The next step in the decoding process is to undo the _quantization_ procedure. This is spelled out in A.3.4, but it simply boils down to... multiplying each coefficient by the corresponding value in the _quantization tables_. Since both the coefficients and the quantization table are currently still in zigzag order, we can do this step *without* any shuffling or rearranging.

In [28]:
def dequantize_jpeg(mcus, sof_sos_info, quant_tables):
    for mcu_i in range(len(mcus)):
        for component_i in range(3):
            q_table = quant_tables[sof_sos_info.component_info[component_i].quant_idx]
            for coeff_i in range(64):
                mcus[mcu_i][component_i][coeff_i] *= q_table[coeff_i]

In [29]:
dequantize_jpeg(decode_jpeg_mcus, parsed_sof_sos_info, quant_tables)

At this point, we have dequantized data that is ready to go into the inverse DCT transformation.

But... is there any quick and dirty way to check our work along the way? It feels like we should be at a natural abstraction boundary, as we have just turned a stream of bits into nice blocks of numbers. It would certainly not be fun to wade through all of the remaining calculations just to end up with the wrong output with no clue where the errors could have come from.

In fact, there is a way to check! However, the explanation of _why_ this quick and dirty trick works requires a deeper understanding of the DCT which will come later (we are plotting just the DC terms of the luma component, giving us a downscaled version of the original image).

In [None]:
def demo_quick_check_dequantize():
    ylist = []
    h_mcus = divroundup(parsed_sof_sos_info.y, 8)
    w_mcus = divroundup(parsed_sof_sos_info.x, 8)
    for y in range(h_mcus):
        xlist = []
        for x in range(w_mcus):
            xlist.append(decode_jpeg_mcus[y * w_mcus + x][0][0])
        ylist.append(xlist)
    plt.pcolormesh(ylist[::-1])
demo_quick_check_dequantize()

### The inverse DCT, as naively as possible

We are finally at the point where we have to deal with inverting the _discrete cosine transform_ (DCT). Formulas are given in section A.3.3.

We also handle zigzag reordering at the same time.

In [31]:
DCT_BASIS_IDX_TO_ZIGZAG = [
    0, 1, 5, 6, 14, 15, 27, 28,
    2, 4, 7, 13, 16, 26, 29, 42,
    3, 8, 12, 17, 25, 30, 41, 43,
    9, 11, 18, 24, 31, 40, 44, 53,
    10, 19, 23, 32, 39, 45, 52, 54,
    20, 22, 33, 38, 46, 51, 55, 60,
    21, 34, 37, 47, 50, 56, 59, 61,
    35, 36, 48, 49, 57, 58, 62, 63,
]

For the first attempt, we can try to copy the formula as directly as possible:

In [32]:
# compute the IDCT on an 8x8 data unit, blindly according to the formula
def idct_block(output, mcu_x, mcu_y, mcu_w, coeffs):
    px_xbase = mcu_x * 8
    px_ybase = mcu_y * 8
    px_w = mcu_w * 8

    for y in range(8):
        for x in range(8):
            s_yx = 0
            # summation symbols are just for loops
            for v in range(8):
                for u in range(8):
                    # constant
                    if u == 0 and v == 0:
                        # 1/4 * 1/sqrt(2) * 1/sqrt(2) = 1/4 * 1/2 = 1/8
                        a = 1 / 8
                    elif u == 0 or v == 0:
                        # 1/4 * 1/sqrt(2) * 1
                        a = 1 / (4 * sqrt(2))
                    else:
                        a = 1 / 4

                    # undo the zigzag ordering (implicitly, without having to store the unscrambled array)
                    S_vu = coeffs[DCT_BASIS_IDX_TO_ZIGZAG[v * 8 + u]]

                    s_yx += a * S_vu * cos(((2*x+1)*u/16)*pi) * cos(((2*y+1)*v/16)*pi)

            output[(px_ybase + y) * px_w + (px_xbase + x)] = s_yx

Although the above code does technically work, attempting to run it will take an unreasonable amount of time to complete. (You can test this by running the above code cell to replace the `idct_block` function and then rerunning the cell below which computes `decoded_jpeg_components`.)

In order to both speed it up and later experiment with some interesting mathematical properties of the DCT, we are going to take advantage of the _linearity_ of the DCT.

### Linearity

A function is called _linear_ if:
1. for any two inputs $x$ and $y$, $f(x+y) = f(x) + f(y)$. This property is called _additivity_.
2. for any input $x$ and scale factor $a$, $f(ax) = af(x)$. This property is called _homogeneity_.

_Note:_ This meaning of linearity which is being used here (_linear maps_) does *not* refer to _linear polynomials_ of the form $y=mx+b$ which are typically taught in the compulsory education system.

A linear function has a very useful property of obeying the _superposition principle_. This principle means that the _same_ result will be obtained from putting some input into a function as would be obtained by breaking apart that input (into pieces which sum together), applying the function to each piece individually, and then adding together those smaller results.

### Why does linearity matter?

The input that we want to put in to the IDCT operation is a set of 64 coefficients (denoted symbolically as $S_{vu}$). Instead of having to compute the entire formula involving summing up cosines over and over again, we can apply linearity and the superposition principle as follows:

1. Precompute the IDCT for the case when $S_{00} = 1$ and every other $S_{vu} = 0$. This gives us an output $s_{yx}$ which we can store as an 8x8 array of numbers.
2. Precompute the IDCT for the case when $S_{01} = 1$ and every other $S_{vu} = 0$. This gives us an output $s_{yx}$ which we can store as an 8x8 array of numbers.
3. Repeat the computation for every combination of $S_{vu}$, which gives us an 8x8 array where each item is itself an 8x8 array of numbers.
4. For each block of actual data we want to compute the IDCT of:
    1. Initialize a temporary 8x8 array of numbers to all 0s
    2. For element $S_{00}$ of the actual data input, look for the precomputed data where $S_{00} = 1$ and every other $S_{vu} = 0$. Scale this precomputed data by the value of $S_{00}$ in the actual data input (relying on the homogeneity property), and then add it to the temporary array (relying on the additivity property).
    3. Repeat for every element of the input.

These precomputed values can be called the IDCT "basis functions."

### In code?

Instead of proving the linearity property with mathematical symbols (a reasonably-common exercise which can be found in many signal processing textbooks), we are going to instead perform refactoring on the above code until the linearity property becomes readily apparent.

To rephrase, our goal right now is to take the above code and refactor it so that it is somehow faster. One thing we can notice is that there are only $8*8*8*8$ possible values for the $constant * cosine_1 * cosine_2$ part of the calculation, because each of `x`/`y`/`u`/`v` only ever range from $[0, 8)$. We can try precomputing them:

In [26]:
def precompute_idct():
    result = []
    for y in range(8):
        for x in range(8):
            for v in range(8):
                for u in range(8):
                    # constant
                    if u == 0 and v == 0:
                        # 1/4 * 1/sqrt(2) * 1/sqrt(2) = 1/4 * 1/2 = 1/8
                        a = 1 / 8
                    elif u == 0 or v == 0:
                        # 1/4 * 1/sqrt(2) * 1
                        a = 1 / (4 * sqrt(2))
                    else:
                        a = 1 / 4
                    
                    # constant * cosine_1 * cosine_2
                    val = a * cos(((2*x+1)*u/16)*pi) * cos(((2*y+1)*v/16)*pi)

                    result.append(val)
    return result
idct_demo_precompute = precompute_idct()

# compute the IDCT on an 8x8 data unit, with some precomputation
def idct_block(output, mcu_x, mcu_y, mcu_w, coeffs):
    px_xbase = mcu_x * 8
    px_ybase = mcu_y * 8
    px_w = mcu_w * 8

    for y in range(8):
        for x in range(8):
            s_yx = 0
            # summation symbols are just for loops
            for v in range(8):
                for u in range(8):
                    # undo the zigzag ordering (implicitly, without having to store the unscrambled array)
                    S_vu = coeffs[DCT_BASIS_IDX_TO_ZIGZAG[v * 8 + u]]

                    # NOTE that we are now looking up precomputed values here!
                    s_yx += S_vu * idct_demo_precompute[y * 8 * 8 * 8 + x * 8 * 8 + v * 8 + u]

            output[(px_ybase + y) * px_w + (px_xbase + x)] = s_yx

This version of the code is indeed faster as we were hoping! (It is still not exactly _fast_. It turns out that image processing is quite computationally expensive!)

In order to see "those thingies" featured in the meme at the very beginning of this notebook, we can plot the 4096-element array as an $8 \times 8$ grid of $8 \times 8$ values. In order to make it visually match, we just need to shuffle some indices around:

In [None]:
def demo_show_idct_basis_biglist():
    fig, axs = plt.subplots(8, 8)
    for v in range(8):
        for u in range(8):
            shuffled_basis_func = []
            for y in range(8):
                row = []
                for x in range(8):
                    row.append(idct_demo_precompute[y * 8 * 8 * 8 + x * 8 * 8 + v * 8 + u])
                shuffled_basis_func.append(row)
            axs[v, u].pcolormesh(shuffled_basis_func)
demo_show_idct_basis_biglist()

So far, we have made use of the fact that multiplication is _commutative_, but we haven't actually needed the linearity property yet. If all we wanted to do was to make the code somewhat faster, we can stop here.

To help prepare for later linear algebra hax, we want to rearrange the 4096-element array into a $64 \times 64$ nested array ordered more like the above plots. We want the first index to correspond to one specific subplot and the second index to correspond to one pixel within the subplot. In order to do this, we swap the `xy` and `uv` loop order:

In [None]:
def compute_idct_basis():
    result = [None] * 64
    # xy and vu loop order have been changed, but it still loops over all the same things
    for v in range(8):
        for u in range(8):
            # intermediate array
            basis_func = [None] * 64

            # constant
            # moved out of xy loop, since the value does not depend on xy
            if u == 0 and v == 0:
                # 1/4 * 1/sqrt(2) * 1/sqrt(2) = 1/4 * 1/2 = 1/8
                a = 1 / 8
            elif u == 0 or v == 0:
                # 1/4 * 1/sqrt(2) * 1
                a = 1 / (4 * sqrt(2))
            else:
                a = 1 / 4

            # now we finally get the xy loop
            for y in range(8):
                for x in range(8):
                    # constant * cosine_1 * cosine_2
                    val = a * cos(((2*x+1)*u/16)*pi) * cos(((2*y+1)*v/16)*pi)
                    basis_func[y*8 + x] = val

            # save it
            result[v*8 + u] = basis_func
    return result
idct_basis = compute_idct_basis()

The variable `idct_basis` now contains something much closer to "those thingies," and we can once again plot them with a little bit of code (although we are forced to do yet more rearranging so that the plots show up the way we want):

In [None]:
def demo_show_idct_basis():
    fig, axs = plt.subplots(8, 8)
    for v in range(8):
        for u in range(8):
            basis_func = idct_basis[v*8 + u]
            shuffled_basis_func = []
            for y in range(8):
                row = []
                for x in range(8):
                    row.append(basis_func[y*8 + x])
                shuffled_basis_func.append(row)
            axs[v, u].pcolormesh(shuffled_basis_func)
demo_show_idct_basis()

A refactoring step which *does* actually require the linearity property is to swap the `xy` and `uv` loop order in the IDCT computation itself just as we swapped them in the precomputation. This groups together the memory lookups which depend on `uv` and separates them from the ones which depend on `xy`, which doesn't speed up Python but can speed up vectorized native code implementations:

In [35]:
# compute the IDCT on an 8x8 data unit, with rearranged precomputation
def idct_block(output, mcu_x, mcu_y, mcu_w, coeffs):
    px_xbase = mcu_x * 8
    px_ybase = mcu_y * 8
    px_w = mcu_w * 8

    # xy and vu loop order have been changed, but it still loops over all the same things
    for v in range(8):
        for u in range(8):
            # undo the zigzag ordering (implicitly, without having to store the unscrambled array)
            S_vu = coeffs[DCT_BASIS_IDX_TO_ZIGZAG[v * 8 + u]]
            # look up the appropriate basis function corresponding to this coefficient
            this_coeff_basis_func = idct_basis[v * 8 + u]

            for y in range(8):
                for x in range(8):
                    # sum is rearranged a bit
                    output[(px_ybase + y) * px_w + (px_xbase + x)] += S_vu * this_coeff_basis_func[y * 8 + x]

This final version `idct_block` function takes advantage of linearity as described. Convince yourself that swapping the `xy` and `uv` loop order results in the exact same values being summed in either case!

Now that we have an implementation of the IDCT, we can move on to performing IDCT on all color components of the entire image. This will take a while. Real implementations gain additional speedups by exploiting mathematical symmetries and patterns in the DCT coefficients (using _fast cosine transform_ algorithms), but we will not be doing that in this proof-of-concept. By *not* doing this, we keep the ability to operate with _arbitrary linear transforms_, which we will make use of for demos later in the notebook.

In [37]:
# compute all IDCTs for an entire image component
# also happens to rearrange from a list of MCUs into the final 2-D grid
def idct_component(mcus, sof_sos_info, component_i):
    mcu_h = divroundup(sof_sos_info.y, 8)
    mcu_w = divroundup(sof_sos_info.x, 8)
    # initialize the output to zeros
    img = [0] * (mcu_w * 8) * (mcu_h * 8)
    for y_mcu in range(mcu_h):
        for x_mcu in range(mcu_w):
            # process one MCU
            mcu_i = y_mcu * mcu_w + x_mcu
            coeffs = mcus[mcu_i][component_i]
            idct_block(img, x_mcu, y_mcu, mcu_w, coeffs)
            if mcu_i and mcu_i % 5000 == 0:
                print(f"IDCT done on MCU {mcu_i}/{len(mcus)}")
    return img

# compute all IDCTs for an entire image (all three components)
def idct_image(mcus, sof_sos_info):
    components = [None] * 3
    for component_i in range(3):
        print(f"Working on component {component_i}")
        components[component_i] = idct_component(mcus, sof_sos_info, component_i)
    return components

In [None]:
decoded_jpeg_components = idct_image(decode_jpeg_mcus, parsed_sof_sos_info)

Let's see what it gives us as a result:

In [None]:
def demo_plot_components():
    for component_i in range(3):
        ylist = []
        h = divroundup(parsed_sof_sos_info.y, 8) * 8
        w = divroundup(parsed_sof_sos_info.x, 8) * 8
        for y in range(h):
            xlist = []
            for x in range(w):
                xlist.append(decoded_jpeg_components[component_i][y * w + x])
            ylist.append(xlist)
        plt.pcolormesh(ylist[::-1])
        plt.show()
demo_plot_components()

### Color space conversion

We now have three image components, and we want to convert this into "real" colors rather than the false-color representation used by the above plots. This necessitates doing _color space conversion_.

JPEG itself does not actually specify *how* colors are encoded into the three image components. This image actually contains an [ICC profile](https://en.wikipedia.org/wiki/ICC_profile) specifying how this should be done, and this ICC profile is stored in the APP2 marker segment.

In [None]:
def demo_show_dump_icc():
    for seg in parsed_jpeg.marker_segs:
        if seg.marker == 0xffe2:
            print(seg.data[:32])
            with open("test.icc", 'wb') as f:
                f.write(seg.data[len(b'ICC_PROFILE\x00\x00\x00'):])
demo_show_dump_icc()

If we open this ICC profile in some software which understands it (e.g. in the builtin macOS viewer), we find that it is describing [sRGB](https://en.wikipedia.org/wiki/SRGB), an average model of computer monitors in the late 1990s and the default color space for anything which does not have an ICC profile.

The colors in the JPEG are encoded using Y'CbCr, where Y' corresponds to brightness (black and white), Cb corresponds to "blueness", and Cr corresponds to "redness". Y' is also referred to as _luma_ and Cb and Cr as _chroma_. Color management is a very complicated topic which is *very* out of scope for this tutorial (also, I the author don't understand it well), so we will simply blindly copy Y'CbCr <-> sRGB conversion formulas from Wikipedia.

In [32]:
def convert_color_space(components, w):
    h = len(components[0]) // w
    im = Image.new('RGB', (w, h))
    im_data = im.load()

    for y in range(h):
        for x in range(w):
            y_ = components[0][y * w + x]
            cb = components[1][y * w + x]
            cr = components[2][y * w + x]

            # adjust bias (A.3.1)
            y_ += 128
            cb += 128
            cr += 128

            # convert to RGB
            r = 298.082/256 * y_ + 408.582/256 * cr - 222.921
            g = 298.082/256 * y_ - 100.291/256 * cb - 208.120/256 * cr + 135.576
            b = 298.082/256 * y_ + 516.412/256 * cb - 276.836

            # round and clamp
            r = clamp(r)
            g = clamp(g)
            b = clamp(b)

            im_data[(x, y)] = (r, g, b)

    return im

In [None]:
final_decoded_image = convert_color_space(decoded_jpeg_components, divroundup(parsed_sof_sos_info.x, 8) * 8)
display(final_decoded_image)
final_decoded_image.save("our_own_decode.png")

**SUCCESS!!!** We have successfully decoded the image!

## Encoding, very suboptimally

Now that we have successfully done some decoding, let's try to invert all of the steps and encode our own JPEG. We will start with this (much smaller) image of Tux:

![Tux the penguin](imgs/Tux2.png)

This image is just under 40000 bytes.

For our first attempt, we are going to _intentionally_ make some suboptimal choices. This will help highlight where choices _can_ even be made, and it will also give us something to compare against.

### Loading the source image

The first thing we are going to need to do is to load the source image and turn it into a grid of pixel values. Since this is not a tutorial about the PNG format, we are going to use an existing software library to do this. In the meantime, we will also perform color space conversion to Y'CbCr as well as padding out the image dimensions to a multiple of 8. (JPEG can only really encode images whose dimensions align with the size of its minimum coded units. Everything else has to be padded with filler data. The recommendation is to pad by repeating the edge-most pixels, but we will simply pad with white (which, in this case, does the same thing).)

In [34]:
def load_png_and_color_convert(filename):
    # load the image, making sure we can get RGB pixel values
    im = Image.open(filename)
    (real_w, real_h) = im.size
    rgb_im = im.convert('RGB')
    im_data = rgb_im.load()

    # pad the size
    pad_w = divroundup(real_w, 8) * 8
    pad_h = divroundup(real_h, 8) * 8
    print(f"Padding size ({real_w}, {real_h}) -> ({pad_w}, {pad_h})")

    # output image components
    y_img = [0] * (pad_w * pad_h)
    cb_img = [0] * (pad_w * pad_h)
    cr_img = [0] * (pad_w * pad_h)

    for y in range(pad_h):
        for x in range(pad_w):
            # get pixel, or padding in white
            if x < real_w and y < real_h:
                (r, g, b) = im_data[x, y]
            else:
                (r, g, b) = (255, 255, 255)
            
            # convert
            y_ = clamp(16 + 65.481/255*r + 128.553/255*g + 24.966/255*b)
            cb = clamp(128 - 37.797/255*r - 74.203/255*g + 112.0/255*b)
            cr = clamp(128 + 112.0/255*r - 93.786/255*g - 18.214/255*b)
            
            # perform level shift (A.3.1)
            y_ -= 128
            cb -= 128
            cr -= 128

            y_img[y * pad_w + x] = y_
            cb_img[y * pad_w + x] = cb
            cr_img[y * pad_w + x] = cr
    
    return (y_img, cb_img, cr_img, (real_w, real_h), pad_w)

In [None]:
(to_encode_y, to_encode_cb, to_encode_cr, to_encode_sz, to_encode_w) = load_png_and_color_convert("imgs/Tux2.png")

In [None]:
def demo_plot_to_encode_components():
    # reshuffle to array of arrays
    pad_h = len(to_encode_y) // to_encode_w
    def conv_one(img):
        ylist = []
        for y in range(pad_h):
            xlist = []
            for x in range(to_encode_w):
                xlist.append(img[y * to_encode_w + x])
            ylist.append(xlist)
        return ylist[::-1]
    # plot
    plt.pcolormesh(conv_one(to_encode_y))
    plt.show()
    plt.pcolormesh(conv_one(to_encode_cb))
    plt.show()
    plt.pcolormesh(conv_one(to_encode_cr))
    plt.show()
demo_plot_to_encode_components()

The solid colors in this image as well as the bright colors of the fursuits in the pride parade _really_ show how the Cb and Cr color components correspond to different shades of colors.

### Performing the forward DCT

The forward DCT is performed in a very similar way to performing the inverse DCT. The same property of linearity applies here as well, as does the existence of fast algorithms which we will not be using.

We start by precomputing the basis functions that will be used in the forwards direction. Note the different ordering of the `for` loops compared to the IDCT.

In [37]:
def compute_fdct_basis():
    ret = [None] * 64
    for y in range(8):
        for x in range(8):
            basis_func = [None] * 64

            for v in range(8):
                for u in range(8):
                    # constant
                    if u == 0 and v == 0:
                        a = 1 / 8
                    elif u == 0 or v == 0:
                        a = 1 / (4 * sqrt(2))
                    else:
                        a = 1 / 4
                
                    # cosines
                    val = a * cos(((2*x+1)*u/16)*pi) * cos(((2*y+1)*v/16)*pi)
                    basis_func[v*8 + u] = val
            
            # save it
            ret[y*8 + x] = basis_func
    return ret
fdct_basis = compute_fdct_basis()

In [None]:
def demo_show_fdct_basis():
    fig, axs = plt.subplots(8, 8)
    for y in range(8):
        for x in range(8):
            basis_func = fdct_basis[y*8 + x]
            shuffled_basis_func = []
            for v in range(8):
                row = []
                for u in range(8):
                    row.append(basis_func[v*8 + u])
                shuffled_basis_func.append(row)
            axs[y, x].pcolormesh(shuffled_basis_func)
demo_show_fdct_basis()

These are differently-shaped patterns compared to when we were performing the IDCT, and there's clearly still some kind of "stripey checkerboard" structure, but the details are out of scope for this discussion. (Each subplot corresponds to the frequencies that would be found in a source image containing only one nonzero pixel at a particular spot. These frequencies depend on how the finite input data gets turned into a periodic function.)

We can now compute the FDCT and rearrange the coefficients into zigzag order.

In [39]:
# compute the FDCT on an 8x8 data unit
def fdct_block(input, xbase, ybase, input_w):
    # output coefficient array
    coeffs = [0] * 64
    for y in range(8):
        for x in range(8):
            # get the pixel value
            px = input[(ybase + y) * input_w + (xbase + x)]
            # select the basis function
            basis_func = fdct_basis[y*8 + x]
            # accumulate it into the output, relying on the superposition principle
            for coeff_i in range(64):
                # implicit zigzag reordering happens here
                coeffs[DCT_BASIS_IDX_TO_ZIGZAG[coeff_i]] += px * basis_func[coeff_i]
    
    return coeffs

# compute all FDCTs for an entire image component
# also chunk them into blocks (which will later get interleaved to form MCUs)
def fdct_component(input, input_w):
    input_h = len(input) // input_w
    mcu_w = divroundup(input_w, 8)
    mcu_h = divroundup(input_h, 8)
    mcus = []
    for y_mcu in range(mcu_h):
        for x_mcu in range(mcu_w):
            # process one unit
            mcus.append(fdct_block(input, x_mcu * 8, y_mcu * 8, input_w))

    return mcus

# compute all FDCTs for an entire image (all three components)
def fdct_image(y_, cb, cr, input_w):
    # compute
    y_units = fdct_component(y_, input_w)
    cb_units = fdct_component(cb, input_w)
    cr_units = fdct_component(cr, input_w)

    # debug assertion
    assert len(y_units) == len(cb_units)
    assert len(y_units) == len(cr_units)

    # combine into array of three arrays
    return list(zip(y_units, cb_units, cr_units))

In [40]:
to_encode_mcus = fdct_image(to_encode_y, to_encode_cb, to_encode_cr, to_encode_w)

At this point, since we have a decoder implemented already, we can check our work by running the output through the decoder. We just have to make up a `JpegSof0SosInfo` containing the appropriate dimensions.

In [None]:
def demo_check_work_at_fdct_step():
    info = JpegSof0SosInfo(to_encode_w, len(to_encode_y) // to_encode_w, None)
    check_idct = idct_image(to_encode_mcus, info)
    check_decoded_im = convert_color_space(check_idct, to_encode_w)
    display(check_decoded_im)
demo_check_work_at_fdct_step()

So far so good!

At this point, it's useful to consider what implementation _choices_ we have already made. Even though the DCT step doesn't involve flexible parameters the way later steps such as quantization do, we _have_ already made several design choices.

One design choice we have made is the _precision_ to use for intermediate values. Even though input sample values and the output of the quantization step are all integers, values within the DCT computation step are (conceptually) real numbers. For simplicity, we have used Python's built-in floating point numbers, which are [IEEE 754 binary64](https://en.wikipedia.org/wiki/Double-precision_floating-point_format). Because of the later quantization step, this level of precision is not necessary, and implementations which do not use floating point are possible and can perform better.

Another design choice we have made is the padding to a multiple of 8 pixels, as explained before. Even though the extra pixels themselves will get discarded when decoding, they can affect visible pixels as the image quality is decreased (made more lossy).

So far, the design choices we have made have been rather minor, affecting things such as performance or only small parts of the image. The next choice we are about to make will be way more significant.

### Quantization

The next step after performing the DCT is to perform _quantization_. Quantization is simple to _compute_ -- just divide each coefficient by the corresponding entry in a _quantization table_, and then round to the closest integer. Except... how do you pick a quantization table? The standard gives two _examples_ in Table K.1 and Table K.2, but it very explicitly does not specify how else it is possible to construct these.

In fact, this is a significant source of variability between different pieces of software, cameras, and other devices which output JPEG! Not only are there several patents in this area, the variability is notable enough to be [useful for forensics](https://dfrws.org/sites/default/files/session-files/2008_USA_pres-using_jpeg_quantization_tables_to_identify_imagery_processed_by_software.pdf)!

The values in the quantization tables strongly affect the trade-off between quality and output size. In fact, software which offers a "quality" slider typically implements it as different choices or rescaling of the quantization tables.

There is also the choice of _how many_ quantization tables to use. The standard allows for up to four tables to be used at once, but implementations can choose to use fewer. Many implementations use one table for luma and the same second table for both of the chroma components (our original image does this). It is also possible to use different tables for each chroma component.

For this example, we will build it to support using three separate tables (one for each component), but we will start by hardcoding these tables to all 1s. We will tweak them later to see how they affect the image.

In [42]:
def quantize_image(mcus, tables):
    assert len(tables) == 3, "need three quantization tables"
    for i in range(len(mcus)):
        for component_i in range(3):
            for coeff_i in range(64):
                mcus[i][component_i][coeff_i] = round(mcus[i][component_i][coeff_i] / tables[component_i][coeff_i])

In [43]:
encode_quant_tables = [[1] * 64, [1] * 64, [1] * 64]
quantize_image(to_encode_mcus, encode_quant_tables)

### Miscellaneous operations

At this point, we need to perform the rest of the "miscellaneous" operations described in the overview, namely implementing the DC value predictor (a subtraction) and run-length encoding the zeros in the AC coefficients. Unlike when decoding, the way in which we have structured our code makes it straightforward to do this as a distinct step (we have not yet interleaved all of the resulting bits).

We will also prepare the magnitude encoding for coefficients so that we have a list of symbols (to be entropy coded) followed by raw bits (to be appended directly).

In this demonstration, we will not support using a restart interval. All MCUs will be coded in one go.

In [44]:
# returns (ssss, bits)
def encode_coeff_mag(coeff):
    if coeff == 0:
        return (0, 0)

    for i in range(1, 12):
        mask = (1 << i) - 1
        if abs(coeff) <= mask:
            if coeff > 0:
                return (i, coeff & mask)
            else:
                return (i, (coeff - 1) & mask)
    
    assert False, "coefficient out of range"

In [45]:
# every token is a symbol (which will be Huffman coded), followed by a string of _nbits_ bits
JpegEncodingToken = namedtuple('JpegEncodingToken', ['sym', 'bits', 'nbits'])

def encode_misc_ops(mcus):
    # mcus is an array of three arrays of 64 coefficients
    # [[mcu0_y, mcu0_cb, mcu0_cr], [mcu1_y, mcu1_cb, mcu1_cr], ...]

    # output will be a 3-array of arrays of tuples, where the second element is an array
    # (toks_y, toks_cb, toks_cr)
    # where each toks_X is
    # [(dc0, [ac0_0, ac0_1, ...]), (dc1, [ac1_0, ac1_1, ...]), ...]
    # each one will be a JpegEncodingToken
    output = [[], [], []]

    for mcu_i in range(len(mcus)):
        for component_i in range(3):
            # get previous DC coefficient
            if mcu_i == 0:
                dc_pred = 0
            else:
                dc_pred = mcus[mcu_i - 1][component_i][0]
            
            # this block's coefficients
            coeffs = mcus[mcu_i][component_i]

            # encode DC coefficient
            dc_diff = coeffs[0] - dc_pred
            (dc_ssss, dc_bits) = encode_coeff_mag(dc_diff)
            dc_tok = JpegEncodingToken(dc_ssss, dc_bits, dc_ssss)

            # RLE AC coefficients
            ac_toks = []
            ac_i = 1
            while ac_i < 64:
                # if this is all 0s, signal EOB
                is_eob = True
                for i in range(ac_i, 64):
                    if coeffs[i] != 0:
                        is_eob = False
                        break
                if is_eob:
                    ac_toks.append(JpegEncodingToken(0, 0, 0))
                    break
                
                # otherwise count the 0s, up to a max of 15
                num_zeros = 0
                while coeffs[ac_i] == 0 and num_zeros < 15:
                    ac_i += 1
                    num_zeros += 1
                
                # encode the zeros, as well as the next coefficient (which might still be 0)
                (ac_ssss, ac_bits) = encode_coeff_mag(coeffs[ac_i])
                assert ac_ssss in range(0, 11), "AC coefficient out of range"
                ac_rs = (num_zeros << 4) | ac_ssss
                ac_toks.append(JpegEncodingToken(ac_rs, ac_bits, ac_ssss))
                ac_i += 1

            output[component_i].append((dc_tok, ac_toks))
    
    return output

In [46]:
to_encode_toks = encode_misc_ops(to_encode_mcus)

### Faking entropy coding

The final step before we can start outputting bits is to perform Huffman compression. However, for expository purposes, we will start by intentionally constructing a very bad code without using the Huffman coding technique at all. This will help show that *any* algorithm which generates codes of the correct form can be used.

As a design choice, we will make the luma component use one pair of tables while the chroma components share the other pair of tables. In this case, only two pairs of (DC, AC) tables are allowed in total for a baseline JPEG, so it isn't possible to give each chroma component separate tables (it *is* possible under _extended DCT processes_ though).

Specifically, we will encode every possible DC coefficient scale value using 4 bits and every possible AC coefficient RS value using 8 bits. Fixed-length codes like this are always prefix-free. As long as we assign every codeword sequentially, it is compatible with the code table representation that JPEG uses.

In [47]:
def make_bad_huff_tables():
    dc_table = []
    for sym in range(0, 12):
        dc_table.append((sym, 4))

    ac_table = []
    codeword = 0
    for rs in range(256):
        ssss = rs & 0xf
        if ssss > 10:
            ac_table.append(None)
        else:
            ac_table.append((codeword, 8))
            codeword += 1

    # we will allow for using separate luma/chroma huffman tables, but we don't actually use that here
    dc_tables = [dc_table, dc_table]
    ac_tables = [ac_table, ac_table]
    return (dc_tables, ac_tables)

def make_bad_huff_tables_dht():
    # data needed for the DHT segment (codes for a given length)
    dc_syms_per_len = []
    for _i in range(16):
        dc_syms_per_len.append([])
    for sym in range(0, 12):
        dc_syms_per_len[3].append(sym)

    ac_syms_per_len = []
    for _i in range(16):
        ac_syms_per_len.append([])
    for rs in range(256):
        ssss = rs & 0xf
        if ssss > 10:
            continue
        ac_syms_per_len[7].append(rs)
    
    dc_tables = [dc_syms_per_len, dc_syms_per_len]
    ac_tables = [ac_syms_per_len, ac_syms_per_len]
    return (dc_tables, ac_tables)

In [48]:
encode_huff_tables = make_bad_huff_tables()
encode_huff_tables_for_dht = make_bad_huff_tables_dht()

The following code helps pack a bitstream into bytes:

In [49]:
# helpers for packing bits
def add_bit(wip, bit):
    (wip_bytes, bits_in_last_byte) = wip

    if bits_in_last_byte == 8:
        wip_bytes.append(bit << 7)
        bits_in_last_byte = 1
    else:
        wip_bytes[-1] |= bit << (7 - bits_in_last_byte)
        bits_in_last_byte += 1
    
    return (wip_bytes, bits_in_last_byte)

def add_bits(wip, bits, nbits):
    for biti in range(nbits):
        b = 1 if bits & (1 << (nbits - 1 - biti)) else 0
        wip = add_bit(wip, b)
    return wip

def pad_out(wip):
    for _i in range(wip[1], 8):
        wip = add_bit(wip, 1)
    return wip

The following code performs Huffman coding and byte stuffing, such that the bitstream is a complete number of bytes:

In [50]:
def encode_huffman(toks, dc_tables, ac_tables):
    wip = ([], 8)

    # debugging sanity check
    assert len(toks) == 3
    assert len(toks[0]) == len(toks[1])
    assert len(toks[0]) == len(toks[2])

    for tok_i in range(len(toks[0])):
        for component_i in range(3):
            # select the right tables given the component
            if component_i == 0:
                dc_table = dc_tables[0]
                ac_table = ac_tables[0]
            else:
                dc_table = dc_tables[1]
                ac_table = ac_tables[1]
            
            # get the tokens to be encoded
            (dc, ac) = toks[component_i][tok_i]

            # DC first
            dc_scale = dc_table[dc.sym]
            wip = add_bits(wip, *dc_scale)
            wip = add_bits(wip, dc.bits, dc.nbits)

            # now AC
            for ac_i in ac:
                ac_rs = ac_table[ac_i.sym]
                wip = add_bits(wip, *ac_rs)
                wip = add_bits(wip, ac_i.bits, ac_i.nbits)
    
    # pad out to a full byte
    wip = pad_out(wip)

    # perform byte stuffing (F.1.2.3))
    ret = b''
    for b in wip[0]:
        ret += bytes([b])
        if b == 0xff:
            ret += b'\x00'
    return ret

We can finally encode the main bitstream containing most of the image data:

In [None]:
to_encode_bitstream = encode_huffman(to_encode_toks, *encode_huff_tables)
print(f"Bitstream encoded to {len(to_encode_bitstream)} bytes")

At this point, we just have to write various file headers and we will have a complete JPEG!

In [52]:
def pack_final_jpeg(bitstream, w, h, quant_tables, huff_tables):
    ret = b''

    # SOI
    ret += b'\xff\xd8'

    # DQT
    assert len(quant_tables) == 3, "expected 3 quantization tables"
    ret += struct.pack(">HH", 0xffdb, 65 * 3 + 2)
    for i in range(3):
        ret += bytes([i])
        ret += bytes(quant_tables[i])

    # DHT
    dht = b''
    for Tc in range(2):
        for Th in range(2):
            dht += bytes([(Tc << 4) | Th])
            t = huff_tables[Tc][Th]
            for len_ in range(16):
                dht += bytes([len(t[len_])])
            for len_ in range(16):
                dht += bytes(t[len_])
    ret += struct.pack(">HH", 0xffc4, len(dht) + 2)
    ret += dht

    # SOF
    ret += struct.pack(">HHBHHBBBBBBBBBB", 0xffc0, 17, 8, h, w, 3, 1, 0x11, 0, 2, 0x11, 1, 3, 0x11, 2)

    # SOS
    ret += struct.pack(">HHBBBBBBBBBB", 0xffda, 12, 3, 1, 0x00, 2, 0x11, 3, 0x11, 0, 63, 0)

    ret += bitstream
    
    # EOI
    ret += b'\xff\xd9'
    
    return ret

In [None]:
our_encoded_jpeg = pack_final_jpeg(to_encode_bitstream, *to_encode_sz, encode_quant_tables, encode_huff_tables_for_dht)
print(f"Final compressed output is {len(our_encoded_jpeg)} bytes")
with open("our_own_encode.jpg", 'wb') as f:
    f.write(our_encoded_jpeg)

Well, we've made a JPEG file! It's... larger than the original PNG.

![our result](our_own_encode.jpg)

## Entropy coding a bit more seriously

One immediate improvement we can make is to improve the Huffman coding step. Annex K gives both an _example_ of an algorithm which can be used to compute a Huffman code as well as an example precomputed code. Let's see how well the precomputed code does.

In [54]:
# given a list of values at each length, assign them codewords systematically
def reassign_huff_codes(syms_per_len, print_dbg=False):
    table = [None] * 256
    min_code_len = 0
    for (i, syms) in enumerate(syms_per_len):
        if len(syms) != 0:
            min_code_len = i + 1
            break
    assert min_code_len > 0, "no codewords!"

    code_wip = 0
    for code_len in range(min_code_len, 17):
        code_vals_of_len = syms_per_len[code_len - 1]
        for code_val in code_vals_of_len:
            if print_dbg:
                print(f"{code_wip:0{code_len}b} = 0x{code_val:02x}")
            # skip 256 (used for reserving the all-1s code)
            if code_val != 256:
                table[code_val] = (code_wip, code_len)
            code_wip += 1
        code_wip <<= 1
    
    return table

In [55]:
def make_annex_k_example_huff_tables():
    dc_luma_syms_per_len = [
        [],                 # 1
        [0],                # 2
        [1, 2, 3, 4, 5],    # 3
        [6],                # 4
        [7],                # 5
        [8],                # 6
        [9],                # 7
        [10],               # 8
        [11],               # 9
        [], [], [], [], [], [], [],
    ]
    dc_chroma_syms_per_len = [
        [],                 # 1
        [0, 1, 2],          # 2
        [3],                # 3
        [4],                # 4
        [5],                # 5
        [6],                # 6
        [7],                # 7
        [8],                # 8
        [9],                # 9
        [10],               # 10
        [11],               # 11
        [], [], [], [], [],
    ]
    ac_luma_syms_per_len = [
        [],
        [0x01, 0x02],
        [0x03],
        [0x00, 0x04, 0x11],
        [0x05, 0x12, 0x21],
        [0x31, 0x41],
        [0x06, 0x13, 0x51, 0x61],
        [0x07, 0x22, 0x71],
        [0x14, 0x32, 0x81, 0x91, 0xA1],
        [0x08, 0x23, 0x42, 0xB1, 0xC1],
        [0x15, 0x52, 0xD1, 0xF0],
        [0x24, 0x33, 0x62, 0x72],
        [],
        [],
        [0x82],
        [
            0x09, 0x0A, 0x16, 0x17, 0x18, 0x19, 0x1A, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x34, 0x35, 0x36,
            0x37, 0x38, 0x39, 0x3A, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A, 0x53, 0x54, 0x55, 0x56,
            0x57, 0x58, 0x59, 0x5A, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6A, 0x73, 0x74, 0x75, 0x76,
            0x77, 0x78, 0x79, 0x7A, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8A, 0x92, 0x93, 0x94, 0x95,
            0x96, 0x97, 0x98, 0x99, 0x9A, 0xA2, 0xA3, 0xA4, 0xA5, 0xA6, 0xA7, 0xA8, 0xA9, 0xAA, 0xB2, 0xB3,
            0xB4, 0xB5, 0xB6, 0xB7, 0xB8, 0xB9, 0xBA, 0xC2, 0xC3, 0xC4, 0xC5, 0xC6, 0xC7, 0xC8, 0xC9, 0xCA,
            0xD2, 0xD3, 0xD4, 0xD5, 0xD6, 0xD7, 0xD8, 0xD9, 0xDA, 0xE1, 0xE2, 0xE3, 0xE4, 0xE5, 0xE6, 0xE7,
            0xE8, 0xE9, 0xEA, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, 0xFA
        ]
    ]
    ac_chroma_syms_per_len = [
        [],
        [0x00, 0x01],
        [0x02],
        [0x03, 0x11],
        [0x04, 0x05, 0x21, 0x31],
        [0x06, 0x12, 0x41, 0x51],
        [0x07, 0x61, 0x71],
        [0x13, 0x22, 0x32, 0x81],
        [0x08, 0x14, 0x42, 0x91, 0xA1, 0xB1, 0xC1],
        [0x09, 0x23, 0x33, 0x52, 0xF0],
        [0x15, 0x62, 0x72, 0xD1],
        [0x0A, 0x16, 0x24, 0x34],
        [],
        [0xE1],
        [0x25, 0xF1],
        [
            0x17, 0x18, 0x19, 0x1A, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x43,
            0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4A, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5A, 0x63,
            0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6A, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7A, 0x82,
            0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8A, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99,
            0x9A, 0xA2, 0xA3, 0xA4, 0xA5, 0xA6, 0xA7, 0xA8, 0xA9, 0xAA, 0xB2, 0xB3, 0xB4, 0xB5, 0xB6, 0xB7,
            0xB8, 0xB9, 0xBA, 0xC2, 0xC3, 0xC4, 0xC5, 0xC6, 0xC7, 0xC8, 0xC9, 0xCA, 0xD2, 0xD3, 0xD4, 0xD5,
            0xD6, 0xD7, 0xD8, 0xD9, 0xDA, 0xE2, 0xE3, 0xE4, 0xE5, 0xE6, 0xE7, 0xE8, 0xE9, 0xEA, 0xF2, 0xF3,
            0xF4, 0xF5, 0xF6, 0xF7, 0xF8, 0xF9, 0xFA
        ]
    ]

    dc_luma_lut = reassign_huff_codes(dc_luma_syms_per_len)
    dc_chroma_lut = reassign_huff_codes(dc_chroma_syms_per_len)
    ac_luma_lut = reassign_huff_codes(ac_luma_syms_per_len)
    ac_chroma_lut = reassign_huff_codes(ac_chroma_syms_per_len)

    return (
        [
            [dc_luma_lut, dc_chroma_lut],
            [ac_luma_lut, ac_chroma_lut],
        ], [
            [dc_luma_syms_per_len, dc_chroma_syms_per_len],
            [ac_luma_syms_per_len, ac_chroma_syms_per_len],
        ]
    )

In [None]:
(encode_huff_tables, encode_huff_tables_for_dht) = make_annex_k_example_huff_tables()
to_encode_bitstream = encode_huffman(to_encode_toks, *encode_huff_tables)
print(f"Bitstream encoded to {len(to_encode_bitstream)} bytes")
our_encoded_jpeg = pack_final_jpeg(to_encode_bitstream, *to_encode_sz, encode_quant_tables, encode_huff_tables_for_dht)
print(f"Final compressed output is {len(our_encoded_jpeg)} bytes")
with open("our_own_encode_better_huff.jpg", 'wb') as f:
    f.write(our_encoded_jpeg)

Wow, a reduction of almost 36% compared to the initial attempt, and very slightly smaller than the original PNG too. This intuitively does make some sense: we have compressed the file with a quality of "100%" (which is slightly lossy due to rounding to the nearest integer but is otherwise maximum quality) and then applied some "vaguely similar if you squint at it from afar" lossless compression (RLE+Huffman in the case of JPEG and LZ77+Huffman in the case of PNG).

## Huffman coding, for real this time

At this point we are going to *finally* stop avoiding understanding the Huffman coding technique, in order to eke out a bit more compression.

As described in the explanation for decoding, we wish to construct a _prefix code_ which is, in some sense, "optimal." Huffman coding happens to be an algorithm which is optimal under the following assumptions:

* symbol-by-symbol coding
* where each symbol is _independently and identically distributed_
* and where the probability distribution is known

Let's think about _why_ these assumptions are important. Let's suppose we are trying to compress the data `abababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababababab` (composed of symbols `a` and `b`). The probability of each symbol is exactly 0.5, but the symbol probabilities are clearly not independent (the pattern is always `ab`, there is never `aa` nor `bb`). One optimal symbol-by-symbol prefix code for encoding this information is

| Symbol | Code |
| ------ | ---- |
| `a`    | 0    |
| `b`    | 1    |

If we use this code, the output will consist of 200 bits. But surely there are better ways to express this pattern? Something like "repeat 100 copies of `ab`" perhaps? It seems as if that should only require a handful of bytes? Something like "this is a repeat" (and not something else), "repeat 100 times", "repeat the following 2 bytes", `ab`. This feels like it should only use somewhere around 5 bytes or 40 bits, if each of those pieces of information takes up one byte? If you follow this chain of thinking, then _congratulations_, you have stumbled into _dictionary coders_ such as the LZ (Lempel-Ziv) family of algorithms!

One of the most popular _lossless_ data compression algorithms today is _DEFLATE_, which is built around the idea of using LZ77 and Huffman coding together. This is the compression used in ZIP files as well as PNG. However, we cannot use that idea here for JPEG (the spec simply doesn't allow it).

Returning to the code we are constructing here, the key idea that David A. Huffman came up with was that the _expected value_ of the lengths of the codewords can be minimized (becoming optimal under our assumptions) by constructing a particular _binary tree_ from the bottom up. However, we cannot directly use Huffman's algorithm here, because:

* JPEG codewords have a maximum length
* Codewords consisting of all-1s cannot be used

The last issue is easy to deal with, and Annex K of the specification suggests a trick: add a dummy symbol (which will never be used) with the lowest possible probability. This dummy symbol will always end up with one of the longest codewords. If codewords are assigned in _canonical_ order, we can easily make sure that, *if* an all-1s codeword were to get used, it would be assigned to this dummy symbol. We can then just delete the dummy symbol and not use it, and then the all-1s codeword will never be used.

The first issue is harder to deal with. Annex K suggests an algorithm which adjusts the output of the normal Huffman coding algorithm. It makes modifications to the code to replace longer codewords with shorter ones. However, it is not clear if this algorithm continues to produce optimal codes or _why_ it would continue to do so if it does.

We will not be doing that. We will instead use a generalization called the [package-merge algorithm](https://en.wikipedia.org/wiki/Package-merge_algorithm), which is more complex but guaranteed to be optimal (under our assumptions). I don't know why the standard does not suggest this algorithm, but the package-merge algorithm was first published in 1990, only two years before the JPEG standard was released. It's possible that the algorithm wasn't around long enough to make it into the standard. If anybody knows for sure, feel free to chime in!

The following is a direct translation of the Wikipedia explanation of the package-merge algorithm into code:

In [57]:
def compute_one_optimal_huff_code(sym_freqs):
    # each coin/package will be stored as a tuple (total value/weight, [symbol, symbol, ...])

    # prepare a "master" list of "coins" for each symbol, each with value/weight corresponding to its frequency
    # this will be reused for each "denomination"/width (for a total of L coins for each symbol)
    master_coin_list = []
    num_syms = 0
    for sym in range(len(sym_freqs)):
        freq = sym_freqs[sym]
        if freq:
            num_syms += 1
            master_coin_list.append((freq, [sym]))
    assert num_syms, "no syms!"
    # sort by total value
    master_coin_list.sort(key=lambda x: x[0])

    # this will be the output from the previous iteration to merge in
    packaged = []

    # go through "denominations"/widths from smallest up to largest
    for denom_i in reversed(range(1, 17)):
        # merge in
        package_merge_iter_input = sorted(packaged + master_coin_list, key=lambda x: x[0])

        # package in pairs
        packaged = []
        for i in range(len(package_merge_iter_input) // 2):
            c0 = package_merge_iter_input[i * 2]
            c1 = package_merge_iter_input[i * 2 + 1]
            cout = (c0[0] + c1[0], c0[1] + c1[1])
            packaged.append(cout)
    
    # at this point, the "denomination" is 1
    # grab the top n-1
    selected_result_coins = packaged[:num_syms - 1]

    # count how many "coins" were used for each symbol
    sym_coins = [0] * len(sym_freqs)
    for (_weight, syms) in selected_result_coins:
        for sym in syms:
            sym_coins[sym] += 1

    # shuffle into a list of symbols per bit length
    syms_per_len = []
    for _i in range(16):
        syms_per_len.append([])
    # skip the dummy symbol
    for sym in range(len(sym_freqs) - 1):
        coins = sym_coins[sym]
        # sanity checks for debugging
        if coins == 0:
            assert sym_freqs[sym] == 0
        if sym_freqs[sym] == 0:
            assert coins == 0
        else:
            # sym is actually used, and it is encoded with a bit string of length `coins`
            syms_per_len[coins - 1].append(sym)
    
    return syms_per_len

def compute_optimal_huffman_codes(toks):
    # initialize empty frequency counts
    dc_freqs = [[0] * 13, [0] * 13]
    ac_freqs = [[0] * 257, [0] * 257]
    # add a count for the dummy symbol needed to avoid all-1s
    dc_freqs[0][-1] = 1
    dc_freqs[1][-1] = 1
    ac_freqs[0][-1] = 1
    ac_freqs[1][-1] = 1

    # go through the tokens, get the actual frequencies
    for tok_i in range(len(toks[0])):
        for component_i in range(3):
            # select the right tables given the component
            if component_i == 0:
                luma_chroma = 0
            else:
                luma_chroma = 1
            
            # get the tokens to be encoded
            (dc, ac) = toks[component_i][tok_i]

            # DC first
            assert dc.sym in range(0, 12)
            dc_freqs[luma_chroma][dc.sym] += 1

            # now AC
            for ac_i in ac:
                assert ac_i.sym in range(0, 256)
                ac_freqs[luma_chroma][ac_i.sym] += 1

    # compute the huffman code lengths
    dc_luma_syms_per_len = compute_one_optimal_huff_code(dc_freqs[0])
    dc_chroma_syms_per_len = compute_one_optimal_huff_code(dc_freqs[1])
    ac_luma_syms_per_len = compute_one_optimal_huff_code(ac_freqs[0])
    ac_chroma_syms_per_len = compute_one_optimal_huff_code(ac_freqs[1])

    # compute the canonical code values
    dc_luma_lut = reassign_huff_codes(dc_luma_syms_per_len, True)
    dc_chroma_lut = reassign_huff_codes(dc_chroma_syms_per_len, True)
    ac_luma_lut = reassign_huff_codes(ac_luma_syms_per_len, True)
    ac_chroma_lut = reassign_huff_codes(ac_chroma_syms_per_len, True)

    return (
        [
            [dc_luma_lut, dc_chroma_lut],
            [ac_luma_lut, ac_chroma_lut],
        ], [
            [dc_luma_syms_per_len, dc_chroma_syms_per_len],
            [ac_luma_syms_per_len, ac_chroma_syms_per_len],
        ]
    )

In [None]:
(encode_huff_tables, encode_huff_tables_for_dht) = compute_optimal_huffman_codes(to_encode_toks)
to_encode_bitstream = encode_huffman(to_encode_toks, *encode_huff_tables)
print(f"Bitstream encoded to {len(to_encode_bitstream)} bytes")
our_encoded_jpeg = pack_final_jpeg(to_encode_bitstream, *to_encode_sz, encode_quant_tables, encode_huff_tables_for_dht)
print(f"Final compressed output is {len(our_encoded_jpeg)} bytes")
with open("our_own_encode_better_huff_2.jpg", 'wb') as f:
    f.write(our_encoded_jpeg)

Wow, better tables saves us over 2 KiB! This is a reduction of almost 40% against the initial "not compressing at all" attempt, and a 6% improvement over the example tables given in Annex K. This really goes to show that Huffman codes are built around _assumptions_ about probability distributions. Because our example image doesn't match the Annex K example's assumed distribution exactly, we were able to gain some noticeable improvement by building our own code.

## Trading off size and quality

All of the tweaking we have been doing so far has been on the _lossless_ side of the JPEG compression. We have not actually done any work on the _lossy_ side of the compression algorithm, and we are still using the all-1s quantization table.

Let's blindly copy the examples from Annex K (even though they are not a perfect match for our image, as we haven't used 2:1 horizontal subsampling like they assume) just to see what happens.

In [59]:
ZIGZAG_TO_IDCT_BASIS_IDX = [
    0,
    1, 8,
    16, 9, 2,
    3, 10, 17, 24,
    32, 25, 18, 11, 4,
    5, 12, 19, 26, 33, 40,
    48, 41, 34, 27, 20, 13, 6,
    7, 14, 21, 28, 35, 42, 49, 56,
    57, 50, 43, 36, 29, 22, 15,
    23, 30, 37, 44, 51, 58,
    59, 52, 45, 38, 31,
    39, 46, 53, 60,
    61, 54, 47,
    55, 62,
    63,
]

In [60]:
def make_annex_k_example_quant_tables():
    luma = [
        16, 11, 10, 16, 24, 40, 51, 61,
        12, 12, 14, 19, 26, 58, 60, 55,
        14, 13, 16, 24, 40, 57, 69, 56,
        14, 17, 22, 29, 51, 87, 80, 62,
        18, 22, 37, 56, 68, 109, 103, 77,
        24, 35, 55, 64, 81, 104, 113, 92,
        49, 64, 78, 87, 103, 121, 120, 101,
        72, 92, 95, 98, 112, 100, 103, 99,
    ]
    chroma = [
        17, 18, 24, 47, 99, 99, 99, 99,
        18, 21, 26, 66, 99, 99, 99, 99,
        24, 26, 56, 99, 99, 99, 99, 99,
        47, 66, 99, 99, 99, 99, 99, 99,
        99, 99, 99, 99, 99, 99, 99, 99,
        99, 99, 99, 99, 99, 99, 99, 99,
        99, 99, 99, 99, 99, 99, 99, 99,
        99, 99, 99, 99, 99, 99, 99, 99,
    ]
    luma_zz = [luma[ZIGZAG_TO_IDCT_BASIS_IDX[i]] for i in range(64)]
    chroma_zz = [chroma[ZIGZAG_TO_IDCT_BASIS_IDX[i]] for i in range(64)]
    return [luma_zz, chroma_zz, chroma_zz]

In [61]:
encode_quant_tables = make_annex_k_example_quant_tables()

In [None]:
to_encode_mcus = fdct_image(to_encode_y, to_encode_cb, to_encode_cr, to_encode_w)
quantize_image(to_encode_mcus, encode_quant_tables)
to_encode_toks = encode_misc_ops(to_encode_mcus)
(encode_huff_tables, encode_huff_tables_for_dht) = compute_optimal_huffman_codes(to_encode_toks)
to_encode_bitstream = encode_huffman(to_encode_toks, *encode_huff_tables)
print(f"Bitstream encoded to {len(to_encode_bitstream)} bytes")
our_encoded_jpeg = pack_final_jpeg(to_encode_bitstream, *to_encode_sz, encode_quant_tables, encode_huff_tables_for_dht)
print(f"Final compressed output is {len(our_encoded_jpeg)} bytes")
with open("our_own_encode_annex_k_quant.jpg", 'wb') as f:
    f.write(our_encoded_jpeg)

Wow, this is an >80% reduction in size against our previous best! However, as expected from _lossy_ compression, zooming in on the result will show _compression artifacts_ on the outline of Tux.

![annex k quantization](our_own_encode_annex_k_quant.jpg)

So, what _exactly_ have we done here?

Remember that the quantization operation is a division followed by rounding to the nearest integer. For example, entry 00 of the luma quantization table is 16. This means that all values between (-8, 8) turn into 0 (because 8/16 = 0.5 which rounds to 0, ignoring tie-breaking rules). Likewise, all values between (-24, -8) and (8, 24) turn into $\mp 1$, values between (-40, -24) and (24, 40) turn into $\mp 2$, and so on.

As values in the quantization table get larger, the range of values which shrink down into each integer gets wider, and the amount of information which gets "thrown away" increases. We can try pushing this to an extreme.

In [None]:
# horrible, horrible quality
encode_quant_tables = [[150] * 64] * 3

to_encode_mcus = fdct_image(to_encode_y, to_encode_cb, to_encode_cr, to_encode_w)
quantize_image(to_encode_mcus, encode_quant_tables)
to_encode_toks = encode_misc_ops(to_encode_mcus)
(encode_huff_tables, encode_huff_tables_for_dht) = compute_optimal_huffman_codes(to_encode_toks)
to_encode_bitstream = encode_huffman(to_encode_toks, *encode_huff_tables)
print(f"Bitstream encoded to {len(to_encode_bitstream)} bytes")
our_encoded_jpeg = pack_final_jpeg(to_encode_bitstream, *to_encode_sz, encode_quant_tables, encode_huff_tables_for_dht)
print(f"Final compressed output is {len(our_encoded_jpeg)} bytes")
with open("our_own_encode_horrible_quant.jpg", 'wb') as f:
    f.write(our_encoded_jpeg)

We've gained an *additional* >60% size improvement! We've reduced the original PNG image by almost 94%! The file is now only about 2 KiB!

But, as a tradeoff:

![low quality quantization](our_own_encode_horrible_quant.jpg)

_\*mmm\*_ crunchy!

### Controllable quality

As explained previously, _how_ to derive "good" quantization tables is complicated and often subjective. Software derived from or inspired by the [Independent JPEG Group](https://ijg.org/)'s code often derives tables by rescaling the Annex K example tables depending on the desired quality as can be seen [here](https://github.com/libjpeg-turbo/libjpeg-turbo/blob/e0e18dea5433e600ea92d60814f13efa40a0d7dd/src/jcparam.c#L132) in libjpeg-turbo. Implementing this will be left as an exercise for the reader.

## Altering the DCT, using linear algebra

Now that we have a working (even if computationally-focused) implementation, we will explore what the tools found in _linear algebra_ can help us understand about the DCT.

If we have a linear function where the inputs and outputs are both finite-dimensional, we can choose to represent the function using a _matrix_. A matrix is simply a two-dimensional array of numbers along with specific procedures for performing calculations (just like how we have procedures for doing arithmetic with everyday base-10 Arabic numerals). Notably, there is a procedure for adding two matrices of the same size (add each corresponding element) and for multiplying two matrices with appropriately-matching sizes (adding up products of the appropriate elements; this procedure can be found in any introductory linear algebra reference).

The DCT (both forward and inverse) takes in a set of 64 numbers (which is finite), outputs 64 numbers (again, finite), and is linear. We can therefore represent the forward and inverse transforms as 64 x 64 matrices. Showing this symbolically is once again very procedural, definition-heavy, and not particularly enlightening. It mostly consists of yet more rearranging of calculations. However, with the rearranging we did earlier in our code, the way in which we compute `idct_basis` and `fdct_basis` already *is* a matrix representation of the transforms:

In [None]:
def demo_dct_as_matrix():
    print(f"idct_basis is a list of {len(idct_basis)} items")
    print(f"each item of idct_basis is a list of {len(idct_basis[0])} numbers")
    print()

    idct_basis_np = np.array(idct_basis)
    print(f"idct_basis as a NumPy array is: {idct_basis_np}")
    print(f"idct_basis dimensions are: {idct_basis_np.shape}")
    print()

    fdct_basis_np = np.array(fdct_basis)
    print(f"fdct_basis as a NumPy array is: {fdct_basis_np}")
    print(f"fdct_basis dimensions are: {fdct_basis_np.shape}")
    print()
demo_dct_as_matrix()

What is the point of doing this? A common tool used in mathematics is to start by giving names to particular objects. Once that is done, we often hope to find useful properties that apply more generally across all objects of a similar type (rather than just the specific object in front of us). These properties might not be so easily found without having "taken a step back" and looking from a different perspective. This is very similar to the power of using _abstractions_ when programming.

One example of this type of more-general property is that a matrix has an _inverse_ if and only if the value of its _determinant_ is nonzero. The determinant is a function which turns a _square_ matrix (one where the two dimensions are the same) into one single value (a _scalar_) according to a particular formula.

We can check the determinant of one of our DCT matrices by asking NumPy to compute it:

In [None]:
def demo_idct_det():
    idct_np_array = np.array(idct_basis)
    print(f"The determinant of the IDCT as a matrix is: {np.linalg.det(idct_np_array)}")
demo_idct_det()

Since the determinant is nonzero, we know that the IDCT has an inverse, i.e. a function which we can apply before or after the IDCT in order to get the input we started with. We were of course already assuming this, and we also "know" that the inverse of the IDCT is the FDCT, but it is good to see that the theory checks out.

In order to be _even more_ sure, we can try to actually combine together the FDCT and the IDCT to see what we get. In order to do this, we need a few more general properties about matrix arithmetic.

### Some properties of matrix arithmetic

When we apply two matrices to a given input, one way we can compute it is as follows:

Suppose we want to find $ABx$, where $A$ and $B$ are matrices and $x$ is the input as a _vector_ (which in this case can be thought of as just a $n$ x $1$ matrix). We can apply $B$ first, get an intermediate result, and then apply $A$ to that result:

```text
⠀      _____  Bx  _____
x ---> | B | ---> | A | ---> ABx
       ‾‾‾‾‾      ‾‾‾‾‾

```

We can also do the computation by first computing $AB$ (which will itself be a matrix) and then applying that to the input:

```text
⠀    _____
A -> | B | -> AB
     ‾‾‾‾‾
     ______
x -> | AB | -> ABx
     ‾‾‾‾‾‾
```

What we *cannot* do (in general, but it does work in special situations) is to apply $A$ to $x$ first and then apply $B$:

```text
⠀      _____  Ax  _____
x ---> | A | ---> | B | ---> BAx != ABx
       ‾‾‾‾‾      ‾‾‾‾‾

```

In mathematical terms, matrix multiplication is _associative_ but not _commutative_.

When matrices $A$ and $B$ are _inverses_ of each other, $AB = BA = I$ where $I$ is the _identity matrix_. The identity matrix is a matrix that happens to have the value 1 on the diagonal from the upper-left to the lower-right and 0 everywhere else. In this situation, $A$ and $B$ happen to commute with each other (but *only* with each other and not necessarily with any other arbitrary matrices).

We can write some code to check whether the FDCT matrix multiplied by the IDCT matrix indeed gives us an identity matrix (as well as the other way around):

In [None]:
def demo_dct_identity():
    idct_basis_np = np.array(idct_basis)
    fdct_basis_np = np.array(fdct_basis)
    
    print(f"IDCT * FDCT is: {idct_basis_np @ fdct_basis_np}")
    plt.pcolormesh(idct_basis_np @ fdct_basis_np)
    plt.show()
    
    print(f"FDCT * IDCT is: {fdct_basis_np @ idct_basis_np}")
    plt.pcolormesh(fdct_basis_np @ idct_basis_np)
    plt.show()
demo_dct_identity()

Annoyingly, we don't get back _exactly_ the identity matrix due to tiny rounding errors (on the order of $10^{-16}$), but the plots should help to show that the result is indeed visually very close to the identity matrix.

### Orthonormal matrices

The DCT matrices actually _even more_ special than just being invertible -- they are also _orthonormal_ matrices.

One way to determine if a matrix is orthonormal is to check whether the inverse of the matrix is equal to its _transpose_ (a matrix with the rows and columns swapped). (Another property of matrices is that inverses, if they exist, must be unique. This implies that this property is a biconditional and works both ways.)

In [None]:
def demo_idct_fdct():
    idct_np_array = np.array(idct_basis)
    fdct_np_array = np.array(fdct_basis)
    difference = idct_np_array.T - fdct_np_array
    print(difference)
demo_idct_fdct()

Instead of checking this with a math-y formula, we could have also shown this property by just looking at the code for the `compute_idct_basis` and `compute_fdct_basis` and noting just how similar they are (differing only in some `uv` and `xy` swaps).

Now that we have explored some of these linear algebra properties, a question that might be asked is... can we use _other_ invertible matrices? The result won't be compatible with JPEG, but... _can we_, at least according to the mathematics? The answer is... yes!

### Substituting the DCT with something else

Let's suppose we are, um, an Independent Photographic Novice Hacker and we want to replace the JPEG DCT step with a different transform. One immediate obvious choice is to use the identity matrix.

In [68]:
fdct_basis = np.identity(64)
idct_basis = np.identity(64)

In [None]:
# demo this with Annex K 50% quality quantization tables
encode_quant_tables = make_annex_k_example_quant_tables()
to_encode_mcus = fdct_image(to_encode_y, to_encode_cb, to_encode_cr, to_encode_w)
quantize_image(to_encode_mcus, encode_quant_tables)
to_encode_toks = encode_misc_ops(to_encode_mcus)
(encode_huff_tables, encode_huff_tables_for_dht) = compute_optimal_huffman_codes(to_encode_toks)
to_encode_bitstream = encode_huffman(to_encode_toks, *encode_huff_tables)
print(f"Bitstream encoded to {len(to_encode_bitstream)} bytes")
our_encoded_jpeg = pack_final_jpeg(to_encode_bitstream, *to_encode_sz, encode_quant_tables, encode_huff_tables_for_dht)
print(f"Final compressed output is {len(our_encoded_jpeg)} bytes")
with open("ipnh_identity.jpg", 'wb') as f:
    f.write(our_encoded_jpeg)

Hmm, that clearly didn't compress nearly as well as using the DCT, but we got _some_ output.

If we try to pretend that our IPNH file is actually a JPEG, we get the following:

![](ipnh_identity.jpg)

And... huh. The human brain's ability to correct errors is quite impressive! What is happening here is that we have colors which have _not_ been transformed by the DCT, but a standard JPEG decoder is performing an IDCT on it anyways. This isn't very intuitively meaningful, but you can see hints of the DCT basis functions showing through.

Let's decode the result back (again, using the identity transform rather than the IDCT):

In [None]:
with open('ipnh_identity.jpg', 'rb') as f:
    jpeg_data = f.read()
parsed_jpeg = parse_jpeg_segments(jpeg_data)
parsed_sof_sos_info = parse_sof_sos(parsed_jpeg)
quant_tables = parse_dqt(parsed_jpeg)
huff_tables = parse_dht(parsed_jpeg)
decode_jpeg_mcus = decode_jpeg_scan(parsed_jpeg.scan_data, parsed_sof_sos_info, huff_tables[0], huff_tables[1])
dequantize_jpeg(decode_jpeg_mcus, parsed_sof_sos_info, quant_tables)
decoded_jpeg_components = idct_image(decode_jpeg_mcus, parsed_sof_sos_info)
final_decoded_image = convert_color_space(decoded_jpeg_components, divroundup(parsed_sof_sos_info.x, 8) * 8)
display(final_decoded_image)

We can *clearly* see the 8x8 block structure in this image, as well as the effect of the quantization matrices affecting the colors in the upper-left corner differently compared to the rest of each block.

Can we do better (worse)? Let's try an entirely random transform! In order to pick one, we can use SciPy to generate an orthonormal matrix with determinant 1 by asking for a member of the _special orthogonal group_ (a fancy name for the set of all matrices with this property).

In [71]:
fdct_basis = sp.stats.special_ortho_group.rvs(64)
idct_basis = fdct_basis.T

In [None]:
# demo this with Annex K 50% quality quantization tables
encode_quant_tables = make_annex_k_example_quant_tables()
to_encode_mcus = fdct_image(to_encode_y, to_encode_cb, to_encode_cr, to_encode_w)
quantize_image(to_encode_mcus, encode_quant_tables)
to_encode_toks = encode_misc_ops(to_encode_mcus)
(encode_huff_tables, encode_huff_tables_for_dht) = compute_optimal_huffman_codes(to_encode_toks)
to_encode_bitstream = encode_huffman(to_encode_toks, *encode_huff_tables)
print(f"Bitstream encoded to {len(to_encode_bitstream)} bytes")
our_encoded_jpeg = pack_final_jpeg(to_encode_bitstream, *to_encode_sz, encode_quant_tables, encode_huff_tables_for_dht)
print(f"Final compressed output is {len(our_encoded_jpeg)} bytes")
with open("ipnh_random.jpg", 'wb') as f:
    f.write(our_encoded_jpeg)

In [None]:
with open('ipnh_random.jpg', 'rb') as f:
    jpeg_data = f.read()
parsed_jpeg = parse_jpeg_segments(jpeg_data)
parsed_sof_sos_info = parse_sof_sos(parsed_jpeg)
quant_tables = parse_dqt(parsed_jpeg)
huff_tables = parse_dht(parsed_jpeg)
decode_jpeg_mcus = decode_jpeg_scan(parsed_jpeg.scan_data, parsed_sof_sos_info, huff_tables[0], huff_tables[1])
dequantize_jpeg(decode_jpeg_mcus, parsed_sof_sos_info, quant_tables)
decoded_jpeg_components = idct_image(decode_jpeg_mcus, parsed_sof_sos_info)
final_decoded_image = convert_color_space(decoded_jpeg_components, divroundup(parsed_sof_sos_info.x, 8) * 8)
display(final_decoded_image)

This *also* did not compress particularly well, but once again it _worked_ in the sense that we were able to decode the result back into something vaguely resembling the input. Because we have chosen a random transform, the quantization error vaguely looks like "noise" across the entire 8x8 block.

Can we break the rules even harder? What if we use an invertible transform that isn't even orthonormal?

In [74]:
fdct_basis = np.random.rand(64, 64)
# XXX it is *extremely* unlikely that a random matrix isn't invertible
# (so we don't bother to check)
idct_basis = sp.linalg.inv(fdct_basis)

In [None]:
# demo this with Annex K 50% quality quantization tables
encode_quant_tables = make_annex_k_example_quant_tables()
to_encode_mcus = fdct_image(to_encode_y, to_encode_cb, to_encode_cr, to_encode_w)
quantize_image(to_encode_mcus, encode_quant_tables)
to_encode_toks = encode_misc_ops(to_encode_mcus)
(encode_huff_tables, encode_huff_tables_for_dht) = compute_optimal_huffman_codes(to_encode_toks)
to_encode_bitstream = encode_huffman(to_encode_toks, *encode_huff_tables)
print(f"Bitstream encoded to {len(to_encode_bitstream)} bytes")
our_encoded_jpeg = pack_final_jpeg(to_encode_bitstream, *to_encode_sz, encode_quant_tables, encode_huff_tables_for_dht)
print(f"Final compressed output is {len(our_encoded_jpeg)} bytes")
with open("ipnh_random_non_ortho.jpg", 'wb') as f:
    f.write(our_encoded_jpeg)

In [None]:
with open('ipnh_random_non_ortho.jpg', 'rb') as f:
    jpeg_data = f.read()
parsed_jpeg = parse_jpeg_segments(jpeg_data)
parsed_sof_sos_info = parse_sof_sos(parsed_jpeg)
quant_tables = parse_dqt(parsed_jpeg)
huff_tables = parse_dht(parsed_jpeg)
decode_jpeg_mcus = decode_jpeg_scan(parsed_jpeg.scan_data, parsed_sof_sos_info, huff_tables[0], huff_tables[1])
dequantize_jpeg(decode_jpeg_mcus, parsed_sof_sos_info, quant_tables)
decoded_jpeg_components = idct_image(decode_jpeg_mcus, parsed_sof_sos_info)
final_decoded_image = convert_color_space(decoded_jpeg_components, divroundup(parsed_sof_sos_info.x, 8) * 8)
display(final_decoded_image)

Once again, it works even worse than before and yields even worse quality, but we *have* managed to get back something vaguely resembling the original input. Hopefully by this point you are convinced that _any_ invertible transform can (somewhat) be used.

Knowing this, let's try an example that is once again a "structured" transform (just not a DCT) -- a [Hadamard transform](https://en.wikipedia.org/wiki/Hadamard_transform):

In [None]:
fdct_basis = sp.linalg.hadamard(64) / 8
idct_basis = fdct_basis.T

demo_show_idct_basis()

A Hadamard transform looks "vaguely like" a DCT, except that it uses square waves (alternating between only two distinct values) and is rearranged in a slightly different order.

In [None]:
# demo this with Annex K 50% quality quantization tables
encode_quant_tables = make_annex_k_example_quant_tables()
to_encode_mcus = fdct_image(to_encode_y, to_encode_cb, to_encode_cr, to_encode_w)
quantize_image(to_encode_mcus, encode_quant_tables)
to_encode_toks = encode_misc_ops(to_encode_mcus)
(encode_huff_tables, encode_huff_tables_for_dht) = compute_optimal_huffman_codes(to_encode_toks)
to_encode_bitstream = encode_huffman(to_encode_toks, *encode_huff_tables)
print(f"Bitstream encoded to {len(to_encode_bitstream)} bytes")
our_encoded_jpeg = pack_final_jpeg(to_encode_bitstream, *to_encode_sz, encode_quant_tables, encode_huff_tables_for_dht)
print(f"Final compressed output is {len(our_encoded_jpeg)} bytes")
with open("ipnh_hadamard.jpg", 'wb') as f:
    f.write(our_encoded_jpeg)

In [None]:
with open('ipnh_hadamard.jpg', 'rb') as f:
    jpeg_data = f.read()
parsed_jpeg = parse_jpeg_segments(jpeg_data)
parsed_sof_sos_info = parse_sof_sos(parsed_jpeg)
quant_tables = parse_dqt(parsed_jpeg)
huff_tables = parse_dht(parsed_jpeg)
decode_jpeg_mcus = decode_jpeg_scan(parsed_jpeg.scan_data, parsed_sof_sos_info, huff_tables[0], huff_tables[1])
dequantize_jpeg(decode_jpeg_mcus, parsed_sof_sos_info, quant_tables)
decoded_jpeg_components = idct_image(decode_jpeg_mcus, parsed_sof_sos_info)
final_decoded_image = convert_color_space(decoded_jpeg_components, divroundup(parsed_sof_sos_info.x, 8) * 8)
display(final_decoded_image)

This "structured" transform gives a size result that isn't quite as good as using the DCT and with subjectively slightly worse compression artifacts, but it's much better than the examples using random matrices.

A natural question to ask after this observation is "are there yet more ways to come up with some kind of 'structured' transform, possibly one which works even better?" In other words, we want to find some transform which is, by some metric, "most similar to" or "most aligned with" the input image data. If we rummage around various "linear algebra techniques from academia," one technique that pops out is [principal component analysis](https://en.wikipedia.org/wiki/Principal_component_analysis).

PCA is a data processing technique that takes some input data (often something like "outputs from running an experiment") and finds the _principal components_ which best fit the data. These components line up with the _variance_ in the data. In the following demo, we treat all of the 8x8 block in the original image as one "result of an experiment" and try to find the components which "explain" most of the input:

In [None]:
def demo_pca_transform():
    global fdct_basis
    global idct_basis
    (input_w, input_h) = to_encode_sz

    # shuffle each 8x8 block into a 64-element list, and accumulate them
    X = []
    def get_mcu_block(input, xbase, ybase):
        output = []
        for y in range(8):
            for x in range(8):
                px = input[(ybase + y) * input_w + (xbase + x)]
                output.append(px)
        return output
    def blocks_for_component(input):
        mcu_w = divroundup(input_w, 8)
        mcu_h = divroundup(input_h, 8)
        blocks = []
        for y_mcu in range(mcu_h):
            for x_mcu in range(mcu_w):
                blocks.append(get_mcu_block(input, x_mcu * 8, y_mcu * 8))
        return blocks

    X += blocks_for_component(to_encode_y)
    X += blocks_for_component(to_encode_cb)
    X += blocks_for_component(to_encode_cr)
    
    # do the PCA
    pca = PCA(n_components=64)
    pca.fit(np.array(X))

    # get the result
    pca_result = pca.components_
    # shuffle things into zigzag order to fit quantization better
    idct_shuffled = [None] * 64
    for i in range(64):
        basis_func = pca_result[i]
        idct_shuffled[ZIGZAG_TO_IDCT_BASIS_IDX[i]] = list(basis_func)

    # set the result
    idct_basis = idct_shuffled
    fdct_basis = np.array(idct_basis).T

    # plot the result (IDCT only as it is more intuitive)
    demo_show_idct_basis()
demo_pca_transform()

From the plots, we see some "vague stripe and checkerboard" patterns, especially in the upper-left. Unlike the DCT, they aren't aligned with the x/y axes but are all slightly diagonal. Intuitively, this makes some sense -- the image does contain diagonal curves.

However, theory is one thing. How well does this actually work?

In [None]:
# demo this with Annex K 50% quality quantization tables
encode_quant_tables = make_annex_k_example_quant_tables()
to_encode_mcus = fdct_image(to_encode_y, to_encode_cb, to_encode_cr, to_encode_w)
quantize_image(to_encode_mcus, encode_quant_tables)
to_encode_toks = encode_misc_ops(to_encode_mcus)
(encode_huff_tables, encode_huff_tables_for_dht) = compute_optimal_huffman_codes(to_encode_toks)
to_encode_bitstream = encode_huffman(to_encode_toks, *encode_huff_tables)
print(f"Bitstream encoded to {len(to_encode_bitstream)} bytes")
our_encoded_jpeg = pack_final_jpeg(to_encode_bitstream, *to_encode_sz, encode_quant_tables, encode_huff_tables_for_dht)
print(f"Final compressed output is {len(our_encoded_jpeg)} bytes")
with open("ipnh_pca.jpg", 'wb') as f:
    f.write(our_encoded_jpeg)

In [None]:
with open('ipnh_pca.jpg', 'rb') as f:
    jpeg_data = f.read()
parsed_jpeg = parse_jpeg_segments(jpeg_data)
parsed_sof_sos_info = parse_sof_sos(parsed_jpeg)
quant_tables = parse_dqt(parsed_jpeg)
huff_tables = parse_dht(parsed_jpeg)
decode_jpeg_mcus = decode_jpeg_scan(parsed_jpeg.scan_data, parsed_sof_sos_info, huff_tables[0], huff_tables[1])
dequantize_jpeg(decode_jpeg_mcus, parsed_sof_sos_info, quant_tables)
decoded_jpeg_components = idct_image(decode_jpeg_mcus, parsed_sof_sos_info)
final_decoded_image = convert_color_space(decoded_jpeg_components, divroundup(parsed_sof_sos_info.x, 8) * 8)
display(final_decoded_image)

Dang. Not bad, but worse than the Hadamard transform, and we _continue_ to be unable to beat the DCT. Understanding _why_ we are struggling requires a different perspective on the DCT beyond pure linear algebra.

## The DCT and spatial frequencies

The DCT is often described as a "Fourier-related" transform that converts between the spatial (or time) domain and the frequency domain. But what does this actually mean?

"Fourier-related" transforms are an entire family of operations which break down input into sinusoidal (sine, cosine, or some combination of both) functions. Some of them work on abstract and/or idealized mathematical functions whereas others are much more computation-friendly. Different transforms also make different assumptions about and have different restrictions on their input.

Why do we care so much about sinusoid functions? Sinusoid functions often arise naturally in physical phenomena because of their association with movement along a circle, and because the _derivative_ and _integral_ of a sinusoid remains a sinusoid. Sinusoid functions also have a unique property that _linear combinations_ of sinusoids of a given frequency always result in a sinusoid of the same frequency. This "non-mixing" of frequencies can greatly simplify the mathematical modeling of physical phenomena.

### JPEG and the DCT

JPEG performs a 2-dimensional DCT on 8x8 blocks. A 2-dimensional transform is generated by multiplying sinusoids aligned with the $x$ axis by sinusoids aligned with the $y$ axis:

In [None]:
def demo_2d_freqs():
    one_d_freqs = []
    for freq in range(8):
        this_freq_wave = []
        for x in range(8):
            y = cos(pi*freq/8*(x+0.5))
            this_freq_wave.append(y)
        one_d_freqs.append(this_freq_wave)

    fig, axs = plt.subplots(9, 9)
    # plot along x
    for freq in range(8):
        axs[0, freq+1].pcolormesh([one_d_freqs[freq]])
    # plot along y
    for freq in range(8):
        axs[freq+1, 0].pcolormesh(np.array([one_d_freqs[freq]]).T)
    # plot products
    for freq_x in range(8):
        for freq_y in range(8):
            freq_x_wave = np.array([one_d_freqs[freq_x]])
            freq_y_wave = np.array([one_d_freqs[freq_y]])
            freq_product = freq_x_wave.T @ freq_y_wave
            axs[freq_y+1, freq_x+1].pcolormesh(freq_product)
demo_2d_freqs()

The 8x8 block size was chosen as a tradeoff between speed (larger DCTs are more computationally expensive) and quality metrics. Larger DCTs do get used in more complex formats.

For JPEG and lossy image compression in general, the DCT is very useful because the human visual system tends to be more sensitive to gradual changes (lower spatial frequencies) and less sensitive to rapid changes and fine detail (higher spatial frequencies). In the above plots, lower frequencies are towards the top-left and higher frequencies are towards the bottom-right.

This (_subjective_!) behavior of the human visual system is why the example JPEG quantization table contains smaller numbers in the upper-left and larger numbers (throwing away more data) in the lower-right -- the data which is thrown away is data which the eyes are less sensitive to. This is also why JPEG arranges the coefficients in a zig-zag order -- it groups the frequences from lower (and thus more visible) to higher.

One very notable entry in the above plots is the one exactly in the upper-left. Because $\cos(0)=1$, that entry contains a constant value across the entire 8x8 block. When decomposing the source data into frequencies, this particular $(0, 0)$ frequency will contribute evenly across the source block, and so the amount of it which needs be added will be the average pixel value across the block. This property explains why the "quick-and-dirty" check of the decode (before any IDCT computations were done) works -- using only the average of each 8x8 block is a simple way to downscale the image by a factor of 8. This particular frequency gets referred to as the _DC coefficient_ because a frequency of 0 does not change and stays a fixed value, just like the voltage in a _direct current_ electrical circuit. The other coefficients are _AC coefficients_ because they vary sinusoidally like _alternating current_.

The *entire* combination of breaking the image into frequencies, throwing away more of the high-frequency information, getting 0s out as a result of this throwing-away step, and finally the lossless compression results in (to recycle some ancient buzzwords) synergy! This is why hacking on only the DCT without fine-tuning everything else did not yield any improvements!

For an example of a principled and properly-engineered lossy compression algorithm which _does_ improve upon JPEG, consider looking into formats such as [JPEG XL](https://en.wikipedia.org/wiki/JPEG_XL).

### Why a _cosine_ transform?

The discrete _Fourier_ transform decomposes the input data into a combination of both sine *and* cosine functions. This is usually represented by using complex numbers. However, if the input is an _even function_ (a function which is symmetric around the y axis), it can be represented using _only_ cosines (and so we won't need complex numbers). Is there a way to make sure our data is always an even function? What does it even mean for _discrete_ data consisting of a _finite_ set of numbers to be an even function?

These discrete transforms handle this by making an _assumption_ that the input data repeats endlessly. Exactly _how_ it repeats corresponds to different transforms. For example, repeating end-to-end corresponds to the discrete Fourier transform. The discrete cosine transform instead repeats _mirrored_ copies of the data:

<img src="imgs/dct.png" width="500" alt="Example of DCT mirroring when repeating data"/>

<small>Originally created by English Wikipedia user Stevenj</small>

This repetition results in a periodic, infinite-length, even function.

This _assumption_ which is being made implicitly is the reason why different choices of JPEG input padding (when the input is not a multiple of 8 pixels in width/height) matter -- this padding is mixed with the real image data and affects what frequencies come out.

### Advanced -- sampling

In the above demo, we used 8 equally-spaced frequencies. But how did we arrive at this? Can't frequencies theoretically be any real number?

First of all, when something which is continuous gets sampled and turned into discrete data, there is an upper limit on "useful" frequencies. Anything above this cannot be captured accurately. For example, this includes very tiny objects which end up smaller than individual pixels on a camera's sensor. The precise limit is given by the [Nyquist-Shannon sampling theorem](https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem) which states that the sampling frequency must be at least twice the frequency of the data you want to accurately reproduce. Data which is above this limit becomes indistinguishable from data at a lower frequency (an _alias_ frequency). In photography, this can result in artifacts such as [moiré](https://en.wikipedia.org/wiki/Moir%C3%A9_pattern).

Between 0 and this upper limit, the "proper" frequencies to use is to space out the frequencies evenly when the input data consists of evenly-spaced samples in space (i.e. a grid of pixels, each of which is a sample, with equal spacing between the pixels i.e. a 1:1 pixel aspect ratio). This can be derived from the (continuous) Fourier transform via manipulations involving multiplying by a "train" of [Dirac delta functions](https://en.wikipedia.org/wiki/Dirac_delta_function), which is once again a typical homework problem in a university-level signals and systems lecture.

Using frequencies which are _not_ evenly spaced corresponds to nonuniform sampling in the time/spatial domain. [Non-uniform discrete Fourier transforms](https://en.wikipedia.org/wiki/Non-uniform_discrete_Fourier_transform) can be used in these situations.

### Advanced -- continuous functions

All of the theory described so far has been focused on discrete data as that is the type of data which can be easily represented by a computer. The theory for _continuous_ Fourier-like transforms is significantly more complicated! Even most introductory textbooks will skip over some of this and only explain how to deal with functions which are "nice."

For quite some time, Fourier transforms didn't actually sit on solid theoretical foundations! Frequency-domain methods were so useful that they were being used well before _mathematical analysis_ was developed enough to justify them to modern standards of rigor. (So don't feel bad if you run into difficulties trying to understand this!)

The best way to learn about this would probably be from university-level coursework and associated materials.

# Conclusion

We built a bare-minimum JPEG encoder from the ground up and explored a lot of the mathematics and signal processing which underlies it. We attempted various "improvements" to the algorithm but were ultimately unsuccessful. Hopefully you learned something, especially if you learned enough to demystify other more advanced lossy compression techniques.

Questions? Comments? Corrections? Feel free to send feedback!