In [1]:
import os
import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import json
from dataset import PolarDecDataset 
from models.wrappers.mamba_32bits import MambaPolarDecoder
import copy
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

N = 32  
CONFIG_NO = 7

snr_curriculum = [10,10,9,8,7,6,6,5,4,3,2,1,0]  

test_snr_list = [10, 5, 0]

num_train_samples = 200000   
num_test_samples  = 4000 
batch_size = 32

train_dataset = PolarDecDataset(
    snr_list=[10],
    num_samples=num_train_samples,
    seq_length=N
)

test_dataset = PolarDecDataset(
    snr_list=test_snr_list,
    num_samples=num_test_samples,
    seq_length=N
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

model = MambaPolarDecoder(
    d_model=64,
    num_layer_encoder=1,
    num_layers_bimamba_block=24,
    seq_len=N,
    d_state=32,
    d_conv=6,
    expand=2,
).to(device)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=1e-2
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=50,
    eta_min=1e-5
)

loss_fn = nn.BCEWithLogitsLoss()


def calc_save_ber(
    model,
    device,
    snr_values,
    msg_bit_sizes,
    num_samples=3200,
    batch_size=32,
    config_no=0,
    epoch=0,
    save_dir="src/evaluation"
):
    model.eval()
    os.makedirs(save_dir, exist_ok=True)
    config_dir = os.path.join(save_dir, f"config_{config_no}")
    os.makedirs(config_dir, exist_ok=True)

    eval_results = {}

    for snr in snr_values:
        snr_key = f"{snr}_snr"
        eval_results[snr_key] = {}
        ber_list = []

        for msg_bits in msg_bit_sizes:
            total_msg_errors = 0
            total_frozen_errors = 0
            total_msg_bits = 0
            total_frozen_bits = 0

            test_set = PolarDecDataset(
                snr_list=[snr],
                num_samples=num_samples,
                fixed_msg_bit_size=msg_bits,
                seq_length=N
            )
            loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

            with torch.no_grad():
                for llrs, frozen, snr_t, target in loader:
                    llrs = llrs.to(device).float()
                    frozen = frozen.to(device).long()
                    target = target.to(device).long()

                    # -------- FIX: SNR must be float --------
                    logits = model(llrs, frozen, snr_t.float().to(device))

                    predicted = (logits > 0).long()

                    mask_msg = frozen == 1
                    mask_frozen = frozen == 0

                    total_msg_errors += (target[mask_msg] != predicted[mask_msg]).sum().item()
                    total_frozen_errors += (target[mask_frozen] != predicted[mask_frozen]).sum().item()
                    total_msg_bits += mask_msg.sum().item()
                    total_frozen_bits += mask_frozen.sum().item()

            total_bits = total_msg_bits + total_frozen_bits
            total_errors = total_msg_errors + total_frozen_errors
            avg_net_ber = total_errors / total_bits

            ber_list.append(avg_net_ber)
            eval_results[snr_key][str(msg_bits)] = {
                "average_net_bit_error_rate": avg_net_ber
            }

        eval_results[snr_key]["overall_ber"] = sum(ber_list) / len(ber_list)

    json_file = os.path.join(config_dir, f"ber_epoch_{epoch+1}.json")
    with open(json_file, "w") as f:
        json.dump(eval_results, f, indent=4)

    return eval_results


def train_one_epoch(batches=1000):
    model.train()
    running_loss = 0.0
    last_loss = 0.0

    for i, batch in enumerate(train_loader):
        channel, frozen, snr, target = batch

        channel = channel.float().to(device)
        frozen  = frozen.long().to(device)
        target  = target.float().to(device)
        snr     = snr.float().to(device)

        optimizer.zero_grad()

        out1 = model(channel, frozen, snr)
        refined_channel = channel + out1.detach()
        out = model(refined_channel, frozen, snr)

        mask_msg = (frozen == 1)
        mask_frozen = (frozen == 0)

        msg_pred = out[mask_msg]
        msg_target = target[mask_msg]

        frozen_pred = out[mask_frozen]
        frozen_target = target[mask_frozen]

        loss_msg = loss_fn(msg_pred, msg_target)
        loss_frozen = loss_fn(frozen_pred, frozen_target)

        loss = 0.8 * loss_msg + 0.2 * loss_frozen

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        running_loss += loss.item()

        if (i + 1) % batches == 0:
            last_loss = running_loss / batches
            print(f"  Batch {i+1} Average Loss: {last_loss:.6f}")
            running_loss = 0.0

    return last_loss


