- 设置工作目录到项目根目录，确保相对路径（`models/` 等）可用。
- 固定随机种子，选择计算设备（GPU优先）。


In [1]:
import os
import sys
import numpy as np
import torch
from matplotlib import pyplot as plt

# Ensure working directory is project root (adjust if needed)
project_root = os.path.abspath(os.path.join(os.path.dirname("..")))
os.chdir(project_root)
print(f"CWD: {os.getcwd()}")

# Global seed and device
seed = np.random.RandomState(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")


CWD: z:\WorkSpace\osd-cnn\notebooks
Device: cuda


- 生成/加载校验矩阵 `H` 与生成矩阵 `G`，打印 `(n, k)` 与码率。
- 设置全局参数：`USE_APP_LLR_FOR_OSD`、`max_iter_nms`、训练SNR、batch大小等。
- 指定模型保存路径 `models/dia_cnn.pth`。


In [2]:
import torch.nn as nn
import torch.optim as optim
from pyldpc import make_ldpc, encode, decode, get_message
from torch.utils.data import DataLoader, TensorDataset

# LDPC code params
n = 128
d_v = 4
d_c = 8
H, G = make_ldpc(n, d_v, d_c, seed=seed, systematic=True, sparse=True)
n_code, k_info = G.shape
print(f"LDPC: n={n_code}, k={k_info}, R={k_info/n_code:.2f}")

# Global configs
USE_APP_LLR_FOR_OSD = False
max_iter_nms = 12
snr_train_db = 2.7
epochs = 20
batch_size = 256

# Model path
model_dir = os.path.join("models")
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, "dia_cnn.pth")


LDPC: n=128, k=67, R=0.52


- 构建Tanner图邻接，执行规范化最小和（NMS）迭代。
- 支持综合校验早停。`alpha` 可按论文调整。


In [3]:
import numpy as np

def decode_with_trajectory(H, y, snr, max_iter):
    """
    规范化最小和（NMS）解码，返回：
      - final_codeword: 最终硬判决码字（0/1）
      - llrs_trajectory: 形状 (n, max_iter + 1) 的后验LLR轨迹
    """
    m_checks, n_vars = H.shape

    # --- Tanner图邻接表构建 (这部分是正确的) ---
    if hasattr(H, 'nonzero'):
        rows, cols = H.nonzero()
    else: # For dense numpy array
        rows, cols = np.nonzero(H)
    check_to_vars = [[] for _ in range(m_checks)]
    var_to_checks = [[] for _ in range(n_vars)]
    for r, c in zip(rows, cols):
        check_to_vars[r].append(c)
        var_to_checks[c].append(r)

    # --- 初始LLR计算 (这部分是正确的) ---
    R = float(k_info) / float(n_code)
    EbN0_linear = 10.0 ** (float(snr) / 10.0)
    # 假设 y 是 BPSK (+1/-1) + 噪声, 转换为 LLR
    # pyldpc.encode返回的就是这种y
    channel_llr = 4.0 * R * EbN0_linear * y.astype(float)
    
    llrs_trajectory = np.zeros((n_vars, max_iter + 1), dtype=float)
    llrs_trajectory[:, 0] = channel_llr

    # --- 消息初始化 (这部分是正确的) ---
    v2c_msgs = np.zeros_like(H.toarray(), dtype=float)
    for r, c in zip(rows, cols):
        v2c_msgs[r, c] = channel_llr[c]
    c2v_msgs = np.zeros_like(v2c_msgs)
    
    alpha = 0.8  # 规范化因子

    # --- 迭代更新 ---
    a_posteriori = channel_llr.copy()
    for it in range(1, max_iter + 1):
        # --- 校验节点更新 (Check Node Update) ---
        for c in range(m_checks):
            connected_vars = check_to_vars[c]
            if not connected_vars: continue

            # 获取所有连接到此校验节点的消息
            incoming_msgs = [v2c_msgs[c, v] for v in connected_vars]
            
            # --- 关键Bug修正 ---
            for i, v_target in enumerate(connected_vars):
                # 提取除当前目标外的所有其它消息
                other_msgs = incoming_msgs[:i] + incoming_msgs[i+1:]
                
                # 计算符号乘积
                sign_prod = np.prod(np.sign(other_msgs))
                
                # 找到最小绝对值
                min_abs = min(np.abs(other_msgs)) if other_msgs else 0.0
                
                # 更新校验到变量的消息
                c2v_msgs[c, v_target] = alpha * sign_prod * min_abs

        # --- 变量节点更新 (Variable Node Update) ---
        for v in range(n_vars):
            connected_checks = var_to_checks[v]
            if not connected_checks: continue

            # 累加来自所有校验节点的消息
            sum_c2v = sum(c2v_msgs[c, v] for c in connected_checks)
            
            # 更新后验LLR
            a_posteriori[v] = channel_llr[v] + sum_c2v
            
            # 更新变量到校验节点的消息（外信息）
            for c in connected_checks:
                v2c_msgs[c, v] = a_posteriori[v] - c2v_msgs[c, v]

        llrs_trajectory[:, it] = a_posteriori

        hard_bits = (a_posteriori < 0).astype(int)
        if np.all((H.dot(hard_bits) % 2) == 0):
            break

    final_bits = (a_posteriori < 0).astype(int)
    return final_bits, llrs_trajectory

