In [73]:
import numpy as np
import reedsolo
from itertools import combinations
import random

# ================================================================
# RS PARAMETERS
# ================================================================
n = 5      # RS symbols (bytes)
k = 3      # message symbols
rs = reedsolo.RSCodec(n - k)
t = (n - k) // 2   # = 1 symbol correctable

P_X = np.array([0.042, 0.02, 0.024, 0.04, 0.07, 0.121, 0.139, 0.144, 0.14, 0.12, 0.068, 0.072])

def construct_segments(P_X, M):
    """
    Given probability vector P_X and an integer M, 
    ignore the first element and allocate M segments 
    proportionally to the remaining probabilities.
    
    Returns:
        segments: np.array of length len(P_X)-1 containing integer allocations
    """
    # Remove first element
    P = np.array(P_X[1:], dtype=float)  
    # Normalize remaining probabilities
    P = P / P.sum()
    # Initial proportional allocation
    raw = P * M
    # Round to integers
    segs = np.floor(raw).astype(int)
    # Correct rounding to ensure sum = M
    shortfall = M - segs.sum()
    # Distribute remaining segments to bins with largest fractional parts
    fractional_parts = raw - np.floor(raw)
    order = np.argsort(-fractional_parts)  # descending

    for idx in order[:shortfall]:
        segs[idx] += 1
    
    return segs


def allocate_bits_proportional_to_entropy(segments_per_bin, bitstream):
    """
    segments_per_bin: list[int] of length 11.
    bitstream: string of bits whose length = total_bits.

    Returns:
        mapping: dict[(bin_idx, seg_idx)] -> bit_subsequence (string)
    """
    import numpy as np

    segments_per_bin = np.array(segments_per_bin, dtype=int)
    total_bits = len(bitstream)

    # Representative entropies (midpoints): [1.5, 2.5, ..., 11.5]
    num_bins = len(segments_per_bin)
    left_edges = np.arange(1, num_bins + 1)
    right_edges = left_edges + 1
    bin_entropies = (left_edges + right_edges) / 2.0

    # Build segment list and weights
    weights = []
    keys = []
    for b, count in enumerate(segments_per_bin):
        for s in range(count):
            weights.append(bin_entropies[b])
            keys.append((b, s))

    weights = np.array(weights, dtype=float)
    if len(weights) == 0:
        return {}

    # Allocate lengths using same proportional logic as original
    W = weights.sum()
    ideal = total_bits * weights / W

    lengths = np.floor(ideal).astype(int)
    remainder = total_bits - lengths.sum()

    if remainder > 0:
        frac = ideal - lengths
        order = np.argsort(-frac)  # highest fractional part first
        for idx in order[:remainder]:
            lengths[idx] += 1

    # Now partition the bitstream into these lengths
    mapping = {}
    pos = 0
    for key, L in zip(keys, lengths):
        subseq = bitstream[pos : pos + L]
        mapping[key] = subseq
        pos += L

    return mapping


In [74]:
def get_segment_reliability(bit_length):
    """Linear reliability = 1 / (bit-length)."""
    return 1.0 / bit_length


def compute_symbol_reliability(segment_bits):
    """
    Given your segment_bits dictionary:
      - assign each bit reliability = 1 / (segment length)
      - fill bits sequentially into 48-bit block
      - convert to 6 RS symbol reliabilities (sum of each 8 bits)
    """

    TOTAL_BITS = 40       # RS(6,3) codeword = 6 bytes = 48 bits
    BITS_PER_SYMBOL = 8
    NUM_SYMBOLS = 5

    reliability_bits = np.zeros(TOTAL_BITS, dtype=float)

    bit_pos = 0
    for seg, bits in segment_bits.items():
        L = len(bits)              # segment length
        rel = 1.0 / L              # reliability per bit
        print(L)
        for i in range(L):
            if bit_pos < TOTAL_BITS:
                reliability_bits[bit_pos] = rel
            bit_pos += 1

    # if fewer than 48 bits assigned, fill rest with reliability 1.0
    if bit_pos < TOTAL_BITS:
        reliability_bits[bit_pos:TOTAL_BITS] = 1.0

    # ---- Convert bit reliabilities → symbol reliabilities ----
    symbol_rel = np.zeros(NUM_SYMBOLS)
    for i in range(NUM_SYMBOLS):
        start = i * BITS_PER_SYMBOL
        end   = (i+1) * BITS_PER_SYMBOL
        symbol_rel[i] = np.sum(reliability_bits[start:end])  # SUM, not mean

    return symbol_rel



