In [1]:
import os
import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import json
from dataset import PolarDecDataset 
from models.wrappers.mamba_32bits import MambaPolarDecoder

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

N = 32  
CONFIG_NO =  2

Using device: cuda


In [None]:
# snr_curriculum = [
#     10,
#     9,
#     8,
#     7,
#     6, 6,
#     5,
#     4,
#     3,
#     0, 0,
#     -3,
#     -6,
#     -10
# ]

# test_snr_list = [10, 6, 0, -3, -6, -10]
snr_curriculum = [
    10,
    9,
    8,
    7,
    6, 6,
    5,
    4,
    3,
    0, 0
]

test_snr_list = [10, 5, 0]

num_train_samples = 300000   
num_test_samples  = 4000 
batch_size = 32

In [None]:
train_dataset = PolarDecDataset(
    snr_list=[10],  
    num_samples=num_train_samples,
    seq_length=N
)

test_dataset = PolarDecDataset(
    snr_list=test_snr_list,
    num_samples=num_test_samples,
    seq_length=N
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size)

In [None]:

model = MambaPolarDecoder( 
    d_model=64,
    num_layer_encoder=1,
    num_layers_bimamba_block=22,
    seq_len=N,
    d_state=32,
    d_conv=6,
    expand=2,
).to(device)
# 


In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-6
)
loss_fn = nn.BCEWithLogitsLoss()

In [None]:
def calc_save_ber(
    model,
    device,
    snr_values,
    msg_bit_sizes,
    num_samples=3200,
    batch_size=32,
    config_no=0,
    epoch=0,
    save_dir="src/evaluation"
):
    model.eval()
    os.makedirs(save_dir, exist_ok=True)
    config_dir = os.path.join(save_dir, f"config_{config_no}")
    os.makedirs(config_dir, exist_ok=True)

    eval_results = {}

    for snr in snr_values:
        print(f"\nEvaluating for SNR = {snr} dB\n")
        snr_key = f"{snr}_snr"
        eval_results[snr_key] = {}
        ber_list = []

        for msg_bits in msg_bit_sizes:
            print(f"  Message bit size = {msg_bits}")
            total_msg_errors = 0
            total_frozen_errors = 0
            total_msg_bits = 0
            total_frozen_bits = 0

            test_set = PolarDecDataset(
                snr_list=[snr],
                num_samples=num_samples,
                fixed_msg_bit_size=msg_bits,
                seq_length=32
            )
            test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

            with torch.no_grad():
                for llrs, frozen_tensor, snr_tensor, target_tensor in test_loader:
                    llrs = llrs.to(device).float()
                    frozen_tensor = frozen_tensor.to(device).long()
                    target_tensor = target_tensor.to(device).long()

                    logits = model(llrs, frozen_tensor, snr_tensor.to(device))
                    predicted = (logits > 0).long()

                    mask_msg = (frozen_tensor == 1)
                    mask_frozen = (frozen_tensor == 0)

                    msg_target = target_tensor[mask_msg]
                    msg_pred = predicted[mask_msg]
                    frozen_target = target_tensor[mask_frozen]
                    frozen_pred = predicted[mask_frozen]

                    total_msg_errors += (msg_target != msg_pred).sum().item()
                    total_frozen_errors += (frozen_target != frozen_pred).sum().item()
                    total_msg_bits += msg_target.numel()
                    total_frozen_bits += frozen_target.numel()

            total_bits = total_msg_bits + total_frozen_bits
            total_errors = total_msg_errors + total_frozen_errors

            avg_net_ber = total_errors / total_bits
            avg_msg_ber = total_msg_errors / total_msg_bits if total_msg_bits > 0 else 0.0
            avg_frozen_ber = total_frozen_errors / total_frozen_bits if total_frozen_bits > 0 else 0.0

            print(f"    Net BER     : {avg_net_ber:.6e}")
            print(f"    Msg BER     : {avg_msg_ber:.6e}")
            print(f"    Frozen BER  : {avg_frozen_ber:.6e}\n")

            ber_list.append(avg_net_ber)
            eval_results[snr_key][str(msg_bits)] = {
                "average_net_bit_error_rate": avg_net_ber,
                "average_msg_bit_error_rate": avg_msg_ber,
                "average_frozen_bit_error_rate": avg_frozen_ber,
                "batch_size": batch_size,
                "num_samples": num_samples,
                "total_bits": total_bits,
                "total_error_bits": total_errors,
                "total_msg_bits": total_msg_bits,
                "total_frozen_bits": total_frozen_bits
            }

        eval_results[snr_key]["overall_ber"] = sum(ber_list) / len(ber_list)
        print(f"  Overall BER for SNR {snr}: {eval_results[snr_key]['overall_ber']:.6e}")

    json_file = os.path.join(config_dir, f"ber_epoch_{epoch+1}.json")
    with open(json_file, "w") as f:
        json.dump(eval_results, f, indent=4)
    print(f"\nBER results saved to {json_file}")

    return eval_results


