In [None]:
import numpy as np

class SimpleBpDecoder:
    """
    Simple belief propagation (sum-product) decoder for LDPC codes.
    Supports both codeword decoding (mG=c) and syndrome-based error decoding (yH=s).
    """

    def __init__(self, H, max_iter=50, tol=1e-6):
        """
        Initialize decoder with parity-check matrix H.
        """
        self.H = np.array(H, dtype=np.uint8)
        self.m, self.n = self.H.shape
        # neighbor lists for variable-to-check and check-to-variable messages
        self.var_nodes = [np.where(self.H[:, j])[0] for j in range(self.n)]
        self.check_nodes = [np.where(self.H[i, :])[0] for i in range(self.m)]
        self.max_iter = max_iter
        self.tol = tol

    def decode(self, received=None, error_rate=None, llr=None, syndrome=None):
        """
        Decode using belief propagation.

        - For codeword decoding, provide 'received' bits and 'error_rate', or directly 'llr'.
          Returns decoded codeword.
        - For syndrome decoding, provide 'syndrome' and 'error_rate' (or 'llr').
          Returns estimated error pattern y such that H@y = syndrome (mod 2).
        """
        # Syndrome-based error decoding
        if syndrome is not None:
            s = np.array(syndrome, dtype=np.uint8)
            # initialize LLR for error bits
            if llr is None:
                if error_rate is None:
                    raise ValueError("Provide 'syndrome' with either llr or error_rate.")
                p = error_rate
                if not 0 < p < 1:
                    raise ValueError("error_rate must be in (0,1).")
                llr = np.full(self.n, np.log((1 - p) / p), dtype=float)
            else:
                llr = np.array(llr, dtype=float)
                if llr.shape[0] != self.n:
                    raise ValueError(f"LLR length must be {self.n}.")
            # messages
            msg_vc = {(i, j): llr[j] for i in range(self.m) for j in self.check_nodes[i]}
            msg_cv = {(i, j): 0.0 for i in range(self.m) for j in self.check_nodes[i]}
            # iterate
            for _ in range(self.max_iter):
                # check-node update (include syndrome)
                for i in range(self.m):
                    for j in self.check_nodes[i]:
                        prod = 1.0
                        for jj in self.check_nodes[i]:
                            if jj != j:
                                prod *= np.tanh(msg_vc[(i, jj)] / 2.0)
                        # apply syndrome sign
                        prod *= -1 if s[i] else 1
                        prod = np.clip(prod, -0.999999, 0.999999)
                        msg_cv[(i, j)] = 2.0 * np.arctanh(prod)
                # variable-node update
                for j in range(self.n):
                    for i in self.var_nodes[j]:
                        total = llr[j] + sum(msg_cv[(ii, j)] for ii in self.var_nodes[j] if ii != i)
                        msg_vc[(i, j)] = total
                # posterior and decision
                llr_post = np.zeros(self.n)
                for j in range(self.n):
                    llr_post[j] = llr[j] + sum(msg_cv[(i, j)] for i in self.var_nodes[j])
                y_est = (llr_post < 0).astype(np.uint8)
                # check syndrome match
                if np.all(self.H.dot(y_est) % 2 == s):
                    return y_est
            return y_est

        # Codeword decoding
        # compute channel LLRs if needed
        if llr is None:
            if received is None or error_rate is None:
                raise ValueError("Provide either llr or both 'received' and 'error_rate'.")
            p = error_rate
            if not 0 < p < 1:
                raise ValueError("error_rate must be in (0,1).")
            received = np.array(received, dtype=np.uint8)
            llr = np.log((1 - p) / p) * (1 - 2 * received)
        else:
            llr = np.array(llr, dtype=float)
            if llr.shape[0] != self.n:
                raise ValueError(f"LLR length must be {self.n}.")
        # initialize messages
        msg_vc = {(i, j): llr[j] for i in range(self.m) for j in self.check_nodes[i]}
        msg_cv = {(i, j): 0.0 for i in range(self.m) for j in self.check_nodes[i]}
        for _ in range(self.max_iter):
            # check-node update
            for i in range(self.m):
                for j in self.check_nodes[i]:
                    prod = 1.0
                    for jj in self.check_nodes[i]:
                        if jj != j:
                            prod *= np.tanh(msg_vc[(i, jj)] / 2.0)
                    prod = np.clip(prod, -0.999999, 0.999999)
                    msg_cv[(i, j)] = 2.0 * np.arctanh(prod)
            # variable-node update
            for j in range(self.n):
                for i in self.var_nodes[j]:
                    total = llr[j] + sum(msg_cv[(ii, j)] for ii in self.var_nodes[j] if ii != i)
                    msg_vc[(i, j)] = total
            # posterior and decision
            llr_post = np.zeros(self.n)
            for j in range(self.n):
                llr_post[j] = llr[j] + sum(msg_cv[(i, j)] for i in self.var_nodes[j])
            decoded = (llr_post < 0).astype(np.uint8)
            # check parity
            if np.all(self.H.dot(decoded) % 2 == 0):
                return decoded
        return decoded

In [52]:
import numpy as np
import ldpc.codes
from ldpc import BpDecoder

H=ldpc.codes.rep_code(5) #parity check matrix for the length-3 repetition code
n=H.shape[1] #the codeword length

bpd = BpDecoder(
    H, #the parity check matrix
    error_rate=0.1, # the error rate on each bit
    max_iter=n, #the maximum iteration depth for BP
    bp_method="product_sum", #BP method. The other option is `minimum_sum'
)

In [61]:
received_vector=np.array([1,0,1,1,0])

In [62]:
s = H.dot(received_vector) % 2
print("Syndrome: ", s)

decoded_codeword=bpd.decode(s)

print('Estimated error pattern y:', decoded_codeword)

Syndrome:  [1 1 0 1]
Estimated error pattern y: [0 1 0 0 1]


In [63]:
H.toarray()

array([[1, 1, 0, 0, 0],
       [0, 1, 1, 0, 0],
       [0, 0, 1, 1, 0],
       [0, 0, 0, 1, 1]], dtype=uint8)

In [64]:
s = H.dot(received_vector) % 2
print("Syndrome: ", s)
decoder = SimpleBpDecoder(H.toarray(), max_iter=100)
y_est = decoder.decode(syndrome=s, error_rate=0.05)
print('Estimated error pattern y:', y_est)

Syndrome:  [1 1 0 1]
Estimated error pattern y: [0 1 0 0 1]