# ================================================================
# BIT ↔ SYMBOL UTILITIES
# ================================================================
def bits_to_bytes(bits24):
    """Convert 24 bits → 3 bytes."""
    bits24 = np.array(bits24, dtype=np.uint8)
    if bits24.size != 24:
        raise ValueError("Message must be exactly 24 bits.")
    return np.packbits(bits24).astype(np.uint8)


def bytes_to_bits(bts):
    """Convert bytes → bits array."""
    return np.unpackbits(np.frombuffer(bts, dtype=np.uint8))


# ================================================================
# ENCODING
# ================================================================
def rs_encode_bits(bits24):
    """
    Input: 24 bits (numpy array of 0/1)
    Output: 48 bits (encoded RS(6,3) codeword)
    """
    msg_bytes = bits_to_bytes(bits24)   # 3 bytes
    encoded = rs.encode(msg_bytes.tobytes())  # 6 bytes
    return np.unpackbits(np.frombuffer(encoded, dtype=np.uint8))


# ================================================================
# OSD-1 + REED-SOLOMON DECODER
# ================================================================
def rs_osd1_decode(rs, received_symbols, reliability_symbols):
    """
    OSD-1 (Chase-1) soft-decision decoding for RS.
    """
    n = len(received_symbols)
    least_rel = np.argsort(reliability_symbols)

    # Order-0 candidate
    candidates = [received_symbols]

    # OSD-1 flips for L least reliable positions
    L = 2
    for pos in least_rel[:L]:
        cand = received_symbols.copy()
        cand[pos] ^= 0xFF   # flip all bits in symbol
        candidates.append(cand)

    # Try all candidates
    best_msg = None
    best_metric = 9e99

    for cand in candidates:
        try:
            decoded = rs.decode(bytes(cand))[0]  # message bytes
        except reedsolo.ReedSolomonError:
            continue

        # soft metric: weighted differences
        dist = 0
        for i in range(n):
            if cand[i] != received_symbols[i]:
                dist += 1.0 / (reliability_symbols[i] + 1e-9)

        if dist < best_metric:
            best_metric = dist
            best_msg = decoded

    return best_msg


# ================================================================
# DECODING FROM RECEIVED BITS
# ================================================================
def rs_decode_bits(received_bits48, segment_bits):
    """
    Input: 48 bits (received)
    Output: 24 bits (decoded message)
    """
    rec_bits = np.array(received_bits48, dtype=np.uint8)
    if rec_bits.size != 40:
        raise ValueError("Received codeword must be 48 bits.")

    # Convert to 6 symbols
    rec_symbols = np.packbits(rec_bits).astype(np.uint8)

    # Compute symbol-level reliability from segment definitions
    reliability = compute_symbol_reliability(segment_bits)

    print("Symbol reliabilities:", reliability)

    # OSD + RS decode
    decoded_bytes = rs_osd1_decode(rs, rec_symbols, reliability)
    if decoded_bytes is None:
        raise RuntimeError("Decoding failed.")

    # Convert to bits
    return np.unpackbits(np.frombuffer(decoded_bytes, dtype=np.uint8))

In [75]:
rng = np.random.default_rng(1)
bitstream = "".join(random.choice("01") for _ in range(24)) 

segments_per_bin = construct_segments(P_X, 15) 

print("Message bits (24):", msg_bits)

# ----------- 2. Encode to 48-bit RS codeword ------------------
code_bits = rs_encode_bits(msg_bits)
print("Encoded bits (48):", code_bits)

segment_bits = allocate_bits_proportional_to_entropy( segments_per_bin, code_bits) 

# ----------- 3. Inject BIT errors -----------------------------
noisy = code_bits.copy()
flip_positions = [34, 35, 36]      # try staying inside one symbol region
for p in flip_positions:
    noisy[p] ^= 1

print("Received bits difference:", noisy-code_bits)

# ----------- 4. Decode via OSD-1 + RS ------------------------
decoded_bits = rs_decode_bits(noisy, segment_bits)
print("Decoded bits:", decoded_bits)

print("Correct decode? ", np.array_equal(decoded_bits, msg_bits))