def train():
    active_snrs = [10]

    for snr in snr_curriculum:
        print(f"\nTraining with new SNR {snr}")

        if snr not in active_snrs:
            active_snrs.append(snr)

        train_dataset.snr_list = active_snrs
        print("Active SNRs:", active_snrs)

        best_loss_in_block = float("inf")
        best_epoch_in_block = 0
        best_state_dict = None

        for epoch in range(6):
            avg_loss = train_one_epoch(batches=1000)

            scheduler.step()  

            print(f"SNR {snr} | Epoch {epoch+1}/6 | Loss {avg_loss:.6f}")

            if avg_loss < best_loss_in_block:
                best_loss_in_block = avg_loss
                best_epoch_in_block = epoch
                best_state_dict = copy.deepcopy(model.state_dict())

        model.load_state_dict(best_state_dict)

        val_results = calc_save_ber(
            model,
            device,
            snr_values=test_snr_list,
            msg_bit_sizes=[8,16,24],
            num_samples=3200,
            batch_size=32,
            config_no=CONFIG_NO,
            epoch=best_epoch_in_block,
            save_dir="src/evaluation"
        )

        avg_vloss = np.mean([val_results[f"{s}_snr"]["overall_ber"] for s in test_snr_list])

        

        model_dir = f"./checkpoints/config_{CONFIG_NO}"
        os.makedirs(model_dir, exist_ok=True)
        model_path = f"{model_dir}/model_best_block_epoch_{best_epoch_in_block+1}_snr_{snr}.pt"

        torch.save({
            "comments": "Removed the snr as input entirely. Use SNR linear if needed.",
            "model_config": {
                "d_model": model.d_model,
                "num_layer_encoder": model.num_layer_encoder,
                "num_layers_bimamba_block": model.num_layers_bimamba_block,
                "seq_len": model.seq_len,
                "d_state": model.d_state,
                "d_conv": model.d_conv,
                "expand": model.expand,
            },
            "epoch": best_epoch_in_block + 1,
            "train_loss": best_loss_in_block,
            "val_loss": avg_vloss,
            "state_dict": best_state_dict
        }, model_path)

        print(f"Saved checkpoint: {model_path}")

    print("\nTraining finished.")


Using device: cuda


In [3]:
train()


Training with new SNR 10
Active SNRs: [10]
  Batch 1000 Average Loss: 0.364664
  Batch 2000 Average Loss: 0.326791
  Batch 3000 Average Loss: 0.318122
  Batch 4000 Average Loss: 0.305829
  Batch 5000 Average Loss: 0.290144
  Batch 6000 Average Loss: 0.286862
SNR 10 | Epoch 1/6 | Loss 0.286862
  Batch 1000 Average Loss: 0.284535
  Batch 2000 Average Loss: 0.283934
  Batch 3000 Average Loss: 0.272894
  Batch 4000 Average Loss: 0.261047
  Batch 5000 Average Loss: 0.260239
  Batch 6000 Average Loss: 0.262642
SNR 10 | Epoch 2/6 | Loss 0.262642
  Batch 1000 Average Loss: 0.263161
  Batch 2000 Average Loss: 0.263615
  Batch 3000 Average Loss: 0.265575
  Batch 4000 Average Loss: 0.264108
  Batch 5000 Average Loss: 0.264196
  Batch 6000 Average Loss: 0.262312
SNR 10 | Epoch 3/6 | Loss 0.262312
  Batch 1000 Average Loss: 0.264629
  Batch 2000 Average Loss: 0.268814
  Batch 3000 Average Loss: 0.267710
  Batch 4000 Average Loss: 0.268085
  Batch 5000 Average Loss: 0.265145
  Batch 6000 Average Lo

KeyboardInterrupt: 