In [28]:
import numpy as np
import reedsolo
from itertools import combinations

# ================================================================
# RS PARAMETERS
# ================================================================
n = 5      # RS symbols (bytes)
k = 3      # message symbols
rs = reedsolo.RSCodec(n - k)
t = (n - k) // 2   # = 1 symbol correctable



def construct_segments(P_X, M):
    """
    Given probability vector P_X and an integer M, 
    ignore the first element and allocate M segments 
    proportionally to the remaining probabilities.
    
    Returns:
        segments: np.array of length len(P_X)-1 containing integer allocations
    """
    # Remove first element
    P = np.array(P_X[1:], dtype=float)
    # Normalize remaining probabilities
    P = P / P.sum()
    # Initial proportional allocation
    raw = P * M
    # Round to integers
    segs = np.floor(raw).astype(int)
    # Correct rounding to ensure sum = M
    shortfall = M - segs.sum()
    # Distribute remaining segments to bins with largest fractional parts
    fractional_parts = raw - np.floor(raw)
    order = np.argsort(-fractional_parts)  # descending
    for idx in order[:shortfall]:
        segs[idx] += 1
    
    return segs

def allocate_bits_proportional_to_entropy(segments_per_bin, bitstream):
    """
    segments_per_bin: list[int] of length 11.
    bitstream: string of bits whose length = total_bits.

    Returns:
        mapping: dict[(bin_idx, seg_idx)] -> bit_subsequence (string)
    """
    import numpy as np

    segments_per_bin = np.array(segments_per_bin, dtype=int)
    total_bits = len(bitstream)

    # Representative entropies (midpoints): [1.5, 2.5, ..., 11.5]
    num_bins = len(segments_per_bin)
    left_edges = np.arange(1, num_bins + 1)
    right_edges = left_edges + 1
    bin_entropies = (left_edges + right_edges) / 2.0

    # Build segment list and weights
    weights = []
    keys = []
    for b, count in enumerate(segments_per_bin):
        for s in range(count):
            weights.append(bin_entropies[b])
            keys.append((b, s))

    weights = np.array(weights, dtype=float)
    if len(weights) == 0:
        return {}

    # Allocate lengths using same proportional logic as original
    W = weights.sum()
    ideal = total_bits * weights / W

    lengths = np.floor(ideal).astype(int)
    remainder = total_bits - lengths.sum()

    if remainder > 0:
        frac = ideal - lengths
        order = np.argsort(-frac)  # highest fractional part first
        for idx in order[:remainder]:
            lengths[idx] += 1

    # Now partition the bitstream into these lengths
    mapping = {}
    pos = 0
    for key, L in zip(keys, lengths):
        subseq = bitstream[pos : pos + L]
        mapping[key] = subseq
        pos += L

    return mapping


In [None]:
def get_segment_reliability(bit_length):
    """Linear reliability = 1 / (bit-length)."""
    return 1.0 / bit_length


def compute_symbol_reliability(segment_bits):
    """
    Given your segment_bits dictionary:
      - assign each bit reliability = 1 / (segment length)
      - fill bits sequentially into 48-bit block
      - convert to 6 RS symbol reliabilities (sum of each 8 bits)
    """

    TOTAL_BITS = 48       # RS(6,3) codeword = 6 bytes = 48 bits
    BITS_PER_SYMBOL = 8
    NUM_SYMBOLS = 6

    reliability_bits = np.zeros(TOTAL_BITS, dtype=float)

    bit_pos = 0
    for seg, bits in segment_bits.items():
        L = len(bits)              # segment length
        rel = 1.0 / L              # reliability per bit
        for i in range(L):
            if bit_pos < TOTAL_BITS:
                reliability_bits[bit_pos] = rel
            bit_pos += 1

    # if fewer than 48 bits assigned, fill rest with reliability 1.0
    if bit_pos < TOTAL_BITS:
        reliability_bits[bit_pos:TOTAL_BITS] = 1.0

    # ---- Convert bit reliabilities → symbol reliabilities ----
    symbol_rel = np.zeros(NUM_SYMBOLS)
    for i in range(NUM_SYMBOLS):
        start = i * BITS_PER_SYMBOL
        end   = (i+1) * BITS_PER_SYMBOL
        symbol_rel[i] = np.sum(reliability_bits[start:end])  # SUM, not mean

    return symbol_rel



# ================================================================
# BIT ↔ SYMBOL UTILITIES
# ================================================================
def bits_to_bytes(bits24):
    """Convert 24 bits → 3 bytes."""
    bits24 = np.array(bits24, dtype=np.uint8)
    if bits24.size != 24:
        raise ValueError("Message must be exactly 24 bits.")
    return np.packbits(bits24).astype(np.uint8)


