In [1]:
import os
import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from dataset import PolarDecDataset  # your dataset class
from models.wrappers.mamba_32bits import MambaPolarDecoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

N = 32  
CONFIG_NO = 1.1 

Using device: cuda


In [3]:
snr_curriculum = [
    10, 10,
    9, 9,
    8,
    7,
    6, 6,
    5,
    4,
    3,
    0, 0,
    -3,
    -6,
    -10
]

test_snr_list = [10, 6, 0, -3, -6, -10]


num_train_samples = 100000   # large dataset for low BER
num_test_samples  = 320_000    
batch_size = 32

In [4]:
train_dataset = PolarDecDataset(
    snr_list=[10],  
    num_samples=num_train_samples,
    seq_length=N
)

test_dataset = PolarDecDataset(
    snr_list=test_snr_list,
    num_samples=num_test_samples,
    seq_length=N
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size)

In [5]:
model = MambaPolarDecoder(
    d_model=64,
    num_layer_encoder=1,
    num_layers_bimamba_block=12,
    seq_len=N,
    d_state=32,
    d_conv=6,
    expand=2,
).to(device)

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-6
)

In [7]:
def calculate_loss(frozen, target, pred):
    mask = (frozen != 1)
    loss_fn = nn.BCEWithLogitsLoss()
    return loss_fn(pred[mask], target[mask])

def compute_ber(pred, target):
    bits = (torch.sigmoid(pred) > 0.5).int()
    return (bits != target.int()).sum().item() / target.numel()

In [12]:
def calc_save_ber(
    model,
    device,
    snr_values,
    msg_bit_sizes=[8,16,24],
    seq_length=32,
    num_samples=3200,
    batch_size=32,
    config_no=0,
    epoch=0,
    save_dir="src/evaluation",
):
    import json
    from torch.utils.data import DataLoader
    os.makedirs(save_dir, exist_ok=True)
    config_dir = os.path.join(save_dir, f"config_{config_no}")
    os.makedirs(config_dir, exist_ok=True)

    eval_results = {}
    model.eval()

    with torch.no_grad():
        for each_snr in snr_values:
            print(f"\nEvaluating for SNR = {each_snr} dB\n")
            snr_key = f"{each_snr}_snr"
            eval_results[snr_key] = {}
            ber_list = []

            for msg_size in msg_bit_sizes:
                total_msg_bit_errors = 0
                total_frozen_bit_errors = 0
                total_msg_bits = 0
                total_frozen_bits = 0

                test_set = PolarDecDataset(
                    snr_list=[each_snr],
                    num_samples=num_samples,
                    fixed_msg_bit_size=msg_size,
                    seq_length=seq_length
                )

                test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

                for llrs, frozen_tensor, snr_tensor, target_tensor in test_loader:
                    llrs = llrs.to(device).float()
                    frozen_tensor = frozen_tensor.to(device).long()
                    target_tensor = target_tensor.to(device).long()

                    logits = model(llrs, frozen_tensor, snr_tensor.to(device))
                    predicted = (logits > 0).long()

                    mask_msg = (frozen_tensor == 1)
                    mask_frozen = (frozen_tensor == 0)

                    msg_target = target_tensor[mask_msg]
                    msg_pred = predicted[mask_msg]

                    frozen_target = target_tensor[mask_frozen]
                    frozen_pred = predicted[mask_frozen]

                    total_msg_bit_errors += (msg_target != msg_pred).sum().item()
                    total_frozen_bit_errors += (frozen_target != frozen_pred).sum().item()

                    total_msg_bits += msg_target.numel()
                    total_frozen_bits += frozen_target.numel()

                total_bits = total_msg_bits + total_frozen_bits
                total_error_bits = total_msg_bit_errors + total_frozen_bit_errors

                avg_net_ber = total_error_bits / total_bits
                avg_msg_ber = total_msg_bit_errors / total_msg_bits if total_msg_bits > 0 else 0.0
                avg_frozen_ber = total_frozen_bit_errors / total_frozen_bits if total_frozen_bits > 0 else 0.0

                print(f"    Net BER     : {avg_net_ber:.6e}")
                print(f"    Msg BER     : {avg_msg_ber:.6e}")
                print(f"    Frozen BER  : {avg_frozen_ber:.6e}\n")

                eval_results[snr_key][str(msg_size)] = {
                    "average_net_bit_error_rate": avg_net_ber,
                    "average_msg_bit_error_rate": avg_msg_ber,
                    "average_frozen_bit_error_rate": avg_frozen_ber,
                    "batch_size": batch_size,
                    "num_samples": num_samples,
                    "total_bits": total_bits,
                    "total_error_bits": total_error_bits,
                    "total_msg_bits": total_msg_bits,
                    "total_frozen_bits": total_frozen_bits,
                }

                ber_list.append(avg_net_ber)

            eval_results[snr_key]["overall_ber"] = sum(ber_list) / len(ber_list)
            print(f"  Overall BER for SNR {each_snr}: {eval_results[snr_key]['overall_ber']:.6e}")

    json_file_name = os.path.join(config_dir, f"epoch_{epoch+1}_snr_{'_'.join(map(str, snr_values))}.json")
    with open(json_file_name, "w") as f:
        json.dump(eval_results, f, indent=4)
    print(f"\nBER results saved to {json_file_name}")

    return eval_results


