In [1]:
from dataset import PolarDecDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
from models.wrappers.mamba_32bits import MambaPolarDecoder

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
N = 32
CONFIG_NO = 2

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

### Dataset

In [4]:
dataset = PolarDecDataset(snr_db=10, num_samples=1000000)
test_set = PolarDecDataset(snr_db=10, num_samples=1000000)


In [5]:
train_dataloader = DataLoader(dataset, batch_size = 32)
test_dataloader = DataLoader(test_set, batch_size = 32)

## Model

In [6]:
model = MambaPolarDecoder(
    d_model=64,
    num_layer_encoder=1,
    num_layers_bimamba_block=32,
    seq_len=N,
    d_state=32,
    d_conv=4,
    expand=2
).to(device)
model

MambaPolarDecoder(
  (discrete_embedding): Embedding(2, 64)
  (linear_embedding1): Linear(in_features=1, out_features=64, bias=True)
  (linear_embedding2): Linear(in_features=1, out_features=64, bias=True)
  (encoder_layers): ModuleList(
    (0): BiMambaEncoder(
      (layers): ModuleList(
        (0-31): 32 x BiMambaBlock(
          (pre_ln_f): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (mamba_f): Mamba(
            (in_proj): Linear(in_features=64, out_features=256, bias=False)
            (conv1d): Conv1d(128, 128, kernel_size=(4,), stride=(1,), padding=(3,), groups=128)
            (act): SiLU()
            (x_proj): Linear(in_features=128, out_features=68, bias=False)
            (dt_proj): Linear(in_features=4, out_features=128, bias=True)
            (out_proj): Linear(in_features=128, out_features=64, bias=False)
          )
          (post_ln_f): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (ffn_f): Sequential(
            (0): Linear(in_f

## Minor modification to the Loss Function: Calculates loss only at non frozen positions

In [7]:
def calculate_loss_for_reliable_bits_only(frozen_bit_prior, target_vector, predicted_vector, loss_fn):
    """
    frozen_bit_prior: tensor of shape (seq_len,) with 1 for frozen, 0 for message bits
    target_vector: tensor of shape (seq_len,)
    predicted_vector: tensor of shape (seq_len,)
    loss_fn: PyTorch loss function
    """
    mask = (frozen_bit_prior != 1) 
    
   
    reliable_target = target_vector[mask]
    reliable_predicted = predicted_vector[mask]

   # print(f"Length of reliable bits: {len(reliable_target)}")

    return loss_fn(reliable_predicted, reliable_target)

In [8]:
loss_fn = torch.nn.BCEWithLogitsLoss()

## Sanity Check

In [43]:
llr, frozen_tensor, snr_tensor, target_tensor= next(iter(train_dataloader))
ip1 = llr.float().to(device)
ip2 = frozen_tensor.int().to(device)
ip3 = snr_tensor.float().to(device)

predicted = model(ip1, ip2, ip3) #works

loss = calculate_loss_for_reliable_bits_only(ip2, target_tensor.to(device), predicted.to(device), loss_fn) #works

#print(f"Channel Observation Vector: {channel_tensor}\n\n")


print(f"Channel Observation Vector: {ip1.shape}\nFrozen Tensor: {ip2.shape}\n")
print(f"Predicted Channel Input Vector(logits): {predicted.shape}\n\n")

print(f"Predicted (sigmoid): {torch.sigmoid(predicted)}\n\n")
pred = (torch.sigmoid(predicted) > 0.5).long()[0]

print(f"Predicted bits:{''.join(map(str, pred.cpu().tolist()))}\n")
print(f"Actual bits: {''.join(str(int(i)) for i in target_tensor[0])}\n")
print(f"Loss: {loss}")

Channel Observation Vector: torch.Size([32, 32])
Frozen Tensor: torch.Size([32, 32])

Predicted Channel Input Vector(logits): torch.Size([32, 32])


Predicted (sigmoid): tensor([[1.0000e+00, 9.9997e-01, 1.0000e+00,  ..., 9.9986e-01, 9.9977e-01,
         1.0000e+00],
        [7.0380e-05, 3.9573e-05, 8.5431e-01,  ..., 1.0000e+00, 9.9999e-01,
         9.9997e-01],
        [9.9998e-01, 1.0000e+00, 9.9948e-01,  ..., 9.9993e-01, 1.0000e+00,
         9.9745e-01],
        ...,
        [5.6786e-03, 2.2419e-04, 9.9998e-01,  ..., 9.9931e-01, 9.9986e-01,
         9.9999e-01],
        [9.9997e-01, 4.7640e-03, 9.9833e-01,  ..., 9.9836e-01, 9.9982e-01,
         9.9991e-01],
        [8.9023e-01, 3.7809e-03, 1.0441e-04,  ..., 9.9979e-01, 1.0000e+00,
         9.9931e-01]], device='cuda:0', grad_fn=<SigmoidBackward0>)


Predicted bits:11110101100111110111111101111111

Actual bits: 01000010111000000000000010000000

Loss: 6.650700569152832


In [44]:

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.001)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
                                                        factor=0.5, patience=3)


In [45]:
def train_one_epoch(epoch_index):

    running_loss = 0
    last_loss = 0


    for i, data in enumerate(train_dataloader):
        channel_tensor, frozen_tensor, snr_tensor, target_tensor = data
        ip1 = channel_tensor.float().to(device)
        ip2 = frozen_tensor.int().to(device)
        ip3 = snr_tensor.float().to(device)
        op = target_tensor.to(device)
        optimizer.zero_grad()
        outputs = model(ip1,ip2 ,ip3 ).to(device)

    #    B, L, C = outputs.shape
    #    output_logits = outputs.view(B*L, C).to(device)
    #    target_flattened = shifted.view(B*L).to(device).long()


    #    loss = loss_fn(output_logits, target_flattened)
        
        loss = calculate_loss_for_reliable_bits_only(ip2, op, outputs, loss_fn)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item()
        if i%1000 == 999:
            last_loss = running_loss/1000
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.
    return last_loss


In [None]:
def train( epochs=50):

    best_vloss = 1_000_000.

    for epoch in range(epochs):
     print('EPOCH {}:'.format(epoch + 1))
 
   
     model.train(True)
     avg_loss = train_one_epoch(epoch)


     running_vloss = 0.0
    
     model.eval()

   
     with torch.no_grad():
        for i, vdata in enumerate(test_dataloader):
            vchannel_tensor, vfrozen_tensor, vsnr_tensor, vtarget_tensor = vdata
            voutputs = model(vchannel_tensor.float().to(device), vfrozen_tensor.int().to(device), vsnr_tensor.float().to(device))
          #  B, L, C = voutputs.shape
          #  vloss = loss_fn(voutputs.view(B*L, C).to(device), vlabels.view(B*L).to(device))
            
            vloss = calculate_loss_for_reliable_bits_only(vfrozen_tensor.to(device), vtarget_tensor.to(device), voutputs.to(device), loss_fn)
            running_vloss += vloss

     avg_vloss = running_vloss / (i + 1)
     print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    
     scheduler.step(avg_vloss)

    
     if avg_vloss < best_vloss:
      
    
      best_vloss = avg_vloss
      model_path = f'./checkpoints/config_{CONFIG_NO}/model_epoch_{epoch}.pt'
      torch.save({
    'model_config': {
        "d_model": model.d_model,
        "num_layer_encoder": model.num_layer_encoder,
        "num_layers_bimamba_block": model.num_layers_bimamba_block,
        "seq_len": model.seq_len,
        "d_state": model.d_state,
        "d_conv": model.d_conv,
        "expand": model.expand,
    },
    'epoch': epoch + 1,
    'train_loss': avg_loss,
    'val_loss': avg_vloss,
    'state_dict': model.state_dict()
}, model_path)


     
    print("Training completed. Model available to use")





In [47]:
train(epochs=30)

EPOCH 1:
  batch 1000 loss: 0.5336742365956306
  batch 2000 loss: 0.3722417163848877
  batch 3000 loss: 0.327930050522089
  batch 4000 loss: 0.3215761846899986
  batch 5000 loss: 0.31960532972216604
  batch 6000 loss: 0.31790875414013864
  batch 7000 loss: 0.31538455495238304
  batch 8000 loss: 0.3153826283812523
  batch 9000 loss: 0.3144086654186249
  batch 10000 loss: 0.31567151948809624
  batch 11000 loss: 0.31377769938111305
  batch 12000 loss: 0.3152754058539867
  batch 13000 loss: 0.3129089913368225
  batch 14000 loss: 0.31123180663585664
  batch 15000 loss: 0.3063631791770458
  batch 16000 loss: 0.30533461801707745
  batch 17000 loss: 0.3048073563277721
  batch 18000 loss: 0.3055128723680973
  batch 19000 loss: 0.30529735863208773
  batch 20000 loss: 0.3047070904374123
  batch 21000 loss: 0.30450183027982713
  batch 22000 loss: 0.30586443743109704
  batch 23000 loss: 0.3029697656780481
  batch 24000 loss: 0.3041873051673174
  batch 25000 loss: 0.3047487173229456
  batch 26000 lo

KeyboardInterrupt: 