- 发送全零码字，通过AWGN信道多次采样；
- 对每次采样运行 NMS 并记录LLR轨迹；
- 仅当NMS失败时，收集每个比特的轨迹作为训练样本，以提升训练有效性。

In [4]:
def is_valid_codeword(H, c):
    if c is None:
        return False
    return np.all((H.dot(c) % 2) == 0)


def generate_training_data_from_failures(n_failures_target, snr_db, max_iter_nms):
    print(f"Generating training data from NMS failures at SNR={snr_db} dB...")
    X_train_list = []
    y_train_list = []

    # All-zero message
    v_message = np.zeros((k_info, 1))
    true_codeword = encode(G, v_message, snr_db).flatten()

    n_failures_found = 0
    n_sims_run = 0
    while n_failures_found < n_failures_target:
        # Generate one noisy observation for the all-zero codeword
        y_noisy = encode(G, v_message, snr_db, seed=seed).flatten()

        decoded_word, llr_trajectory = decode_with_trajectory(H, y_noisy, snr_db, max_iter=max_iter_nms)
        n_sims_run += 1
        if not is_valid_codeword(H, decoded_word):
            n_failures_found += 1
            for i in range(n_code):
                X_train_list.append(llr_trajectory[i, :])
                y_train_list.append(true_codeword[i])
        if n_sims_run % 5000 == 0:
            print(f"  Sims run: {n_sims_run}, Failures: {n_failures_found}/{n_failures_target}")

    X_train = np.array(X_train_list).reshape(-1, 1, max_iter_nms + 1)
    y_train = np.array(y_train_list).reshape(-1, 1, 1)
    print(f"Generated {X_train.shape[0]} samples from {n_failures_found} frame failures.\n")
    return X_train, y_train


- 一维卷积堆叠提取轨迹时序特征，输出为单一logit（bit为1的概率的对数几率）。
- 输入形状为 `(batch, 1, T)`，其中 `T = max_iter_nms + 1`。

In [5]:
class DIA_CNN_Model(nn.Module):
    def __init__(self, trajectory_length):
        super(DIA_CNN_Model, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Conv1d(in_channels=32, out_channels=1, kernel_size=3, padding='same')
        )
        self.flatten = nn.Flatten()
        self.dense = nn.Linear(1 * trajectory_length, 1)

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.flatten(x)
        x = self.dense(x)
        return x

# Instantiate
max_iter_nms = max_iter_nms
model = DIA_CNN_Model(trajectory_length=max_iter_nms + 1).to(device)
print("Model ready")


Model ready


- 如果存在 `models/dia_cnn.pth` 则直接加载并跳过训练。
- 否则：仅从NMS失败帧采样生成训练数据并训练，训练后保存权重，随后切换到推理模式。

In [6]:
if os.path.exists(model_path):
    print(f"Found existing model at {model_path}. Loading...")
    state = torch.load(model_path, map_location=device)
    model.load_state_dict(state)
    model.to(device)
    model.eval()
    print("Loaded.\n")
