In [None]:
import torch
import matplotlib.pyplot as plt
import random
from torch.utils.data import DataLoader
import komm
import numpy as np

In [None]:
SNR = 2
NUM_SYMBOLS = 4
INPUT_SIZE = 2*NUM_SYMBOLS*8
HIDDEN_SIZE = 256
NUM_LAYERS = 1
NUM_EPOCHS = 20

MOD = "ask" # Choose Modulation Technique["qam", "ask", "psk", "pam"]

In [None]:
if MOD == "qam":
    mod = komm.QAModulation(16)
elif MOD == "ask":
    mod = komm.ASKModulation(16)
elif MOD == "psk":
    mod = komm.PSKModulation(16)
elif MOD== "pam":
    mod = komm.PAModulation(16)

awgn = komm.AWGNChannel(snr=10**(SNR / 10),signal_power=5.0)

In [None]:
snr = 2
qam = komm.QAModulation(16)
ask = komm.ASKModulation(16)
psk = komm.PSKModulation(16)
pam = komm.PAModulation(16)
awgn = komm.AWGNChannel(snr=10**(snr / 10),signal_power=5.0)

In [None]:
batch_size_train = 64
batch_size_test = 8
num_symbols = 4
num_messages_train = 2048
num_messages_test = 256

In [None]:
train_bytes = [[random.randint(0,255) for i in range(num_symbols)] for j in range(num_messages_train)]
train_bits = [list(map(int,list(''.join(format(byte,'08b') for byte in train_bytes[j]))))for j in range(num_messages_train)]
train_signal = [mod.modulate(train_bits[i]) for i in range(num_messages_train)]
train_noisy_signal1 = [awgn(train_signal[i]) for i in range(num_messages_train)]
train_noisy_signal2 = [awgn(train_signal[i]) for i in range(num_messages_train)]
train_received_bits1 = [mod.demodulate(train_noisy_signal1[i]) for i in range(num_messages_train)]
train_received_bits2 = [mod.demodulate(train_noisy_signal2[i]) for i in range(num_messages_train)]

train_bits = torch.tensor(train_bits,dtype=torch.float32)
train_received_bits1 = torch.tensor(train_received_bits1,dtype=torch.float32)
train_received_bits2 = torch.tensor(train_received_bits2,dtype=torch.float32)
train_data = torch.stack((train_bits,train_received_bits1,train_received_bits2),axis=1)

test_bytes = [[random.randint(0,255) for i in range(num_symbols)] for j in range(num_messages_test)]
test_bits = [list(map(int,list(''.join(format(byte,'08b') for byte in test_bytes[j]))))for j in range(num_messages_test)]
test_signal = [mod.modulate(test_bits[i]) for i in range(num_messages_test)]
test_noisy_signal1 = [awgn(test_signal[i]) for i in range(num_messages_test)]
test_noisy_signal2 = [awgn(test_signal[i]) for i in range(num_messages_test)]
test_received_bits1 = [mod.demodulate(test_noisy_signal1[i]) for i in range(num_messages_test)]
test_received_bits2 = [mod.demodulate(test_noisy_signal2[i]) for i in range(num_messages_test)]

test_bits = torch.tensor(test_bits,dtype=torch.float32)
test_received_bits1 = torch.tensor(test_received_bits1,dtype=torch.float32)
test_received_bits2 = torch.tensor(test_received_bits2,dtype=torch.float32)
test_data = torch.stack((test_bits,test_received_bits1,test_received_bits2),axis=1)

In [None]:
train_dataloader = DataLoader(train_data, batch_size=batch_size_train, shuffle=False, num_workers=4)
test_dataloader = DataLoader(test_data, batch_size=batch_size_test, shuffle=False, num_workers=4)

In [None]:
input_size = 2*num_symbols*8
hidden_size = 256
num_layers = 1

