In [1]:
from dataset import PolarDecDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
from models.wrappers.mamba_32bits import MambaPolarDecoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
N = 32
CONFIG_NO = 5

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

### Dataset

In [4]:
dataset = PolarDecDataset(snr_db=10, num_samples=100000)
test_set = PolarDecDataset(snr_db=10, num_samples=3200)


In [5]:
train_dataloader = DataLoader(dataset, batch_size = 32)
test_dataloader = DataLoader(test_set, batch_size = 32)

## Model

In [20]:
model = MambaPolarDecoder(
    d_model=32,
    num_layer_encoder=2,
    num_layers_bimamba_block=8,
    seq_len=N,
    d_state=16,
    d_conv=4,
    expand=2
).to(device)
model

MambaPolarDecoder(
  (discrete_embedding): Embedding(2, 32)
  (linear_embedding1): Linear(in_features=1, out_features=32, bias=True)
  (linear_embedding2): Linear(in_features=1, out_features=32, bias=True)
  (linear_input_layer): Linear(in_features=96, out_features=32, bias=True)
  (encoder_layers): ModuleList(
    (0-1): 2 x BiMambaEncoder(
      (layers): ModuleList(
        (0-7): 8 x BiMambaBlock(
          (pre_ln_f): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (mamba_f): Mamba(
            (in_proj): Linear(in_features=32, out_features=128, bias=False)
            (conv1d): Conv1d(64, 64, kernel_size=(4,), stride=(1,), padding=(3,), groups=64)
            (act): SiLU()
            (x_proj): Linear(in_features=64, out_features=34, bias=False)
            (dt_proj): Linear(in_features=2, out_features=64, bias=True)
            (out_proj): Linear(in_features=64, out_features=32, bias=False)
          )
          (post_ln_f): LayerNorm((32,), eps=1e-05, elementwise

In [8]:
# Run/Change if loading an old model only

model_path = "./checkpoints/config_3/model_epoch_29.pt"
checkpoint = torch.load(model_path, map_location=device)

state_dict = checkpoint.get("model_state_dict", checkpoint.get("state_dict", checkpoint))
model.load_state_dict(state_dict)

<All keys matched successfully>

## Minor modification to the Loss Function: Calculates loss only at non frozen positions

In [7]:
def calculate_loss(frozen_bit_prior, target_vector, predicted_vector,  reliable_only=False):
    """
    frozen_bit_prior: tensor of shape (seq_len,) with 1 for frozen, 0 for message bits
    target_vector: tensor of shape (seq_len,)
    predicted_vector: tensor of shape (seq_len,)
    loss_fn: PyTorch loss function
    """

    if reliable_only: 
     mask = (frozen_bit_prior != 1) 
     target_vector = target_vector[mask]
     predicted_vector = predicted_vector[mask]

    #print("target vector:" ,target_vector[:32], "\n")
    #print("pred vector:" ,predicted_vector[:32])

   # print(f"Length of reliable bits: {len(reliable_target)}")
    loss_fn = torch.nn.BCEWithLogitsLoss()

    return loss_fn(predicted_vector, target_vector)

## Sanity Check

In [25]:
llr, frozen_tensor, snr_tensor, target_tensor= next(iter(train_dataloader))
ip1 = llr.float().to(device)
ip2 = frozen_tensor.int().to(device)
ip3 = snr_tensor.float().to(device)

predicted = model(ip1, ip2, ip3) #works

loss = calculate_loss(ip2, target_tensor.to(device), predicted.to(device)) #works

#print(f"Channel Observation Vector: {channel_tensor}\n\n")


#print(f"Channel Observation Vector: {ip1.shape}\nFrozen Tensor: {ip2.shape}\n")
#print(f"Predicted Channel Input Vector(logits): {predicted.shape}\n\n")

#print(f"Predicted (sigmoid): {torch.sigmoid(predicted)}\n\n")
pred = (torch.sigmoid(predicted) > 0.5).long()[0]

print(f"Predicted bits:{''.join(map(str, pred.cpu().tolist()))}\n")
print(f"Actual bits: {''.join(str(int(i)) for i in target_tensor[0])}\n")
print(f"Loss: {loss}")

Predicted bits:01111101101111100111000100100000

Actual bits: 11010101011011100011000010100000

Loss: 0.4689653217792511


In [26]:

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=0.001)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
                                                        factor=0.5, patience=3)


In [27]:
def train_one_epoch(epoch_index):

    running_loss = 0
    last_loss = 0


    for i, data in enumerate(train_dataloader):
        channel_tensor, frozen_tensor, snr_tensor, target_tensor = data
        ip1 = channel_tensor.float().to(device)
        ip2 = frozen_tensor.int().to(device)
        ip3 = snr_tensor.float().to(device)
        op = target_tensor.to(device)
        optimizer.zero_grad()
        outputs = model(ip1,ip2 ,ip3 ).to(device)

    #    B, L, C = outputs.shape
    #    output_logits = outputs.view(B*L, C).to(device)
    #    target_flattened = shifted.view(B*L).to(device).long()


    #    loss = loss_fn(output_logits, target_flattened)
        
        loss = calculate_loss(ip2, op, outputs)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item()
        if i%1000 == 999:
            last_loss = running_loss/1000
            print('  batch {} loss: {}\n'.format(i + 1, last_loss))
          #  print(f"Predictions look currently like: {outputs[:32]}\n\n")
            running_loss = 0.
    return last_loss


In [29]:
def train( epochs=50):

    best_vloss = 1_000_000.

    for epoch in range(epochs):
     print('EPOCH {}:'.format(epoch + 1))
 
   
     model.train(True)
     avg_loss = train_one_epoch(epoch)


     running_vloss = 0.0
    
     model.eval()

   
     with torch.no_grad():
        for i, vdata in enumerate(test_dataloader):
            vchannel_tensor, vfrozen_tensor, vsnr_tensor, vtarget_tensor = vdata
            voutputs = model(vchannel_tensor.float().to(device), vfrozen_tensor.int().to(device), vsnr_tensor.float().to(device))
          #  B, L, C = voutputs.shape
          #  vloss = loss_fn(voutputs.view(B*L, C).to(device), vlabels.view(B*L).to(device))
            
            vloss = calculate_loss(vfrozen_tensor.to(device), vtarget_tensor.to(device), voutputs.to(device))
            running_vloss += vloss

     avg_vloss = running_vloss / (i + 1)
     print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    
     scheduler.step(avg_vloss)

    
     if avg_vloss < best_vloss:
      
    
      best_vloss = avg_vloss
      model_path = f'./checkpoints/config_{CONFIG_NO}/model_epoch_{epoch}.pt'
      torch.save({
         "comments": "New lightweight configuration used. Used 2 encoder layers instead of 1. Now each encoder layer has 8 bimamba blocks",
    'model_config': {
        "d_model": model.d_model,
        "num_layer_encoder": model.num_layer_encoder,
        "num_layers_bimamba_block": model.num_layers_bimamba_block,
        "seq_len": model.seq_len,
        "d_state": model.d_state,
        "d_conv": model.d_conv,
        "expand": model.expand,
    },
    'epoch': epoch + 1,
    'train_loss': avg_loss,
    'val_loss': avg_vloss,
    'state_dict': model.state_dict()
}, model_path)


     
    print("Training completed. Model available to use")





In [30]:
train(epochs=50)

EPOCH 1:
  batch 1000 loss: 0.21762221300601958

  batch 2000 loss: 0.18520276387035847

  batch 3000 loss: 0.17971488001942634

LOSS train 0.17971488001942634 valid 0.17469437420368195
EPOCH 2:
  batch 1000 loss: 0.17638983738422392

  batch 2000 loss: 0.17330959343910218

  batch 3000 loss: 0.17274228969216346

LOSS train 0.17274228969216346 valid 0.16979952156543732
EPOCH 3:
  batch 1000 loss: 0.17128088223934174

  batch 2000 loss: 0.16987204115092755

  batch 3000 loss: 0.16895946999639272

LOSS train 0.16895946999639272 valid 0.16394193470478058
EPOCH 4:
  batch 1000 loss: 0.16740853187441826

  batch 2000 loss: 0.16658798415213824

  batch 3000 loss: 0.164609590344131

LOSS train 0.164609590344131 valid 0.165303573012352
EPOCH 5:
  batch 1000 loss: 0.16448793176561594

  batch 2000 loss: 0.1641144694760442

  batch 3000 loss: 0.16233833894133567

LOSS train 0.16233833894133567 valid 0.15812137722969055
EPOCH 6:
  batch 1000 loss: 0.1625214083045721

  batch 2000 loss: 0.16258501