else:
    print("No saved model. Generating failures and training...")
    X_train_np, y_train_np = generate_training_data_from_failures(
        n_failures_target=4000, snr_db=snr_train_db, max_iter_nms=max_iter_nms
    )
    X_train = torch.from_numpy(X_train_np).float()
    y_train = torch.from_numpy(y_train_np).float()

    train_dataset = TensorDataset(X_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    model.train()
    for epoch in range(20):
        epoch_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.view(-1, 1).to(device)
            optimizer.zero_grad()
            logits = model(inputs)
            loss = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch {epoch+1}/20, Loss: {epoch_loss/len(train_loader):.4f}")

    torch.save(model.state_dict(), model_path)
    model.eval()
    print(f"Saved model to {model_path}.\n")


No saved model. Generating failures and training...
Generating training data from NMS failures at SNR=2.7 dB...


AttributeError: 'numpy.ndarray' object has no attribute 'toarray'

- `osd_rescue_order3`：按可靠度对最不可靠比特进行阶数1/2/3翻转救援（受 `L` 和 `max_triplets` 控制）。
- `standard_osd_decoder`：可配置使用信道LLR或NMS最后一轮后验LLR作为可靠度。
- `cnn_enhanced_decoder`：利用CNN输出的比特错误概率构造可靠度并执行OSD救援。

In [None]:
from itertools import combinations

def osd_rescue_order3(H, base_decision, reliability_vector, L=12, max_triplets=100):
    idx_sorted = np.argsort(reliability_vector)
    least_reliable = idx_sorted[:min(L, len(idx_sorted))]
    if is_valid_codeword(H, base_decision):
        return base_decision
    # Order-1
    for i in least_reliable:
        cand = base_decision.copy(); cand[i] ^= 1
        if is_valid_codeword(H, cand):
            return cand
    # Order-2
    for i, j in combinations(least_reliable, 2):
        cand = base_decision.copy(); cand[i] ^= 1; cand[j] ^= 1
        if is_valid_codeword(H, cand):
            return cand
    # Order-3 (limited)
    triplets_checked = 0
    for i, j, k in combinations(least_reliable, 3):
        cand = base_decision.copy(); cand[i] ^= 1; cand[j] ^= 1; cand[k] ^= 1
        if is_valid_codeword(H, cand):
            return cand
        triplets_checked += 1
        if triplets_checked >= max_triplets:
            break
    return base_decision


def baseline_nms_decoder(H, y, snr):
    return decode(H, y, snr, maxiter=max_iter_nms)


def standard_osd_decoder(H, y, snr, predecoded=None, llr_traj=None):
    decoded = predecoded if predecoded is not None else baseline_nms_decoder(H, y, snr)
    if is_valid_codeword(H, decoded):
        return decoded
    reliability_vec = np.abs(llr_traj[:, -1]) if (USE_APP_LLR_FOR_OSD and llr_traj is not None) else np.abs(y)
    return osd_rescue_order3(H, decoded, reliability_vec)


def cnn_enhanced_decoder(H, y, snr, model, dev, predecoded=None, llr_traj=None):
    if predecoded is None or llr_traj is None:
        decoded, llr_trajectory = decode_with_trajectory(H, y, snr, max_iter=max_iter_nms)
    else:
        decoded, llr_trajectory = predecoded, llr_traj
    if is_valid_codeword(H, decoded):
        return decoded
    model.eval()
    with torch.no_grad():
        traj_tensor = torch.from_numpy(llr_trajectory.T).float().reshape(n_code, 1, -1).to(dev)
        logits = model(traj_tensor)
        error_probs = torch.sigmoid(logits).cpu().numpy().flatten()
    reliability_metric = 1.0 - error_probs
    return osd_rescue_order3(H, decoded, reliability_metric)

- 批量生成 `n_trials` 帧的接收向量，逐帧运行一次 NMS，收集失败帧。
- 标准OSD：仅对失败帧进行救援（可靠度可选信道LLR/后验LLR）。
- CNN-OSD：对失败帧的比特轨迹做一次（分块）前向，得到错误概率并执行OSD救援。
- 统计三条曲线的BER并绘图。

In [None]:
import time

snrs_db = np.arange(2.0, 4.5, 0.5)
n_trials = 20000
ber_baseline, ber_std_osd, ber_cnn_osd = [], [], []

v_messages = seed.randint(2, size=(k_info, n_trials))

for snr in snrs_db:
    start_time = time.time()
    total_errors_baseline, total_errors_std_osd, total_errors_cnn_osd = 0, 0, 0

    y_noisy_batch = encode(G, v_messages, snr, seed=seed)

    failed_indices, failed_trajectories, failed_decoded_words = [], [], []

    # Stage 1: NMS and Standard OSD
    for i in range(n_trials):
        y_col, v_col = y_noisy_batch[:, i], v_messages[:, i]
        d_nms, llr_traj = decode_with_trajectory(H, y_col, snr, max_iter=max_iter_nms)

        errors_this = np.count_nonzero(get_message(G, d_nms) != v_col)
        total_errors_baseline += errors_this

        if not is_valid_codeword(H, d_nms):
            failed_indices.append(i)
            failed_trajectories.append(llr_traj)
            failed_decoded_words.append(d_nms)

            reliability_vec = np.abs(llr_traj[:, -1]) if USE_APP_LLR_FOR_OSD else np.abs(y_col)
            d_std_osd = osd_rescue_order3(H, d_nms, reliability_vec)
            total_errors_std_osd += np.count_nonzero(get_message(G, d_std_osd) != v_col)
        else:
            total_errors_std_osd += errors_this

    # Start CNN-OSD from baseline errors
    total_errors_cnn_osd = total_errors_baseline

    # Stage 2: CNN rescue on failed frames
    if len(failed_indices) > 0:
        print(f"SNR={snr:.2f} dB, NMS failures: {len(failed_indices)}/{n_trials}")
        batch_trajectories = np.array(failed_trajectories)
        flat_trajectories = batch_trajectories.reshape(-1, 1, max_iter_nms + 1)

        # Optional: chunking to avoid OOM
        with torch.no_grad():
            logits_list = []
            chunk = 32768  # tune if needed
            for start in range(0, flat_trajectories.shape[0], chunk):
                end = min(start + chunk, flat_trajectories.shape[0])
                batch_tensor = torch.from_numpy(flat_trajectories[start:end]).float().to(device)
                logits_list.append(model(batch_tensor).cpu())
            all_logits = torch.cat(logits_list, dim=0)
            all_error_probs = torch.sigmoid(all_logits).numpy().flatten()

        error_probs_per_frame = all_error_probs.reshape(len(failed_indices), n_code)
        for idx, i in enumerate(failed_indices):
            v_col = v_messages[:, i]
            base_decision = failed_decoded_words[idx]

            errors_base_frame = np.count_nonzero(get_message(G, base_decision) != v_col)
            total_errors_cnn_osd -= errors_base_frame

            reliability_metric = 1.0 - error_probs_per_frame[idx]
            d_cnn_osd = osd_rescue_order3(H, base_decision, reliability_metric)
            total_errors_cnn_osd += np.count_nonzero(get_message(G, d_cnn_osd) != v_col)

    ber_baseline.append(total_errors_baseline / (k_info * n_trials))
    ber_std_osd.append(total_errors_std_osd / (k_info * n_trials))
    ber_cnn_osd.append(total_errors_cnn_osd / (k_info * n_trials))

    elapsed = time.time() - start_time
    print(f"SNR={snr:.2f} dB finished in {elapsed:.1f}s.  NMS:{ber_baseline[-1]:.8e}  StdOSD:{ber_std_osd[-1]:.8e}  CNN-OSD:{ber_cnn_osd[-1]:.8e}")

plt.figure(figsize=(12, 8))
plt.semilogy(snrs_db, ber_baseline, 'o:', label='Baseline NMS')
plt.semilogy(snrs_db, ber_std_osd, 's--', label='NMS + Standard OSD')
plt.semilogy(snrs_db, ber_cnn_osd, '^-', label='NMS + CNN-Enhanced OSD')
plt.xlabel('SNR (dB)'); plt.ylabel('BER'); plt.grid(True, which='both', ls='--', alpha=0.5); plt.legend(); plt.ylim(1e-5, 1)
plt.show()
