In [1]:
import os
import json
import copy
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from dataset import PolarDecDataset
from models.wrappers.mamba_32bits import MambaPolarDecoder

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

CONFIG_NO = 5
N = 32


  from .autonotebook import tqdm as notebook_tqdm


Device: cuda


In [None]:

snr_curriculum = [
    7.8, 7.5, 7.3, 7, 6.8, 6.5, 6.3, 6, 5.8, 5.5, 5.2,
    5, 4.8, 4.5, 4.2, 4, 3.8, 3.5, 3.2, 3, 2.8,
    2.5, 2.2, 2, 1.8, 1.5, 1.2, 1, 0.8, 0.5, 0.3, 0
]

test_snr_list = [10,5,0]

num_train_samples = 200000
num_test_samples = 4000
batch_size = 32


In [None]:
train_dataset = PolarDecDataset(
    snr_list=[10, 9, 8],  
    num_samples=num_train_samples,
    seq_length=N
)

test_dataset = PolarDecDataset(
    snr_list=test_snr_list,
    num_samples=num_test_samples,
    seq_length=N
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


In [None]:
model = MambaPolarDecoder(
    d_model=64,
    num_layer_encoder=1,
    num_layers_bimamba_block=20,
    seq_len=N,
    d_state=32,
    d_conv=6,
    expand=2,
).to(device)


In [5]:
ckpt_path = "./checkpoints/config_2_firstbiblock20/model_best_block_epoch_5_snr_8.pt"
ckpt = torch.load(ckpt_path, map_location=device,weights_only=False)
model.load_state_dict(ckpt["state_dict"])

print("Loaded pretrained model from:", ckpt_path)

Loaded pretrained model from: ./checkpoints/config_2_firstbiblock20/model_best_block_epoch_5_snr_8.pt


In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=3, min_lr=1e-6
)

loss_fn = nn.BCEWithLogitsLoss()


In [7]:
def calc_save_ber(
    model,
    device,
    snr_values,
    msg_bit_sizes,
    num_samples=3200,
    batch_size=32,
    config_no=0,
    epoch=0,
    save_dir="src/evaluation"
):
    model.eval()
    os.makedirs(save_dir, exist_ok=True)
    config_dir = os.path.join(save_dir, f"config_{config_no}")
    os.makedirs(config_dir, exist_ok=True)

    eval_results = {}

    for snr in snr_values:
        print(f"\nEvaluating for SNR = {snr} dB\n")
        snr_key = f"{snr}_snr"
        eval_results[snr_key] = {}
        ber_list = []

        for msg_bits in msg_bit_sizes:
            total_msg_errors = 0
            total_frozen_errors = 0
            total_msg_bits = 0
            total_frozen_bits = 0

            test_set = PolarDecDataset(
                snr_list=[snr],
                num_samples=num_samples,
                fixed_msg_bit_size=msg_bits,
                seq_length=N
            )
            loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

            with torch.no_grad():
                for llrs, frozen, snr_t, target in loader:
                    llrs = llrs.to(device).float()
                    frozen = frozen.to(device).long()
                    target = target.to(device).long()

                    logits = model(llrs, frozen, snr_t.to(device))
                    predicted = (logits > 0).long()

                    mask_msg = frozen == 1
                    mask_frozen = frozen == 0

                    total_msg_errors += (target[mask_msg] != predicted[mask_msg]).sum().item()
                    total_frozen_errors += (target[mask_frozen] != predicted[mask_frozen]).sum().item()
                    total_msg_bits += mask_msg.sum().item()
                    total_frozen_bits += mask_frozen.sum().item()

            total_bits = total_msg_bits + total_frozen_bits
            total_errors = total_msg_errors + total_frozen_errors
            avg_net_ber = total_errors / total_bits

            ber_list.append(avg_net_ber)
            eval_results[snr_key][str(msg_bits)] = {
                "average_net_bit_error_rate": avg_net_ber
            }

        eval_results[snr_key]["overall_ber"] = sum(ber_list) / len(ber_list)

    json_file = os.path.join(config_dir, f"ber_epoch_{epoch+1}.json")
    with open(json_file, "w") as f:
        json.dump(eval_results, f, indent=4)

    return eval_results


