# Evaluation for trained model to decode 32 bits polar code 

In [1]:
import sys, pathlib, pandas as pd
import json
import re, matplotlib.pyplot as plt, os

p = pathlib.Path.cwd()
while p != p.parent:
    if (p / "models").exists():
        project_root = p
        break
    p = p.parent

sys.path.insert(0, str(project_root))

In [2]:
import torch
from dataset import PolarDecDataset 
from models.wrappers.mamba_32bits import MambaPolarDecoder

from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

  from .autonotebook import tqdm as notebook_tqdm


'cuda'

In [3]:
N = 32

In [4]:
ckpt_path = "../checkpoints/config_6/model_best_block_epoch_6_snr_9.pt" 
#src/checkpoints/config_2/model_best_block_epoch_5_snr_8.pt

In [5]:
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
ckpt

{'comments': 'Removed the snr as input entirely. Use SNR linear if needed.',
 'model_config': {'d_model': 96,
  'num_layer_encoder': 1,
  'num_layers_bimamba_block': 24,
  'seq_len': 32,
  'd_state': 48,
  'd_conv': 9,
  'expand': 2},
 'epoch': 6,
 'train_loss': 0.041174855762161316,
 'val_loss': np.float64(0.09596137152777777),
 'state_dict': OrderedDict([('residual_scale',
               tensor(1.2590, device='cuda:0')),
              ('discrete_embedding.weight',
               tensor([[-1.1328e-01,  8.5577e-02,  1.2146e-01,  1.2322e-01, -1.2709e-01,
                         1.1338e-01, -1.0043e-01, -9.4671e-02, -4.9095e-02, -1.1862e-01,
                        -1.1027e-01, -1.0255e-01,  1.0762e-01,  1.1323e-01, -1.1905e-01,
                         1.1651e-01,  9.5961e-02, -4.5677e-02, -1.0715e-01, -1.0594e-01,
                         1.0771e-01,  1.1218e-01, -1.1535e-01,  1.1774e-01, -8.4351e-02,
                        -1.2078e-01, -1.0212e-01, -1.0764e-01, -1.1470e-01,  1.0785e

In [6]:
model = MambaPolarDecoder(
    d_model=ckpt['model_config']['d_model'],
    num_layer_encoder=ckpt['model_config']['num_layer_encoder'],
    num_layers_bimamba_block=ckpt['model_config']['num_layers_bimamba_block'],
    seq_len=ckpt['model_config']['seq_len'],
    d_state=ckpt['model_config']['d_state'],
    d_conv=ckpt['model_config']['d_conv'],
    expand=ckpt['model_config']['expand']
).to(device)
model

MambaPolarDecoder(
  (discrete_embedding): Embedding(2, 96)
  (linear_embedding1): Linear(in_features=1, out_features=96, bias=True)
  (input_layer): Sequential(
    (0): Linear(in_features=192, out_features=96, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=96, out_features=96, bias=True)
  )
  (encoder_layers): ModuleList(
    (0): BiMambaEncoder(
      (layers): ModuleList(
        (0-23): 24 x BiMambaBlock(
          (pre_ln_f): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mamba_f): Mamba(
            (in_proj): Linear(in_features=96, out_features=384, bias=False)
            (conv1d): Conv1d(192, 192, kernel_size=(9,), stride=(1,), padding=(8,), groups=192)
            (act): SiLU()
            (x_proj): Linear(in_features=192, out_features=102, bias=False)
            (dt_proj): Linear(in_features=6, out_features=192, bias=True)
            (out_proj): Linear(in_features=192, out_features=96, bias=False)
          )
          (post_ln_

### Metrics

In [7]:


# load model
state_dict = ckpt.get("model_state_dict", ckpt.get("state_dict", ckpt))
model.load_state_dict(state_dict, strict=False)

<All keys matched successfully>

In [8]:
import json
import torch
from torch.utils.data import DataLoader

def calc_save_ber(
    model,
    device,
    msg_bit_sizes: list,
    snr_db: list,
    num_samples=32000,
    batch_size=32,
    json_file_name=None,
):
    """
    Calculate BER over a test dataset and save detailed BER stats
    for message and frozen bits.
    """

    eval_results = {}

    for each_snr_val in snr_db:
        print(f"\nEvaluating for SNR = {each_snr_val} dB\n")

        snr_key = f"{each_snr_val}_snr"
        eval_results[snr_key] = {}
        ber_list = []

        for each_msg_bit_size in msg_bit_sizes:
            print(f"  Message bit size = {each_msg_bit_size}")

            total_msg_bit_errors = 0
            total_frozen_bit_errors = 0
            total_msg_bits = 0
            total_frozen_bits = 0

            
            test_set = PolarDecDataset(
                snr_list=[each_snr_val],      
                num_samples=num_samples,
                fixed_msg_bit_size=each_msg_bit_size,
                seq_length=32,
            )

            test_loader = DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False,
            )

            model.eval()
            with torch.no_grad():
                for llrs, frozen_tensor, snr_tensor, target_tensor in test_loader:

                    llrs = llrs.to(device).float()
                    frozen_tensor = frozen_tensor.to(device).long()
                    target_tensor = target_tensor.to(device).long()

                    # forward pass
                    logits = model(llrs, frozen_tensor, snr_tensor.to(device))

                    # hard decision
                    predicted = (logits > 0).long()

                    # masks
                    mask_msg = (frozen_tensor == 1)
                    mask_frozen = (frozen_tensor == 0)

                    msg_target = target_tensor[mask_msg]
                    msg_pred = predicted[mask_msg]

                    frozen_target = target_tensor[mask_frozen]
                    frozen_pred = predicted[mask_frozen]

                    total_msg_bit_errors += (msg_target != msg_pred).sum().item()
                    total_frozen_bit_errors += (frozen_target != frozen_pred).sum().item()

                    total_msg_bits += msg_target.numel()
                    total_frozen_bits += frozen_target.numel()

            total_bits = total_msg_bits + total_frozen_bits
            total_error_bits = total_msg_bit_errors + total_frozen_bit_errors

            avg_net_ber = total_error_bits / total_bits
            avg_msg_ber = total_msg_bit_errors / total_msg_bits if total_msg_bits > 0 else 0.0
            avg_frozen_ber = (
                total_frozen_bit_errors / total_frozen_bits if total_frozen_bits > 0 else 0.0
            )

            print(f"    Net BER     : {avg_net_ber:.6e}")
            print(f"    Msg BER     : {avg_msg_ber:.6e}")
            print(f"    Frozen BER  : {avg_frozen_ber:.6e}\n")

            ber_list.append(avg_net_ber)

            eval_results[snr_key][str(each_msg_bit_size)] = {
                "average_net_bit_error_rate": avg_net_ber,
                "average_msg_bit_error_rate": avg_msg_ber,
                "average_frozen_bit_error_rate": avg_frozen_ber,
                "batch_size": batch_size,
                "num_samples": num_samples,
                "total_bits": total_bits,
                "total_error_bits": total_error_bits,
                "total_msg_bits": total_msg_bits,
                "total_frozen_bits": total_frozen_bits,
            }

        eval_results[snr_key]["overall_ber"] = sum(ber_list) / len(ber_list)

    if json_file_name:
        with open(f"{json_file_name}.json", "w") as f:
            json.dump(eval_results, f, indent=4)
        print(f"Results saved to {json_file_name}.json")

    return eval_results


In [9]:
calc_save_ber(model, device, msg_bit_sizes=[8, 16,24], snr_db=[10], json_file_name="checkrightnow")



Evaluating for SNR = 10 dB

  Message bit size = 8
    Net BER     : 2.290723e-02
    Msg BER     : 2.989453e-02
    Frozen BER  : 2.057812e-02

  Message bit size = 16
    Net BER     : 5.688672e-02
    Msg BER     : 2.601563e-03
    Frozen BER  : 1.111719e-01

  Message bit size = 24
    Net BER     : 3.386230e-02
    Msg BER     : 2.112500e-02
    Frozen BER  : 7.207422e-02

Results saved to checkrightnow.json


{'10_snr': {'8': {'average_net_bit_error_rate': 0.0229072265625,
   'average_msg_bit_error_rate': 0.02989453125,
   'average_frozen_bit_error_rate': 0.020578125,
   'batch_size': 32,
   'num_samples': 32000,
   'total_bits': 1024000,
   'total_error_bits': 23457,
   'total_msg_bits': 256000,
   'total_frozen_bits': 768000},
  '16': {'average_net_bit_error_rate': 0.05688671875,
   'average_msg_bit_error_rate': 0.0026015625,
   'average_frozen_bit_error_rate': 0.111171875,
   'batch_size': 32,
   'num_samples': 32000,
   'total_bits': 1024000,
   'total_error_bits': 58252,
   'total_msg_bits': 512000,
   'total_frozen_bits': 512000},
  '24': {'average_net_bit_error_rate': 0.0338623046875,
   'average_msg_bit_error_rate': 0.021125,
   'average_frozen_bit_error_rate': 0.07207421875,
   'batch_size': 32,
   'num_samples': 32000,
   'total_bits': 1024000,
   'total_error_bits': 34675,
   'total_msg_bits': 768000,
   'total_frozen_bits': 256000},
  'overall_ber': 0.037885416666666664}}