In [8]:
# def compute_ber(preds, target):
#     pred_bits = (torch.sigmoid(preds) > 0.5).int()
#     total_bits = target.numel()
#     error_bits = (pred_bits != target.int()).sum().item()
#     return error_bits / total_bits

In [None]:
def train_one_epoch(batches==1000):
    model.train()
    running_loss = 0.0
    last_loss = 0.0

    for i, batch in enumerate(train_loader):
        channel, frozen, snr, target = batch
        channel = channel.float().to(device)
        frozen  = frozen.int().to(device)
        target  = target.float().to(device)

        optimizer.zero_grad()
        out = model(channel, frozen, snr.float().to(device))
        loss = calculate_loss(frozen, target, out)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item()

        
        if (i + 1) % batches== 0:
            last_loss = running_loss / batches
            print(f'  Batch {i+1} Average Loss: {last_loss:.6f}')
            running_loss = 0.0

    return last_loss 



In [None]:
def train():
    best_val_ber = 1.0
    active_snrs = [10]

    # --- Pretraining at 10 dB ---
    print("\nSNR = 10 dB")
    train_dataset.snr_list = [10]

    for epoch in range(10):
        print(f"Epoch {epoch+1}/10")
        avg_loss = train_one_epoch(print_every=1000)

        # Track best loss in this 5-epoch pretraining block
        if epoch % 5 == 0:
            best_loss_in_block = avg_loss
            best_epoch_in_block = epoch
        else:
            if avg_loss < best_loss_in_block:
                best_loss_in_block = avg_loss
                best_epoch_in_block = epoch

        # Save after 5 epochs of pretraining
        if (epoch + 1) % 5 == 0:
            print(f"\nSaving model after 5 epochs (best training loss in block: {best_loss_in_block:.6e})")
            
            # Compute BER over test SNRs
            val_results = calc_save_ber(
                model,
                device,
                snr_values=test_snr_list,
                msg_bit_sizes=[8,16,24],
                num_samples=3200,
                batch_size=32,
                config_no=CONFIG_NO,
                epoch=best_epoch_in_block,
                save_dir="src/evaluation"
            )

            avg_vloss = np.mean([val_results[f"{s}_snr"]["overall_ber"] for s in test_snr_list])

            # Save checkpoint of the epoch with best training loss in this 5-epoch block
            model_dir = f"./checkpoints/config_{CONFIG_NO}"
            os.makedirs(model_dir, exist_ok=True)
            model_path = f"{model_dir}/model_best_block_epoch_{best_epoch_in_block+1}.pt"

            torch.save({
                "comments": "Removed the snr as input entirely. (even if used in future, use as snr linear, not in db)",
                "model_config": {
                    "d_model": model.d_model,
                    "num_layer_encoder": model.num_layer_encoder,
                    "num_layers_bimamba_block": model.num_layers_bimamba_block,
                    "seq_len": model.seq_len,
                    "d_state": model.d_state,
                    "d_conv": model.d_conv,
                    "expand": model.expand,
                },
                "epoch": best_epoch_in_block + 1,
                "train_loss": best_loss_in_block,
                "val_loss": avg_vloss,
                "state_dict": model.state_dict()
            }, model_path)
            print(f"Saved model checkpoint for best loss in block: {model_path}")

    # --- Curriculum Training ---
    for snr in snr_curriculum:
        print(f"\nValidation BEFORE adding new SNR {snr}")
        val_results = calc_save_ber(
            model,
            device,
            snr_values=test_snr_list,
            msg_bit_sizes=[8,16,24],
            num_samples=3200,
            batch_size=32,
            config_no=CONFIG_NO,
            epoch=0,
            save_dir="src/evaluation"
        )
        avg_vloss = np.mean([val_results[f"{s}_snr"]["overall_ber"] for s in test_snr_list])
        print(f"Validation BER before adding new SNR: {avg_vloss:.6e}")

        if snr not in active_snrs:
            active_snrs.append(snr)
        train_dataset.snr_list = active_snrs
        print(f"\nTraining with SNRs: {active_snrs}")

        # 5-epoch block for this SNR
        best_loss_in_block = float('inf')
        best_epoch_in_block = 0
        for epoch in range(5):
            print(f"Epoch {epoch+1}/5 for SNR {snr}")
            avg_loss = train_one_epoch(print_every=1000)

            # Track best training loss
            if avg_loss < best_loss_in_block:
                best_loss_in_block = avg_loss
                best_epoch_in_block = epoch

        # After 5 epochs, save the model with lowest training loss in this block
        print(f"\nSaving model for SNR {snr} with best training loss in block: {best_loss_in_block:.6e}")
        val_results = calc_save_ber(
            model,
            device,
            snr_values=test_snr_list,
            msg_bit_sizes=[8,16,24],
            num_samples=3200,
            batch_size=32,
            config_no=CONFIG_NO,
            epoch=best_epoch_in_block,
            save_dir="src/evaluation"
        )
        avg_vloss = np.mean([val_results[f"{s}_snr"]["overall_ber"] for s in test_snr_list])

        model_dir = f"./checkpoints/config_{CONFIG_NO}"
        os.makedirs(model_dir, exist_ok=True)
        model_path = f"{model_dir}/model_best_block_epoch_{best_epoch_in_block+1}_snr_{snr}.pt"

        torch.save({
            "comments": "Removed the snr as input entirely. (even if used in future, use as snr linear, not in db)",
            "model_config": {
                "d_model": model.d_model,
                "num_layer_encoder": model.num_layer_encoder,
                "num_layers_bimamba_block": model.num_layers_bimamba_block,
                "seq_len": model.seq_len,
                "d_state": model.d_state,
                "d_conv": model.d_conv,
                "expand": model.expand,
            },
            "epoch": best_epoch_in_block + 1,
            "train_loss": best_loss_in_block,
            "val_loss": avg_vloss,
            "state_dict": model.state_dict()
        }, model_path)
        print(f"Saved model checkpoint for best loss in block for SNR {snr}: {model_path}")

    print("\nTraining finished.")


In [11]:
train()


 SNR = 10 dB
Epoch 1/10
Batch 1000/3125, Loss: 0.000012
Batch 2000/3125, Loss: 0.000007
Batch 3000/3125, Loss: 0.000005


KeyboardInterrupt: 