In [8]:
# def compute_ber(preds, target):
#     pred_bits = (torch.sigmoid(preds) > 0.5).int()
#     total_bits = target.numel()
#     error_bits = (pred_bits != target.int()).sum().item()
#     return error_bits / total_bits

In [9]:
def train_one_epoch(batches=1000):
    model.train()
    running_loss = 0.0

    for i, (channel, frozen, snr, target) in enumerate(train_loader):
        channel = channel.float().to(device)
        frozen = frozen.int().to(device)
        target = target.float().to(device)

        optimizer.zero_grad()
        out = model(channel, frozen, snr.float().to(device))

        mask_msg = frozen == 1
        loss = loss_fn(out[mask_msg], target[mask_msg])

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item()

        if (i + 1) % batches == 0:
            return running_loss / batches

    return running_loss / (i + 1)


In [None]:
def train():
    best_val_ber = 1.0

    # Already trained up to SNR 8
    active_snrs = [10, 9, 8]

    for snr in snr_curriculum:
        print(f"\nTraining with new SNR {snr}")

        if snr not in active_snrs:
            active_snrs.append(snr)

        train_dataset.snr_list = active_snrs
        print("Active SNRs:", active_snrs)

        best_loss_in_block = float("inf")
        best_epoch_in_block = 0
        best_state_dict = None

        for epoch in range(6):
            avg_loss = train_one_epoch(batches=1000)
            print(f"SNR {snr} | Epoch {epoch+1}/6 | Loss {avg_loss:.6f}")

            if avg_loss < best_loss_in_block:
                best_loss_in_block = avg_loss
                best_epoch_in_block = epoch
                best_state_dict = copy.deepcopy(model.state_dict())

        model.load_state_dict(best_state_dict)

        val_results = calc_save_ber(
            model,
            device,
            snr_values=test_snr_list,
            msg_bit_sizes=[8, 16, 24],
            num_samples=3200,
            batch_size=32,
            config_no=CONFIG_NO,
            epoch=best_epoch_in_block,
            save_dir="src/evaluation"
        )

        avg_vloss = np.mean(
            [val_results[f"{s}_snr"]["overall_ber"] for s in test_snr_list]
        )

        scheduler.step(avg_vloss)

        model_dir = f"./checkpoints/config_{CONFIG_NO}"
        os.makedirs(model_dir, exist_ok=True)
        model_path = f"{model_dir}/model_best_block_epoch_{best_epoch_in_block+1}_snr_{snr}.pt"

        
        torch.save({
            "comments": "Removed the snr as input entirely. Use SNR linear if needed.",
            "model_config": {
                "d_model": model.d_model,
                "num_layer_encoder": model.num_layer_encoder,
                "num_layers_bimamba_block": model.num_layers_bimamba_block,
                "seq_len": model.seq_len,
                "d_state": model.d_state,
                "d_conv": model.d_conv,
                "expand": model.expand,
            },
            "epoch": best_epoch_in_block + 1,
            "train_loss": best_loss_in_block,
            "val_loss": avg_vloss,
            "state_dict": best_state_dict
        }, model_path)

        print(f"Saved checkpoint: {model_path}")

    print("\nTraining finished.")


In [11]:
train()


Training with new SNR 7.8
Active SNRs: [10, 9, 8, 7.8]
SNR 7.8 | Epoch 1/6 | Loss 0.042051
SNR 7.8 | Epoch 2/6 | Loss 0.042796
SNR 7.8 | Epoch 3/6 | Loss 0.043522
SNR 7.8 | Epoch 4/6 | Loss 0.042956
SNR 7.8 | Epoch 5/6 | Loss 0.042632
SNR 7.8 | Epoch 6/6 | Loss 0.043355

Evaluating for SNR = 10 dB


Evaluating for SNR = 5 dB


Evaluating for SNR = 0 dB

Saved checkpoint: ./checkpoints/config_5/model_best_block_epoch_1_snr_7.8.pt

