In [1]:
from dataset import PolarDecDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
from models.wrappers.mamba_32bits import MambaPolarDecoder

  from .autonotebook import tqdm as notebook_tqdm


torch.Size([1, 3])
tensor([[-0.7251, -0.7145, -0.4168]], device='cuda:0',
       grad_fn=<SqueezeBackward1>)


In [2]:
N = 32

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

### Dataset

In [4]:
dataset = PolarDecDataset(snr_db=6, num_samples=1000000)
test_set = PolarDecDataset(snr_db=6, num_samples=1000000)


In [5]:
train_dataloader = DataLoader(dataset, batch_size = 32)
test_dataloader = DataLoader(test_set, batch_size = 32)

## Model

In [6]:
model = MambaPolarDecoder(
    d_model=16,
    num_layer_encoder=1,
    num_layers_bimamba_block=4,
    seq_len=N,
    d_state=16,
    d_conv=4,
    expand=2
).to(device)
model

MambaPolarDecoder(
  (discrete_embedding): Embedding(2, 16)
  (linear_embedding1): Linear(in_features=1, out_features=16, bias=True)
  (linear_embedding2): Linear(in_features=1, out_features=16, bias=True)
  (encoder_layers): ModuleList(
    (0): BiMambaEncoder(
      (layers): ModuleList(
        (0-3): 4 x BiMambaBlock(
          (pre_ln_f): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
          (mamba_f): Mamba(
            (in_proj): Linear(in_features=16, out_features=64, bias=False)
            (conv1d): Conv1d(32, 32, kernel_size=(4,), stride=(1,), padding=(3,), groups=32)
            (act): SiLU()
            (x_proj): Linear(in_features=32, out_features=33, bias=False)
            (dt_proj): Linear(in_features=1, out_features=32, bias=True)
            (out_proj): Linear(in_features=32, out_features=16, bias=False)
          )
          (post_ln_f): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
          (ffn_f): Sequential(
            (0): Linear(in_features=1

## Minor modification to the Loss Function: Calculates loss only at non frozen positions

In [7]:
def calculate_loss_for_reliable_bits_only(frozen_bit_prior, target_vector, predicted_vector, loss_fn):
    """
    frozen_bit_prior: tensor of shape (seq_len,) with 1 for frozen, 0 for message bits
    target_vector: tensor of shape (seq_len,)
    predicted_vector: tensor of shape (seq_len,)
    loss_fn: PyTorch loss function
    """
    mask = (frozen_bit_prior != 1) 
    
   
    reliable_target = target_vector[mask]
    reliable_predicted = predicted_vector[mask]

   # print(f"Length of reliable bits: {len(reliable_target)}")

    return loss_fn(reliable_predicted, reliable_target)

In [8]:
loss_fn = torch.nn.BCEWithLogitsLoss()

## Sanity Check

In [9]:
channel_tensor, frozen_tensor, snr_tensor, target_tensor= next(iter(train_dataloader))
ip1 = channel_tensor.float().to(device)
ip2 = frozen_tensor.int().to(device)
ip3 = snr_tensor.float().to(device)

predicted = model(ip1, ip2, ip3) #works

loss = calculate_loss_for_reliable_bits_only(ip2, target_tensor.to(device), predicted.to(device), loss_fn) #works


print(f"Channel Observation Vector: {ip1.shape}\nFrozen Tensor: {ip2.shape}\n")
print(f"Predicted Channel Input Vector(logits): {predicted.shape}\n\n")

print(f"Predicted (sigmoid): {torch.sigmoid(predicted)}\n\n")
pred = (torch.sigmoid(predicted) > 0.5).long()

print(f"Predicted bits:{''.join(map(str, pred.cpu().tolist()))}\n")
#print(f"Actual bits: {''.join(str(int(i)) for i in target_tensor)}\n")
print(f"Loss: {loss}")

Channel Observation Vector: torch.Size([32, 32])
Frozen Tensor: torch.Size([32, 32])

Predicted Channel Input Vector(logits): torch.Size([32, 32])


Predicted (sigmoid): tensor([[0.3038, 0.3824, 0.3246,  ..., 0.2908, 0.2673, 0.2454],
        [0.4233, 0.3338, 0.3723,  ..., 0.2789, 0.3020, 0.2686],
        [0.3510, 0.3955, 0.3669,  ..., 0.3993, 0.3035, 0.3530],
        ...,
        [0.3107, 0.3234, 0.3105,  ..., 0.2531, 0.2631, 0.2532],
        [0.3796, 0.3411, 0.3047,  ..., 0.3040, 0.3243, 0.2731],
        [0.3264, 0.3162, 0.4459,  ..., 0.3246, 0.3154, 0.3451]],
       device='cuda:0', grad_fn=<SigmoidBackward0>)


Predicted bits:[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0][0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0][0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0][0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [10]:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [11]:
def train_one_epoch(epoch_index):

    running_loss = 0
    last_loss = 0


    for i, data in enumerate(train_dataloader):
        channel_tensor, frozen_tensor, snr_tensor, target_tensor = data
        ip1 = channel_tensor.float().to(device)
        ip2 = frozen_tensor.int().to(device)
        ip3 = snr_tensor.float().to(device)
        op = target_tensor.to(device)
        optimizer.zero_grad()
        outputs = model(ip1,ip2 ,ip3 ).to(device)

    #    B, L, C = outputs.shape
    #    output_logits = outputs.view(B*L, C).to(device)
    #    target_flattened = shifted.view(B*L).to(device).long()


    #    loss = loss_fn(output_logits, target_flattened)
        
        loss = calculate_loss_for_reliable_bits_only(ip2, op, outputs, loss_fn)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i%1000 == 999:
            last_loss = running_loss/1000
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.
    return last_loss


In [12]:
def train( epochs=50):

    best_vloss = 1_000_000.

    for epoch in range(epochs):
     print('EPOCH {}:'.format(epoch + 1))
 
   
     model.train(True)
     avg_loss = train_one_epoch(epoch)


     running_vloss = 0.0
    
     model.eval()

   
     with torch.no_grad():
        for i, vdata in enumerate(test_dataloader):
            vchannel_tensor, vfrozen_tensor, vsnr_tensor, vtarget_tensor = vdata
            voutputs = model(vchannel_tensor.float().to(device), vfrozen_tensor.int().to(device), vsnr_tensor.float().to(device))
          #  B, L, C = voutputs.shape
          #  vloss = loss_fn(voutputs.view(B*L, C).to(device), vlabels.view(B*L).to(device))
            
            vloss = calculate_loss_for_reliable_bits_only(vfrozen_tensor.to(device), vtarget_tensor.to(device), voutputs.to(device), loss_fn)
            running_vloss += vloss

     avg_vloss = running_vloss / (i + 1)
     print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    
    

    
     if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = './checkpoints/model_epoch_{}'.format(epoch)
        torch.save(model.state_dict(), model_path)

     
    print("Training completed. Model available to use")





In [13]:
train(epochs=30)

EPOCH 1:
  batch 1000 loss: 0.3939847068488598
  batch 2000 loss: 0.3761376224756241
  batch 3000 loss: 0.3793273172080517
  batch 4000 loss: 0.37432436341047287
  batch 5000 loss: 0.3664797570109367
  batch 6000 loss: 0.36705457988381385
  batch 7000 loss: 0.3666394959688187
  batch 8000 loss: 0.35937456831336023
  batch 9000 loss: 0.3599668442606926
  batch 10000 loss: 0.3585135807991028
  batch 11000 loss: 0.3588571336567402
  batch 12000 loss: 0.36129762449860575
  batch 13000 loss: 0.35614758130908014
  batch 14000 loss: 0.3397854159474373
  batch 15000 loss: 0.3305469160676002
  batch 16000 loss: 0.3245115993618965
  batch 17000 loss: 0.3175469451248646
  batch 18000 loss: 0.31661248591542246
  batch 19000 loss: 0.31736321476101875
  batch 20000 loss: 0.3180165456533432
  batch 21000 loss: 0.31625718903541566
  batch 22000 loss: 0.31554092955589297
  batch 23000 loss: 0.31466167032718656
  batch 24000 loss: 0.6543365717232227
  batch 25000 loss: 0.6932442944645881
  batch 26000 l

RuntimeError: Parent directory ./checkpoints does not exist.

Input: tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 1])
Output: tensor([0, 1, 0, 0, 0, 0, 0, 0, 1, 0])

Predicted (logits): tensor([[-12.3343,  12.1840, -12.3455, -12.3638, -12.3470, -12.3511, -12.3607,
         -12.3202,  12.1719,  -0.0603]], device='cuda:0',
       grad_fn=<SqueezeBackward1>)
Predicted (sigmoid): tensor([[4.3982e-06, 9.9999e-01, 4.3493e-06, 4.2704e-06, 4.3428e-06, 4.3248e-06,
         4.2838e-06, 4.4606e-06, 9.9999e-01, 4.8492e-01]], device='cuda:0',
       grad_fn=<SigmoidBackward0>)
Predicted bits:0100000010
