<a href="https://colab.research.google.com/github/aditya-r-m/experimental/blob/main/classical-algorithms/ecc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Overview

There is rich theory around Finite Fields & Cyclic Linear Block Codes behind modern Error Correcting Codes.
This notebook covers two of the most important concepts with basic working implementations:
1. Hamming codes
2. Reed-Solomon encoding + Berlekamp-Welch decoding

To represent arithmetic operations in encoding & decoding procedures, we will use the simplest finite fields, with prime order, which are directly accessible using modular arithmetic on integers.
We will also use basic utilities for linear transformations acting on vectors of these field elements.


In [None]:
class ModInt(object):

    def __init__(self, val = 0, mod = 2):
        assert isinstance(val, int)
        self._mod = mod
        self._val = val % mod

    def __add__(self, other):
        assert isinstance(other, ModInt)
        assert self._mod == other._mod
        return ModInt(self._val + other._val, self._mod)

    def __neg__(self):
        return ModInt(-self._val, self._mod)

    def __sub__(self, other):
        assert isinstance(other, ModInt)
        assert self._mod == other._mod
        return ModInt(self._val - other._val, self._mod)

    def __mul__(self, other):
        assert isinstance(other, ModInt)
        assert self._mod == other._mod
        return ModInt(self._val * other._val, self._mod)

    def is_zero(self):
        return self._val == 0

    def __inv(self): # Extended Euclid's algorithm
        assert not self.is_zero()
        m0, r0 = 0, self._mod
        m1, r1 = 1, self._val
        while r1 > 1: [m0, r0, m1, r1] = [
            m1,
            r1,
            (m0 - (r0 // r1) * m1),
            (r0 - (r0 // r1) * r1)]
        return ModInt(m1, self._mod)

    def __truediv__(self, other):
        assert isinstance(other, ModInt)
        assert self._mod == other._mod
        return self * other.__inv()

    def __eq__(self, other):
        assert isinstance(other, ModInt)
        assert self._mod == other._mod
        return self._val == other._val

    def __repr__(self):
        return f"{self._val} mod {self._mod}"

    def __str__(self):
        return self.__repr__()

    def __int__(self):
        return self._val


In [None]:
from functools import reduce
from copy import deepcopy

# let T . X = Y, where T is a linear transformation acting on vector X

# dot(T, X) = Y
def dot(T, X):
    return list(
        reduce(
            lambda a, b: a + b,
            map(
                lambda a: a[0] * a[1],
                zip(v, X)))
        for v in T)

# inv(T, Y) = X
def inv(T, Y):
    [T, Y] = deepcopy([T, Y])

    for i in range(len(T)):

        if T[i][i].is_zero():
            for j in range(i + 1, len(T)):
                if T[j][i].is_zero(): continue
                for k in range(len(T)): T[i][k] += T[j][k]
                Y[i] += Y[j]
                break
        if T[i][i].is_zero(): raise Exception("transformation is not invertible")

        scaling_factor = T[i][i]
        for k in range(len(T)): T[i][k] /= scaling_factor
        Y[i] /= scaling_factor

        for j in range(len(T)):
            if j == i: continue
            scaling_factor = T[j][i] / T[i][i]
            for k in range(len(T)): T[j][k] -= scaling_factor * T[i][k]
            Y[j] -= scaling_factor * Y[i]

    return Y


#### Hamming Codes

There are various ways to look at Hamming Codes.
They project low dimensional binary vectors to higher dimensions such that there is sufficient "distance" between valid codewords. Any invalid codeword can then be moved to the closest valid one.

The additional bits can be considered parity bits located at indices with exactly one true bit in the binary representation. The values for parity bits are XOR-sums of all other positions with matching bit set in the binary representation.

The computation process can be summarized in the form of a generator matrix that encodes the input, & a parity-check matrix which acts on the (possibly erroneous) code to give a syndrome vector. This syndrome vector is binary representation of the index with erroneous bit flip.


In [None]:
# Hamming(7,4) implementation

mod2 = lambda i : ModInt(i, 2)

generator = list(map(lambda l: list(map(mod2, l)), [
    [1,1,0,1], # 001 p[1,2,4]
    [1,0,1,1], # 010 p[1,3,4]
    [1,0,0,0], # 011 d[1]
    [0,1,1,1], # 100 p[2,3,4]
    [0,1,0,0], # 101 d[2]
    [0,0,1,0], # 110 d[3]
    [0,0,0,1], # 111 d[4]
]))

parity_checker = list(map(lambda l: list(map(mod2, l)), [
    [1,0,1,0,1,0,1], # p001 : **1
    [0,1,1,0,0,1,1], # p010 : *1*
    [0,0,0,1,1,1,1], # p100 : 1**
]))

for d in range(16):
    data = list(map(lambda x: mod2(int(x > 0)), [d&8,d&4,d&2,d&1]))
    for e in range(-1, 7):
        code = dot(generator, data)
        if e >= 0: code[e] = mod2(1) - code[e]
        syndrome = dot(parity_checker, code)
        # Any single erroneous bit flip is identified by the syndrome vector
        assert e == sum(int(s)<<i for (i, s) in enumerate(syndrome)) - 1


#### Reed-Solomon Encoding

Reed-Solomon codes are very useful for errors that occur in bursts.
Instead of looking at parity bits over stream of binary values, we can consider the data as a stream of elements of a Finite Field.
These elements are then considered coefficients of a polynomial.

We evaluate this polynomial at a number of distinct points higher than the size of original data block - which becomes our code.
This buffer giving us the ability to lose information & still be able to reconstruct all the polynomial coefficients.


In [None]:
# RS(7,3) over FiniteField(929) and evaluation points [0..7)

mod929 = lambda i: ModInt(i, 929)

generator = [[mod929(a**i) for i in range(3)] for a in range(7)]

# encoded polynomial is (123 + 456x + 789x^2) computed using generator matrix over [0..7)
data = list(map(mod929, [123, 456, 789]))
code = dot(generator, data)

for i in range(7):
    for j in range(i + 1, 7):
        for k in range(j + 1, 7):
            # Any 3 out of 7 code elements are sufficient to recover all 3 data elements
            assert data == inv(
                [generator[i], generator[j], generator[k]],
                [code[i], code[j], code[k]])


##### Berlekamp-Welch Decoding

Instead of pure erasure correction with prior knowledge of missing evaluation points,
Reed-Solomon codes can also (theoretically) be used for error correction by decoding all possible (sufficient) combinations of code points & picking the result with majority vote. This is not efficient in practice.

An efficient error correction algorithm was discovered using the following idea,

Let $C'[x]$ represent the encoded polynomial evaluation points with potential errors.
We want to compute the correct polynomial $C(x)$ using this.
For this, we can introduce a special monic polynomial $E(x)$ which is zero only at error locations & non-zero everywhere else.

Then, the following equation holds for all evaluation points,

$$
E(x) \cdot C'[x] = E(x) \cdot C(x)
$$

Let $Q(x) = E(x) \cdot C(x)$, & expand for $RS(7, 3)$ assuming 2-errors,

$$
(e_0 + e_1x + x^2) \cdot C'[x] - \sum_{i=0}^{4} q_i x^i = 0
$$

$$
C'[x] e_0 + x C'[x] e_1 - \sum_{i=0}^{4} q_i x^i = -x^2 C'[x]
$$

Once we substitute the evaluation points $x$ & known evaluation results $C'[x]$, these equations are also solvable using Gaussian elimination. If there are less than 2 erroneous values, the system will not have a unique solution. In that case we can fall back to $E(x)$ & $Q(x)$ with lower degrees & repeat the process.

Once $E(x)$ & $Q(x)$ are known, we can recover original data either as $\frac{Q(x)}{E(x)}$ OR by simply erasing $C'[x]$ values for evaluation points $x$ where $E(x) = 0$ & inverting the transformation.



In [None]:

# RS(7,3) over FiniteField(929) and evaluation points [0..7)

mod929 = lambda i: ModInt(i, 929)

generator = [[mod929(a**i) for i in range(3)] for a in range(7)]

# encoded polynomial is (123 + 456x + 789x^2) computed using generator matrix over [0..7)
data = list(map(mod929, [123, 456, 789]))

from random import randrange

for er0 in range(7):
    for er1 in range(er0 + 1, 7):
        code = dot(generator, data)
        code[er0] += mod929(randrange(1, 929))
        code[er1] += mod929(randrange(1, 929))

        t = [
            list(map(mod929, [
                int(code[x]),
                int(code[x]) * x
            ] + [
                -x**i for i in range(5)
            ]))
            for x in range(7)
        ]
        y = [mod929(-x*x*int(code[x])) for x in range(7)]

        e = inv(t, y)[:2]
        for x in map(mod929, range(7)):
            assert (e[0] + x*e[1] + x*x).is_zero() == (int(x) == er0 or int(x) == er1)