Training with new SNR 7.5
Active SNRs: [10, 9, 8, 7.8, 7.5]
SNR 7.5 | Epoch 1/6 | Loss 0.049680
SNR 7.5 | Epoch 2/6 | Loss 0.051021
SNR 7.5 | Epoch 3/6 | Loss 0.049545
SNR 7.5 | Epoch 4/6 | Loss 0.052217
SNR 7.5 | Epoch 5/6 | Loss 0.048095
SNR 7.5 | Epoch 6/6 | Loss 0.047716

Evaluating for SNR = 10 dB


Evaluating for SNR = 5 dB


Evaluating for SNR = 0 dB

Saved checkpoint: ./checkpoints/config_5/model_best_block_epoch_6_snr_7.5.pt

Training with new SNR 7.3
Active SNRs: [10, 9, 8, 7.8, 7.5, 7.3]
SNR 7.3 | Epoch 1/6 | Loss 0.054672
SNR 7.3 | Epoch 2/6 | Los

In [None]:
import torch
from torch.utils.data import DataLoader
from dataset import PolarDecDataset
from models.wrappers.mamba_32bits import MambaPolarDecoder


device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

bce = torch.nn.BCEWithLogitsLoss()

def eccm_loss(logits, target_tensor):
    # logits already (B, seq_len)
    return bce(logits, target_tensor)


def get_estimated_codeword(logits):
    probs = torch.sigmoid(logits)
    return (probs > 0.5).float()


test_dataset = PolarDecDataset(
    snr_list=[10],
    num_samples=32,
    seq_length=32,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
)

ckpt_path = "./checkpoints/config_5_loadedfromconfig2snr8andmicrotrained_bestoneyet/model_best_block_epoch_5_snr_0.pt"
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)


model = MambaPolarDecoder(
    d_model=ckpt['model_config']['d_model'],
    num_layer_encoder=ckpt['model_config']['num_layer_encoder'],
    num_layers_bimamba_block=ckpt['model_config']['num_layers_bimamba_block'],
    seq_len=ckpt['model_config']['seq_len'],
    d_state=ckpt['model_config']['d_state'],
    d_conv=ckpt['model_config']['d_conv'],
    expand=ckpt['model_config']['expand'],
).to(device)

state_dict = ckpt.get("model_state_dict", ckpt.get("state_dict", ckpt))
model.load_state_dict(state_dict, strict=False)
model.eval()

print("Model loaded.")


# Run ONE batch

llrs, frozen_tensor, snr_tensor, target_tensor = next(iter(test_loader))

llrs = llrs.to(device)
frozen_tensor = frozen_tensor.long().to(device)   
target_tensor = target_tensor.to(device)

with torch.no_grad():
    logits = model(llrs, frozen_tensor)
    loss = eccm_loss(logits, target_tensor)

print("Loss:", loss.item())


# Decode first sample

actual_codeword = target_tensor[0]
decoded_codeword = get_estimated_codeword(logits[0])

print("\nActual bits:")
print("".join(str(int(i)) for i in actual_codeword.cpu()))

print("\nPredicted bits:")
print("".join(map(str, decoded_codeword.int().cpu().tolist())))


# BER for this batch

batch_hard = get_estimated_codeword(logits)
bit_errors = (batch_hard != target_tensor).sum().item()
total_bits = target_tensor.numel()
ber = bit_errors / total_bits

print("\nBatch BER:", ber)
print("Total bit errors:", bit_errors, "/", total_bits)


print("\nShapes:")
print("LLRs:", llrs.shape)
print("Frozen:", frozen_tensor.shape)
print("Logits:", logits.shape)
print("Target:", target_tensor.shape)



Device: cuda
Model loaded.
Loss: 0.020280059427022934

Actual bits:
11111100001010001000000010000000

Predicted bits:
11111100001010001000000010000000

Batch BER: 0.0009765625
Total bit errors: 1 / 1024

Shapes:
LLRs: torch.Size([32, 32])
Frozen: torch.Size([32, 32])
Logits: torch.Size([32, 32])
Target: torch.Size([32, 32])