In [8]:
# def compute_ber(preds, target):
#     pred_bits = (torch.sigmoid(preds) > 0.5).int()
#     total_bits = target.numel()
#     error_bits = (pred_bits != target.int()).sum().item()
#     return error_bits / total_bits

In [None]:
def train_one_epoch(batches=1000):
    model.train()
    running_loss = 0.0
    last_loss = 0.0

    for i, batch in enumerate(train_loader):
        channel, frozen, snr, target = batch
        channel = channel.float().to(device)
        frozen  = frozen.int().to(device)
        target  = target.float().to(device)

        optimizer.zero_grad()
        out = model(channel, frozen, snr.float().to(device))
        
      
        mask_msg = (frozen == 1)       
        msg_pred = out[mask_msg]
        msg_target = target[mask_msg]

        loss = loss_fn(msg_pred, msg_target)  
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item()

        
        if (i + 1) % batches == 0:
            last_loss = running_loss / batches
            print(f'  Batch {i+1} Average Loss: {last_loss:.6f}')
            running_loss = 0.0

  
    if running_loss > 0:
        last_loss = running_loss / ((i + 1) % batches)
        print(f'  Last Batch {i+1} Average Loss: {last_loss:.6f}')

    return last_loss


In [None]:
def train():
    best_val_ber = 1.0
    active_snrs = [10]

  
    print("\nPretraining at SNR = 10 dB")
    train_dataset.snr_list = [10]

    for epoch in range(10):
        print(f"\nEpoch {epoch+1}/10")
        avg_loss = train_one_epoch(batches=1000)

        # Track best loss in 5-epoch block
        if epoch % 5 == 0:
            best_loss_in_block = avg_loss
            best_epoch_in_block = epoch
            best_state_dict = model.state_dict()
        else:
            if avg_loss < best_loss_in_block:
                best_loss_in_block = avg_loss
                best_epoch_in_block = epoch
                best_state_dict = model.state_dict()

        # Every 5 epochs â†’ evaluate + save
        if (epoch + 1) % 5 == 0:
            print(f"\n5-epoch block finished. Best training loss: {best_loss_in_block:.6e}")

            model.load_state_dict(best_state_dict)

            val_results = calc_save_ber(
                model,
                device,
                snr_values=test_snr_list,
                msg_bit_sizes=[8, 16, 24],
                num_samples=3200,
                batch_size=32,
                config_no=CONFIG_NO,
                epoch=best_epoch_in_block,
                save_dir="src/evaluation"
            )

            avg_vloss = np.mean(
                [val_results[f"{s}_snr"]["overall_ber"] for s in test_snr_list]
            )

            model_dir = f"./checkpoints/config_{CONFIG_NO}"
            os.makedirs(model_dir, exist_ok=True)
            model_path = f"{model_dir}/model_best_block_epoch_{best_epoch_in_block+1}.pt"

            torch.save({
                "comments": "Removed the snr as input entirely. Use SNR linear if needed.",
                "model_config": {
                    "d_model": model.d_model,
                    "num_layer_encoder": model.num_layer_encoder,
                    "num_layers_bimamba_block": model.num_layers_bimamba_block,
                    "seq_len": model.seq_len,
                    "d_state": model.d_state,
                    "d_conv": model.d_conv,
                    "expand": model.expand,
                },
                "epoch": best_epoch_in_block + 1,
                "train_loss": best_loss_in_block,
                "val_loss": avg_vloss,
                "state_dict": best_state_dict
            }, model_path)

            print(f"Saved checkpoint: {model_path}")


    for snr in snr_curriculum:
        print(f"\nValidation BEFORE adding new SNR {snr}")

        val_results = calc_save_ber(
            model,
            device,
            snr_values=test_snr_list,
            msg_bit_sizes=[8, 16, 24],
            num_samples=3200,
            batch_size=32,
            config_no=CONFIG_NO,
            epoch=0,
            save_dir="src/evaluation"
        )

        avg_vloss = np.mean(
            [val_results[f"{s}_snr"]["overall_ber"] for s in test_snr_list]
        )

        print(f"Validation BER before adding new SNR: {avg_vloss:.6e}")

        if snr not in active_snrs:
            active_snrs.append(snr)

        train_dataset.snr_list = active_snrs
        print(f"\nTraining with SNRs: {active_snrs}")

        best_loss_in_block = float("inf")
        best_epoch_in_block = 0
        best_state_dict = None

        for epoch in range(5):
            print(f"\nEpoch {epoch+1}/5 for SNR {snr}")
            avg_loss = train_one_epoch(batches=1000)

            if avg_loss < best_loss_in_block:
                best_loss_in_block = avg_loss
                best_epoch_in_block = epoch
                best_state_dict = model.state_dict()

        model.load_state_dict(best_state_dict)

        val_results = calc_save_ber(
            model,
            device,
            snr_values=test_snr_list,
            msg_bit_sizes=[8, 16, 24],
            num_samples=3200,
            batch_size=32,
            config_no=CONFIG_NO,
            epoch=best_epoch_in_block,
            save_dir="src/evaluation"
        )

        avg_vloss = np.mean(
            [val_results[f"{s}_snr"]["overall_ber"] for s in test_snr_list]
        )

        model_dir = f"./checkpoints/config_{CONFIG_NO}"
        os.makedirs(model_dir, exist_ok=True)
        model_path = f"{model_dir}/model_best_block_epoch_{best_epoch_in_block+1}_snr_{snr}.pt"

        torch.save({
            "comments": "Removed the snr as input entirely. Use SNR linear if needed.",
            "model_config": {
                "d_model": model.d_model,
                "num_layer_encoder": model.num_layer_encoder,
                "num_layers_bimamba_block": model.num_layers_bimamba_block,
                "seq_len": model.seq_len,
                "d_state": model.d_state,
                "d_conv": model.d_conv,
                "expand": model.expand,
            },
            "epoch": best_epoch_in_block + 1,
            "train_loss": best_loss_in_block,
            "val_loss": avg_vloss,
            "state_dict": best_state_dict
        }, model_path)

        print(f"Saved checkpoint for SNR {snr}: {model_path}")

    print("\nTraining finished.")