Message bits (24): [1 1 0 0 1 1 0 1 0 1 0 1 1 1 0 1 1 0 1 0 1 1 1 0]
Encoded bits (48): [1 1 0 0 1 1 0 1 0 1 0 1 1 1 0 1 1 0 1 0 1 1 1 0 0 0 0 1 1 1 1 0 0 0 1 0 0
 0 0 0]
Received bits difference: [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 255   1
   1   0   0   0]
1
1
2
2
2
2
2
3
3
3
3
4
4
4
4
Symbol reliabilities: [5.         3.33333333 2.66666667 2.         2.        ]
Decoded bits: [1 1 0 0 1 1 0 1 0 1 0 1 1 1 0 1 1 0 1 0 1 1 1 0]
Correct decode?  True


In [76]:
def get_topK_candidates(segment_stat, seg_id, K=3):
    """
    Return the top-K most likely RS-symbol candidates for segment seg_id=(n,k).
    Each bitstring from 'counts' is treated as big-endian bits.
    """
    if seg_id not in segment_stat:
        return []  # unknown segment → no candidates

    counts = segment_stat[seg_id]['counts']

    # Sort bitstrings descending by count
    sorted_bits = sorted(counts.items(), key=lambda x: -x[1])

    topK = []
    for bits, cnt in sorted_bits[:K]:
        # Convert bitstring → integer symbol (GF(256))
        sym = int(bits, 2)
        topK.append(sym)

    return topK

def get_symbol_candidates_from_stats(segment_stats, symbol_index, bits_per_symbol=8, K=3):
    """
    Generate K^S candidate RS symbols for symbol_index, using only segment_stats.
    Segments appear in chronological order. No external mapping needed.
    """

    # total number of bits in full message (concatenate segment lengths)
    total_bits = sum(info["L"] for info in segment_stats.values())

    start = symbol_index * bits_per_symbol
    end   = (symbol_index + 1) * bits_per_symbol

    if start >= total_bits:
        return []

    # clamp end for last symbol
    end = min(end, total_bits)
    sym_len = end - start

    # ----- find segments that overlap this symbol -----
    contributing = []   # (seg_id, local_start, take_len)
    current_bitpos = 0

    for seg_id, info in segment_stats.items():
        L = info["L"]
        seg_start = current_bitpos
        seg_end   = current_bitpos + L

        overlap_start = max(seg_start, start)
        overlap_end   = min(seg_end,   end)

        if overlap_end > overlap_start:
            local_start = overlap_start - seg_start
            take_len    = overlap_end   - overlap_start
            contributing.append((seg_id, local_start, take_len))

        current_bitpos += L
        if current_bitpos >= end:
            break

    if not contributing:
        return []

    # ----- for each contributing segment, get top K bitstring slices -----
    seg_bit_cands = []

    for seg_id, local_start, take_len in contributing:
        stat = segment_stats.get(seg_id)
        if stat is None:
            return []

        counts = stat["counts"]    # {bitstring: count}
        topK = sorted(counts.items(), key=lambda x: -x[1])[:K]

        slices = []
        for bits, cnt in topK:
            if len(bits) >= local_start + take_len and cnt > 0:
                slices.append(bits[local_start : local_start + take_len])

        if not slices:
            return []

        seg_bit_cands.append(slices)

    # ----- Cartesian product => K^S candidate full-symbol bitstrings -----
    candidates = []
    for combo in itertools.product(*seg_bit_cands):
        full_bits = "".join(combo)

        # pad if last symbol shorter than 8 bits
        if len(full_bits) < bits_per_symbol:
            full_bits = full_bits.ljust(bits_per_symbol, "0")

        candidates.append(int(full_bits, 2))

    return candidates

