In [None]:
import os
import json
import copy
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from dataset import PolarDecDataset
from models.wrappers.mamba_32bits import MambaPolarDecoder

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

CONFIG_NO = 6
N = 32


In [None]:
snr_curriculum = [
    10,9,
    9,
    8,
    7,
    6, 6,
    5,
    4,
    3,2,1,
    0, 0
]

test_snr_list = [10, 5, 0]

num_train_samples = 200000   
num_test_samples  = 4000 
batch_size = 32

In [None]:
train_dataset = PolarDecDataset(
    snr_list=[10],   
    num_samples=num_train_samples,
    seq_length=N
)

test_dataset = PolarDecDataset(
    snr_list=test_snr_list,
    num_samples=num_test_samples,
    seq_length=N
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


In [None]:
model = MambaPolarDecoder(
    d_model=96,
    num_layer_encoder=1,
    num_layers_bimamba_block=24,
    seq_len=N,
    d_state=48,
    d_conv=9,
    expand=2,
).to(device)


In [None]:
ckpt_path = "./checkpoints/config_6/model_best_block_epoch_5_snr_10.pt"
ckpt = torch.load(ckpt_path, map_location=device,weights_only=False)
model.load_state_dict(ckpt["state_dict"])

print("Loaded pretrained model from:", ckpt_path)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=3, min_lr=1e-6
)

loss_fn = nn.BCEWithLogitsLoss()


In [None]:
def calc_save_ber(
    model,
    device,
    snr_values,
    msg_bit_sizes,
    num_samples=3200,
    batch_size=32,
    config_no=0,
    epoch=0,
    save_dir="src/evaluation"
):
    model.eval()
    os.makedirs(save_dir, exist_ok=True)
    config_dir = os.path.join(save_dir, f"config_{config_no}")
    os.makedirs(config_dir, exist_ok=True)

    eval_results = {}

    for snr in snr_values:
        print(f"\nEvaluating for SNR = {snr} dB\n")
        snr_key = f"{snr}_snr"
        eval_results[snr_key] = {}
        ber_list = []

        for msg_bits in msg_bit_sizes:
            total_msg_errors = 0
            total_frozen_errors = 0
            total_msg_bits = 0
            total_frozen_bits = 0

            test_set = PolarDecDataset(
                snr_list=[snr],
                num_samples=num_samples,
                fixed_msg_bit_size=msg_bits,
                seq_length=N
            )
            loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

            with torch.no_grad():
                for llrs, frozen, snr_t, target in loader:
                    llrs = llrs.to(device).float()
                    frozen = frozen.to(device).long()
                    target = target.to(device).long()

                    logits = model(llrs, frozen, snr_t.to(device))
                    predicted = (logits > 0).long()

                    mask_msg = frozen == 1
                    mask_frozen = frozen == 0

                    total_msg_errors += (target[mask_msg] != predicted[mask_msg]).sum().item()
                    total_frozen_errors += (target[mask_frozen] != predicted[mask_frozen]).sum().item()
                    total_msg_bits += mask_msg.sum().item()
                    total_frozen_bits += mask_frozen.sum().item()

            total_bits = total_msg_bits + total_frozen_bits
            total_errors = total_msg_errors + total_frozen_errors
            avg_net_ber = total_errors / total_bits

            ber_list.append(avg_net_ber)
            eval_results[snr_key][str(msg_bits)] = {
                "average_net_bit_error_rate": avg_net_ber
            }

        eval_results[snr_key]["overall_ber"] = sum(ber_list) / len(ber_list)

    json_file = os.path.join(config_dir, f"ber_epoch_{epoch+1}.json")
    with open(json_file, "w") as f:
        json.dump(eval_results, f, indent=4)

    return eval_results


In [None]:
# def compute_ber(preds, target):
#     pred_bits = (torch.sigmoid(preds) > 0.5).int()
#     total_bits = target.numel()
#     error_bits = (pred_bits != target.int()).sum().item()
#     return error_bits / total_bits

In [None]:
def train_one_epoch(batches=1000):
    model.train()
    running_loss = 0.0

    for i, (channel, frozen, snr, target) in enumerate(train_loader):
        channel = channel.float().to(device)
        frozen = frozen.int().to(device)
        target = target.float().to(device)

        optimizer.zero_grad()
        out = model(channel, frozen, snr.float().to(device))

        mask_msg = frozen == 1
        loss = loss_fn(out[mask_msg], target[mask_msg])

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.8)
        optimizer.step()

        running_loss += loss.item()

        if (i + 1) % batches == 0:
            return running_loss / batches

    return running_loss / (i + 1)


In [None]:
def train():
    best_val_ber = 1.0

    # Already trained up to SNR 8
    active_snrs = [10]

    for snr in snr_curriculum:
        print(f"\nTraining with new SNR {snr}")

        if snr not in active_snrs:
            active_snrs.append(snr)

        train_dataset.snr_list = active_snrs
        print("Active SNRs:", active_snrs)

        best_loss_in_block = float("inf")
        best_epoch_in_block = 0
        best_state_dict = None

        for epoch in range(6):
            avg_loss = train_one_epoch(batches=1000)
            print(f"SNR {snr} | Epoch {epoch+1}/6 | Loss {avg_loss:.6f}")

            if avg_loss < best_loss_in_block:
                best_loss_in_block = avg_loss
                best_epoch_in_block = epoch
                best_state_dict = copy.deepcopy(model.state_dict())

        model.load_state_dict(best_state_dict)

        val_results = calc_save_ber(
            model,
            device,
            snr_values=test_snr_list,
            msg_bit_sizes=[8, 16, 24],
            num_samples=3200,
            batch_size=32,
            config_no=CONFIG_NO,
            epoch=best_epoch_in_block,
            save_dir="src/evaluation"
        )

        avg_vloss = np.mean(
            [val_results[f"{s}_snr"]["overall_ber"] for s in test_snr_list]
        )

        scheduler.step(avg_vloss)

        model_dir = f"./checkpoints/config_{CONFIG_NO}"
        os.makedirs(model_dir, exist_ok=True)
        model_path = f"{model_dir}/model_best_block_epoch_{best_epoch_in_block+1}_snr_{snr}.pt"

        
        torch.save({
            "comments": "Removed the snr as input entirely. Use SNR linear if needed.",
            "model_config": {
                "d_model": model.d_model,
                "num_layer_encoder": model.num_layer_encoder,
                "num_layers_bimamba_block": model.num_layers_bimamba_block,
                "seq_len": model.seq_len,
                "d_state": model.d_state,
                "d_conv": model.d_conv,
                "expand": model.expand,
            },
            "epoch": best_epoch_in_block + 1,
            "train_loss": best_loss_in_block,
            "val_loss": avg_vloss,
            "state_dict": best_state_dict
        }, model_path)

        print(f"Saved checkpoint: {model_path}")

    print("\nTraining finished.")


In [None]:
train()