In [None]:
from dataset import PolarDecDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
from models.wrappers.mamba_32bits import MambaPolarDecoder

In [None]:
N = 32
CONFIG_NO = 26

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

### Dataset

In [None]:
dataset = PolarDecDataset(snr_db=-10, num_samples=100000, seq_length=N)
test_set = PolarDecDataset(snr_db=-10, num_samples=3200, seq_length=N)
#num_samples = polar_block

In [None]:
train_dataloader = DataLoader(dataset, batch_size = 32)
test_dataloader = DataLoader(test_set, batch_size = 32)

## Model

In [None]:
model = MambaPolarDecoder(
    d_model=32,               
    num_layer_encoder=1,      
    num_layers_bimamba_block=10,  
    seq_len=N,
    d_state=32,               
    d_conv=6,                 
    expand=2
).to(device)

In [None]:
checkpoint_path = "./checkpoints/config_25/model_epoch_5.pt"
ckpt = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(ckpt['state_dict'])

## Minor modification to the Loss Function: Calculates loss only at non frozen positions

In [None]:
def calculate_loss(frozen_bit_prior, target_vector, predicted_vector,  reliable_only=False):

    if reliable_only: 
     mask = (frozen_bit_prior != 1) 
     target_vector = target_vector[mask]
     predicted_vector = predicted_vector[mask]

    loss_fn = torch.nn.BCEWithLogitsLoss()

    return loss_fn(predicted_vector, target_vector)
    #Defines a custom loss function for polar code decoding, optionally ignoring frozen bits.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6, weight_decay=0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=4, min_lr=1e-8
)


In [None]:
def train_one_epoch(epoch_index):

    running_loss = 0
    last_loss = 0


    for i, data in enumerate(train_dataloader):
        # Extracting tensors
        channel_tensor, frozen_tensor, snr_tensor, target_tensor = data
        ip1 = channel_tensor.float().to(device)
        ip2 = frozen_tensor.int().to(device)
        ip3 = snr_tensor.float().to(device)
    

        op = target_tensor.to(device)
        optimizer.zero_grad()
        outputs = model(ip1,ip2 ,ip3 ).to(device)


        
        loss = calculate_loss(ip2, op, outputs)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item()
        if i%1000 == 999:
            last_loss = running_loss/1000
            print('  batch {} loss: {}\n'.format(i + 1, last_loss))
            running_loss = 0.
    return last_loss


In [None]:
import os

def train(epochs=50):
    best_vloss = 1_000_000.

    for epoch in range(epochs):
        print('EPOCH {}:'.format(epoch + 1))

        # Training
        model.train(True)
        avg_loss = train_one_epoch(epoch)

        # Validation
        running_vloss = 0.0
        model.eval()

        with torch.no_grad():
            for i, vdata in enumerate(test_dataloader):
                vchannel_tensor, vfrozen_tensor, vsnr_tensor, vtarget_tensor = vdata
                voutputs = model(
                    vchannel_tensor.float().to(device),
                    vfrozen_tensor.int().to(device),
                    vsnr_tensor.float().to(device)
                )
                vloss = calculate_loss(
                    vfrozen_tensor.to(device), 
                    vtarget_tensor.to(device), 
                    voutputs.to(device)
                )
                running_vloss += vloss

        avg_vloss = running_vloss / (i + 1)
        print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

        scheduler.step(avg_vloss)

        # Save checkpoint if validation improves
        if avg_vloss < best_vloss:
            best_vloss = avg_vloss
            model_dir = f'./checkpoints/config_{CONFIG_NO}'
            os.makedirs(model_dir, exist_ok=True)  
            model_path = f'{model_dir}/model_epoch_{epoch}.pt'
            
            torch.save({
                "comments": "Removed the snr as input entirely. (even if used in future, use as snr linear, not in db)",
                'model_config': {
                    "d_model": model.d_model,
                    "num_layer_encoder": model.num_layer_encoder,
                    "num_layers_bimamba_block": model.num_layers_bimamba_block,
                    "seq_len": model.seq_len,
                    "d_state": model.d_state,
                    "d_conv": model.d_conv,
                    "expand": model.expand,
                },
                'epoch': epoch + 1,
                'train_loss': avg_loss,
                'val_loss': avg_vloss,
                'state_dict': model.state_dict()
            }, model_path)

    print("Training completed. Model available to use")


In [None]:
train(epochs=20)