In [72]:
import itertools
seg_stat = {(3, 0): {'L': 2, 'trials': 12, 'counts': {'00': 0, '01': 4, '10': 2, '11': 6}, 'best_bits': '11'}, (4, 0): {'L': 3, 'trials': 22, 'counts': {'000': 2, '001': 12, '010': 1, '011': 3, '100': 1, '101': 0, '110': 2, '111': 1}, 'best_bits': '001'}, (5, 0): {'L': 3, 'trials': 28, 'counts': {'000': 0, '001': 1, '010': 3, '011': 0, '100': 4, '101': 13, '110': 3, '111': 4}, 'best_bits': '101'}, (6, 0): {'L': 4, 'trials': 21, 'counts': {'0000': 0, '0001': 0, '0010': 1, '0011': 0, '0100': 1, '0101': 9, '0110': 2, '0111': 0, '1000': 1, '1001': 3, '1010': 0, '1011': 1, '1100': 1, '1101': 1, '1110': 1, '1111': 0}, 'best_bits': '0101'}, (6, 1): {'L': 4, 'trials': 19, 'counts': {'0000': 2, '0001': 1, '0010': 1, '0011': 0, '0100': 1, '0101': 0, '0110': 1, '0111': 0, '1000': 0, '1001': 1, '1010': 1, '1011': 1, '1100': 0, '1101': 7, '1110': 2, '1111': 1}, 'best_bits': '1101'}, (7, 0): {'L': 4, 'trials': 11, 'counts': {'0000': 0, '0001': 1, '0010': 0, '0011': 0, '0100': 0, '0101': 0, '0110': 0, '0111': 1, '1000': 1, '1001': 1, '1010': 6, '1011': 0, '1100': 0, '1101': 0, '1110': 0, '1111': 1}, 'best_bits': '1010'}, (7, 1): {'L': 4, 'trials': 16, 'counts': {'0000': 0, '0001': 2, '0010': 0, '0011': 0, '0100': 0, '0101': 1, '0110': 0, '0111': 0, '1000': 1, '1001': 0, '1010': 2, '1011': 0, '1100': 1, '1101': 0, '1110': 8, '1111': 1}, 'best_bits': '1110'}, (8, 0): {'L': 5, 'trials': 27, 'counts': {'00000': 0, '00001': 0, '00010': 1, '00011': 8, '00100': 2, '00101': 0, '00110': 0, '00111': 1, '01000': 1, '01001': 0, '01010': 1, '01011': 1, '01100': 1, '01101': 1, '01110': 0, '01111': 0, '10000': 0, '10001': 1, '10010': 0, '10011': 3, '10100': 0, '10101': 2, '10110': 1, '10111': 0, '11000': 0, '11001': 0, '11010': 0, '11011': 0, '11100': 0, '11101': 2, '11110': 0, '11111': 1}, 'best_bits': '00011'}, (9, 0): {'L': 5, 'trials': 12, 'counts': {'00000': 0, '00001': 0, '00010': 0, '00011': 0, '00100': 0, '00101': 1, '00110': 0, '00111': 0, '01000': 0, '01001': 0, '01010': 0, '01011': 0, '01100': 1, '01101': 0, '01110': 0, '01111': 0, '10000': 0, '10001': 1, '10010': 0, '10011': 0, '10100': 0, '10101': 0, '10110': 0, '10111': 1, '11000': 4, '11001': 0, '11010': 0, '11011': 0, '11100': 2, '11101': 1, '11110': 0, '11111': 1}, 'best_bits': '11000'}, (10, 0): {'L': 6, 'trials': 6, 'counts': {'000000': 0, '000001': 0, '000010': 0, '000011': 0, '000100': 0, '000101': 0, '000110': 0, '000111': 0, '001000': 0, '001001': 0, '001010': 0, '001011': 0, '001100': 0, '001101': 0, '001110': 0, '001111': 0, '010000': 0, '010001': 0, '010010': 0, '010011': 0, '010100': 0, '010101': 1, '010110': 0, '010111': 0, '011000': 0, '011001': 0, '011010': 0, '011011': 0, '011100': 0, '011101': 0, '011110': 0, '011111': 0, '100000': 1, '100001': 0, '100010': 1, '100011': 0, '100100': 1, '100101': 0, '100110': 1, '100111': 0, '101000': 0, '101001': 0, '101010': 0, '101011': 0, '101100': 0, '101101': 0, '101110': 0, '101111': 0, '110000': 0, '110001': 0, '110010': 0, '110011': 0, '110100': 0, '110101': 0, '110110': 0, '110111': 0, '111000': 0, '111001': 0, '111010': 0, '111011': 0, '111100': 1, '111101': 0, '111110': 0, '111111': 0}, 'best_bits': '010101'}}
print(get_topK_candidates(seg_stat, (4,0), K=3))  # Example usage

candidate = get_symbol_candidates_from_stats(seg_stat, 4, bits_per_symbol=8, K=2)
print("Symbol candidates for symbol 1:", candidate)

[1, 3, 0]
Symbol candidates for symbol 1: [21, 32, 21, 32]