def bytes_to_bits(bts):
    """Convert bytes → bits array."""
    return np.unpackbits(np.frombuffer(bts, dtype=np.uint8))


# ================================================================
# ENCODING
# ================================================================
def rs_encode_bits(bits24):
    """
    Input: 24 bits (numpy array of 0/1)
    Output: 48 bits (encoded RS(6,3) codeword)
    """
    msg_bytes = bits_to_bytes(bits24)   # 3 bytes
    encoded = rs.encode(msg_bytes.tobytes())  # 6 bytes
    return np.unpackbits(np.frombuffer(encoded, dtype=np.uint8))


# ================================================================
# OSD-1 + REED-SOLOMON DECODER
# ================================================================
def rs_osd1_decode(rs, received_symbols, reliability_symbols):
    """
    OSD-1 (Chase-1) soft-decision decoding for RS.
    """
    n = len(received_symbols)
    least_rel = np.argsort(reliability_symbols)

    # Order-0 candidate
    candidates = [received_symbols]

    # OSD-1 flips for L least reliable positions
    L = 2
    for pos in least_rel[:L]:
        cand = received_symbols.copy()
        cand[pos] ^= 0xFF   # flip all bits in symbol
        candidates.append(cand)

    # Try all candidates
    best_msg = None
    best_metric = 9e99

    for cand in candidates:
        try:
            decoded = rs.decode(bytes(cand))[0]  # message bytes
        except reedsolo.ReedSolomonError:
            continue

        # soft metric: weighted differences
        dist = 0
        for i in range(n):
            if cand[i] != received_symbols[i]:
                dist += 1.0 / (reliability_symbols[i] + 1e-9)

        if dist < best_metric:
            best_metric = dist
            best_msg = decoded

    return best_msg


# ================================================================
# DECODING FROM RECEIVED BITS
# ================================================================
def rs_decode_bits(received_bits48, segment_bits):
    """
    Input: 48 bits (received)
    Output: 24 bits (decoded message)
    """
    rec_bits = np.array(received_bits48, dtype=np.uint8)
    if rec_bits.size != 48:
        raise ValueError("Received codeword must be 48 bits.")

    # Convert to 6 symbols
    rec_symbols = np.packbits(rec_bits).astype(np.uint8)

    bitstream = "".join(random.choice("01") for _ in range(BITSTREAM_LEN)) 

    segments_per_bin = construct_segments(P_X, 15) 

    segment_bits = allocate_bits_proportional_to_entropy( segments_per_bin, bitstream ) 

    # Compute symbol-level reliability from segment definitions
    reliability = compute_symbol_reliability(segment_bits)

    print("Symbol reliabilities:", reliability)

    # OSD + RS decode
    decoded_bytes = rs_osd1_decode(rs, rec_symbols, reliability)
    if decoded_bytes is None:
        raise RuntimeError("Decoding failed.")

    # Convert to bits
    return np.unpackbits(np.frombuffer(decoded_bytes, dtype=np.uint8))

In [26]:
rng = np.random.default_rng(1)
msg_bits = rng.integers(0, 2, size=24, dtype=np.uint8)
print("Message bits (24):", msg_bits)

# ----------- 2. Encode to 48-bit RS codeword ------------------
code_bits = rs_encode_bits(msg_bits)
print("Encoded bits (48):", code_bits)

# ----------- 3. Inject BIT errors -----------------------------
noisy = code_bits.copy()
flip_positions = [34, 35, 36]      # try staying inside one symbol region
for p in flip_positions:
    noisy[p] ^= 1

print("Received bits difference:", noisy-code_bits)

# ----------- 4. Decode via OSD-1 + RS ------------------------
decoded_bits = rs_decode_bits(noisy, segment_bits)
print("Decoded bits:", decoded_bits)

print("Correct decode? ", np.array_equal(decoded_bits, msg_bits))

Message bits (24): [1 1 0 0 1 1 0 1 0 1 0 1 1 1 0 1 1 0 1 0 1 1 1 0]
Encoded bits (48): [1 1 0 0 1 1 0 1 0 1 0 1 1 1 0 1 1 0 1 0 1 1 1 0 0 0 0 1 1 1 1 0 0 0 1 0 0
 0 0 0 0 0 0 0 0 0 0 0]
Received bits difference: [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 255   1
   1   0   0   0   0   0   0   0   0   0   0   0]
Symbol reliabilities: [4.  3.  2.5 2.  1.9 1.6]
Decoded bits: [1 1 0 0 1 1 0 1 0 1 0 1 1 1 0 1 1 0 1 0 1 1 1 0]
Correct decode?  True