In [None]:
class AutoencoderModel(torch.nn.Module):
    def __init__(self,input_size, hidden_size, num_layers):
        super(AutoencoderModel, self).__init__()
        self.hidden_size= hidden_size
        self.num_layers = num_layers
        
        self.bilstm_block1 = torch.nn.LSTM(input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, bidirectional = True)
        self.dense1 = torch.nn.Linear(2*hidden_size, input_size // 2)
        self.dense2 = torch.nn.Linear(hidden_size, input_size // 2)
        
    def forward(self,x):
      
        x,_ = self.bilstm_block1(x)
        x = torch.nn.functional.relu(x)
        x = self.dense1(x)
        x = torch.nn.functional.relu(x)

        return x
        

In [None]:
model = AutoencoderModel(input_size = input_size, hidden_size= hidden_size, num_layers = num_layers)

In [None]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
def BER(pred,target):
    size = torch.numel(target)
    return torch.sum(pred!=target).item()/size

In [None]:
def train_loop(batch, criterion, optimizer, model):
    
    y,x1,x2 = torch.unbind(batch,dim=1) 
    x = torch.concat([x1,x2], axis =1)
    optimizer.zero_grad()        
    y_pred = model(x)       
    loss = criterion(y_pred, y)
    ber = BER(torch.round(y_pred.detach()),y)
    loss.backward()  
    optimizer.step()
    
    return loss.item(),ber

In [None]:
def test_loop(batch, criterion, model):
    
    with torch.no_grad():
        model.eval()
        y,x1,x2 = torch.unbind(batch, dim = 1)
        x = torch.concat([x1,x2], axis =1)
        y_pred = model(x)
        ber = BER(torch.round(y_pred.detach()),y)
        loss = criterion(y_pred, y)
        
        return loss.item(),ber

In [None]:
num_epochs = 20
train_losses = []
test_losses = []
train_bers = []
test_bers = []
for epoch in range(num_epochs):
    train_l = 0.0
    test_l = 0.0
    train_b = 0.0
    test_b = 0.0
    count  = 0
    for train_batch,test_batch in iter(zip(train_dataloader,test_dataloader)):

        train_loss,train_ber = train_loop(batch = train_batch,criterion = criterion,optimizer = optimizer,model = model)
        test_loss,test_ber = test_loop(batch = test_batch,criterion = criterion,model = model)
        
        train_l += train_loss
        test_l += test_loss
        train_b += train_ber
        test_b += test_ber
        count += 1
        
    train_losses.append(train_l / count)
    test_losses.append(test_l / count)
    train_bers.append(train_b / count)
    test_bers.append(test_b / count)
        
    print(f"Epoch {epoch + 1} : Training Loss : {train_l / count} Training BER : {train_b / count}  Validation Loss : {test_l / count} Validation BER : {test_b / count}")

In [None]:
plt.plot([i for i in range(1,len(train_losses)+1)], train_losses, label = "Training Loss")
plt.plot([i for i in range(1,len(test_losses)+1)] , test_losses, label = "Validation Loss", color = 'red',marker = 'o')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
plt.plot([i for i in range(1,len(train_bers)+1)], train_bers, label = "Training BER")
plt.plot([i for i in range(1,len(test_bers)+1)] , test_bers, label = "Validation BER", color = 'red',marker = 'o')
plt.xlabel("Epochs")
plt.ylabel("BER")
plt.legend()
plt.show()

In [None]:
noise_pred_train = []
noise_x1_train = []
noise_x2_train = []
with torch.no_grad():
    model.eval()
    for x in train_dataloader:
        y,x1,x2 = torch.unbind(x, dim = 1)
        x = torch.concat([x1,x2], axis =1)
        y_pred = model(x).round()
        noise_pred_train.append(BER(y_pred,y))
        noise_x1_train.append(BER(x1,y))
        noise_x2_train.append(BER(x2,y))

In [None]:
plt.plot([i for i in range(1,len(noise_pred_train)+1)], noise_pred_train, label = "Predicted Noise", color = 'red', marker = 'o')
plt.plot([i for i in range(1,len(noise_x1_train)+1)] , noise_x1_train, label = "First Part Noise")
plt.xlabel("Frame")
plt.ylabel("Noise1")
plt.legend()
plt.show()

In [None]:
plt.plot([i for i in range(1,len(noise_pred_train)+1)], noise_pred_train, label = "Predicted Noise", color = 'red', marker = 'o')
plt.plot([i for i in range(1,len(noise_x2_train)+1)] , noise_x2_train, label = "Second Part Noise")
plt.xlabel("Frame")
plt.ylabel("Noise1")
plt.legend()
plt.show()

In [None]:
def data_gen(snr,data_bits,num_messages,mod):
    awgn = komm.AWGNChannel(snr=10**(snr / 10),signal_power=5.0)
    data_signal = [mod.modulate(data_bits[i]) for i in range(num_messages)]
    data_noisy_signal1 = [awgn(data_signal[i]) for i in range(num_messages)]
    data_noisy_signal2 = [awgn(data_signal[i]) for i in range(num_messages)]
    data_received_bits1 = [mod.demodulate(data_noisy_signal1[i]) for i in range(num_messages)]
    data_received_bits2 = [mod.demodulate(data_noisy_signal2[i]) for i in range(num_messages)]

    data_bits = torch.tensor(data_bits,dtype=torch.float32)
    data_received_bits1 = torch.tensor(data_received_bits1,dtype=torch.float32)
    data_received_bits2 = torch.tensor(data_received_bits2,dtype=torch.float32)
    return torch.stack((data_bits,data_received_bits1,data_received_bits2),axis=1)

In [None]:
with torch.no_grad():
    model.eval()
    ber_QAM=[]
    ber_PAM=[]
    ber_ASK=[]
    ber_PSK=[]
    for snr in np.arange(-4,30,1):
        data_QAM=data_gen(snr,train_bits.detach().numpy(),num_messages_train,qam)
        data_QAM = torch.unbind(data_QAM,axis = 1)
        data_PAM=data_gen(snr,train_bits.detach().numpy(),num_messages_train,pam)
        data_PAM = torch.unbind(data_PAM,axis = 1)
        data_PSK=data_gen(snr,train_bits.detach().numpy(),num_messages_train,psk)
        data_PSK = torch.unbind(data_PSK,axis = 1)
        data_ASK=data_gen(snr,train_bits.detach().numpy(),num_messages_train,ask)
        data_ASK = torch.unbind(data_ASK,axis = 1)
        y_pred_QAM=model(torch.concat([data_QAM[1],data_QAM[2]],axis=1))
        y_pred_PAM=model(torch.concat([data_PAM[1],data_PAM[2]],axis=1))
        y_pred_PSK=model(torch.concat([data_PSK[1],data_PSK[2]],axis=1))
        y_pred_ASK=model(torch.concat([data_ASK[1],data_ASK[2]],axis=1))
        ber_PSK.append(BER(torch.round(y_pred_PSK.detach()),data_PSK[0]))
        ber_PAM.append(BER(torch.round(y_pred_PAM.detach()),data_PAM[0]))
        ber_ASK.append(BER(torch.round(y_pred_ASK.detach()),data_ASK[0]))
        ber_QAM.append(BER(torch.round(y_pred_QAM.detach()),data_QAM[0]))

plt.plot(np.arange(-4,30,1),ber_QAM,label='QAM',color='red',marker='o')
plt.plot(np.arange(-4,30,1),ber_PAM,label="PAM",color='blue',marker='s')
plt.plot(np.arange(-4,30,1),ber_ASK,label="ASK",color='green',marker='^')
plt.plot(np.arange(-4,30,1),ber_PSK,label="PSK",color='orange',marker='*')
plt.yscale('logit')
plt.xlabel("SNR (dB)")
plt.ylabel("TrainBER")
plt.legend()
plt.show()

In [None]:
with torch.no_grad():
    model.eval()
    ber_QAM=[]
    ber_PAM=[]
    ber_ASK=[]
    ber_PSK=[]
    for snr in np.arange(-4,30,1):
        data_QAM=data_gen(snr,test_bits.detach().numpy(),num_messages_test,qam)
        data_QAM = torch.unbind(data_QAM,axis = 1)
        data_PAM=data_gen(snr,test_bits.detach().numpy(),num_messages_test,pam)
        data_PAM = torch.unbind(data_PAM,axis = 1)
        data_PSK=data_gen(snr,test_bits.detach().numpy(),num_messages_test,psk)
        data_PSK = torch.unbind(data_PSK,axis = 1)
        data_ASK=data_gen(snr,test_bits.detach().numpy(),num_messages_test,ask)
        data_ASK = torch.unbind(data_ASK,axis = 1)
        y_pred_QAM=model(torch.concat([data_QAM[1],data_QAM[2]],axis=1))
        y_pred_PAM=model(torch.concat([data_PAM[1],data_PAM[2]],axis=1))
        y_pred_PSK=model(torch.concat([data_PSK[1],data_PSK[2]],axis=1))
        y_pred_ASK=model(torch.concat([data_ASK[1],data_ASK[2]],axis=1))
        ber_PSK.append(BER(torch.round(y_pred_PSK.detach()),data_PSK[0]))
        ber_PAM.append(BER(torch.round(y_pred_PAM.detach()),data_PAM[0]))
        ber_ASK.append(BER(torch.round(y_pred_ASK.detach()),data_ASK[0]))
        ber_QAM.append(BER(torch.round(y_pred_QAM.detach()),data_QAM[0]))

plt.plot(np.arange(-4,30,1),ber_QAM,label='QAM',color='red',marker='o')
plt.plot(np.arange(-4,30,1),ber_PAM,label="PAM",color='blue',marker='s')
plt.plot(np.arange(-4,30,1),ber_ASK,label="ASK",color='green',marker='^')
plt.plot(np.arange(-4,30,1),ber_PSK,label="PSK",color='orange',marker='*')
plt.yscale('logit')
plt.xlabel("SNR (dB)")
plt.ylabel("TestBER")
plt.legend()
plt.show()