In [11]:
train()


Pretraining at SNR = 10 dB

Epoch 1/10
  Batch 1000 Average Loss: 0.368104
  Batch 2000 Average Loss: 0.261773
  Batch 3000 Average Loss: 0.244201
  Batch 4000 Average Loss: 0.214342
  Batch 5000 Average Loss: 0.199647
  Batch 6000 Average Loss: 0.197878
  Last Batch 6250 Average Loss: 0.195672

Epoch 2/10
  Batch 1000 Average Loss: 0.190168
  Batch 2000 Average Loss: 0.184750
  Batch 3000 Average Loss: 0.184472
  Batch 4000 Average Loss: 0.184246
  Batch 5000 Average Loss: 0.149441
  Batch 6000 Average Loss: 0.131992
  Last Batch 6250 Average Loss: 0.134169

Epoch 3/10
  Batch 1000 Average Loss: 0.132968
  Batch 2000 Average Loss: 0.132606
  Batch 3000 Average Loss: 0.132774
  Batch 4000 Average Loss: 0.132740
  Batch 5000 Average Loss: 0.131265
  Batch 6000 Average Loss: 0.131896
  Last Batch 6250 Average Loss: 0.130979

Epoch 4/10
  Batch 1000 Average Loss: 0.131574
  Batch 2000 Average Loss: 0.132026
  Batch 3000 Average Loss: 0.132431
  Batch 4000 Average Loss: 0.133209
  Batch 5