In [1]:
import sys
sys.path.append('..')
from Comms_System import Comms_System, butter_lowpass, SNR_plot_new
import numpy as np
import matplotlib.pyplot as plt
import torch
from NetworkPytorch import joint_train_loop
from DE_Pytorch import DE
from scipy import signal
import scipy
import torchaudio
from scipy.stats import norm

In [2]:
def get_data():

    upsampled = CS.upsample()
    X_tx = torch.Tensor(upsampled).view(1, 1, -1)
    classes = np.array(symbol_set)
    y = symbol_seq
    class_idx = {v: i for i, v in enumerate(classes)}
    y_idx = np.array([class_idx[v] for v in y])
    y = torch.Tensor(y_idx)
    
    return X_tx, y

In [3]:
symbol_set = [3, 1, -1, -3] # all symbols that we use
num_symbols = 10000
symbol_seq = np.random.choice(symbol_set, num_symbols, replace=True)
m = 8
CS = Comms_System(symbol_set=symbol_set, symbol_seq=symbol_seq, num_samples=m, beta=0.35, norm_h=False)

# Create 1D Convolutional Neural Networks for transmitter and receiver and define optimizer and loss
# Remember to double check padding and general design of NN_tx and NN_rx

Xtrain, ytrain = get_data()
Xtest, ytest = get_data()

NN_tx = torch.nn.Sequential(torch.nn.Conv1d(1, 1, 64, padding=len(CS.h)-1)) # padding=len(CS.h) - 1
NN_rx = torch.nn.Sequential(torch.nn.Conv1d(1, 1, 64), torch.nn.Conv1d(1, 4, 8, stride=8))
criterion = torch.nn.CrossEntropyLoss()
params = list(NN_tx.parameters()) + list(NN_rx.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3)

In [None]:
# Joint train sender and receiver NN using Backpropagation

epoch_losses = joint_train_loop(NN_tx, NN_rx, Xtrain, ytrain, optimizer, criterion, lowpass=True,
                                sample_rate=CS.m, epochs=2000, cutoff_freq=2, v=True, use_cuda=False, SNRdb=10)

plt.figure(figsize=(13, 8))
plt.plot(epoch_losses.to('cpu'))
plt.show()

In [None]:
# Evaluate

def transmit_joint(upsampled, classes, SNRdb=10, cutoff_freq=2, v=False):
    
    SNR = 10 ** (SNRdb / 10)
    sigma = np.sqrt(8 / SNR)
    if v:
        print("sigma:", sigma)
    
    b, a = butter_lowpass(cutoff_freq, CS.m, 4)
    b = torch.tensor(b, requires_grad=True).float()
    a = torch.tensor(a, requires_grad=True).float()

    Tx = NN_tx(upsampled)

    #Send filtered signal through lowpass filter
    Tx = torchaudio.functional.lfilter(Tx, a, b)
    # Normalize signal
    Tx = Tx / torch.sqrt(torch.mean(torch.square(Tx)))
    # Transmit signal
    Tx = Tx + torch.normal(0.0, sigma, Tx.shape)

    output = NN_rx(Tx)[0].T
    decisions = classes[output.argmax(axis=1)]
    
    return decisions

In [None]:
joint_decisions = transmit_joint(Xtest, np.array(CS.symbol_set))
print("Accuracy:", 1 - CS.evaluate(joint_decisions)[1])

In [None]:
sigma = CS.SNRdb_to_sigma(SNRdb=10, energy=1*8)
joint_decisions = CS.transmit_joint(sigma, cutoff_freq=2)
joint_decisions
1 - CS.evaluate(joint_decisions)[1]

In [None]:
symbol_set = [3, 1, -1, -3]  # all symbols that we use
num_symbols = 1000
symbol_seq = np.random.choice(symbol_set, num_symbols, replace=True)
m = 8
CS = Comms_System(symbol_set=symbol_set, symbol_seq=symbol_seq, num_samples=m, beta=0.35, norm_h=False)

SNRdbs = np.linspace(0, 18, 50)
sigmas = []
euclid_error_rates = []
network_error_rates = []
avg_symbol_energy = np.mean(np.array(symbol_seq) ** 2)
print('Avg symbol energy', avg_symbol_energy)
gain_factor = np.max(np.convolve(CS.h, CS.h))
print('gain', gain_factor)
Xtest, ytest, CStest = get_data(num_symbols=10000)

for SNRdb in SNRdbs:
    sigma_euclid = CS.SNRdb_to_sigma(SNRdb, avg_symbol_energy, use_gain=True) # symbol energy og gain
    sigma_network = CS.SNRdb_to_sigma(SNRdb, 8, use_gain=False) # fordi vi har normaliseret er sample energi sat til 1, og vi har 8 samples pr symbol, er avg_symbol_energy så 1*8
    euclid_decisions = CS.transmission(noise_level=sigma_euclid, norm_signal=False, v=False)
    joint_decisions = transmit_joint(SNRdb, Xtest, ytest, CStest, cutoff_freq=2)
    CStest.evaluate(joint_decisions)[1]
    
    #network_decisions = CS.transmission(mode='network', noise_level=sigma_network, norm_signal=True, v=False, model=model)
    sigmas.append(sigma_euclid)

    euclid_error_rates.append(CS.evaluate(euclid_decisions)[1])
    network_error_rates.append(CStest.evaluate(joint_decisions)[1])

sigmas = np.array(sigmas)
error_theory = 1.5 * (1 - norm.cdf(np.sqrt(gain_factor / sigmas ** 2)))
euclid_error_rates = np.array(euclid_error_rates)
network_error_rates = np.array(network_error_rates)

In [None]:
plt.figure(figsize=(18, 11))
plt.title('Noise Plot', fontsize=24)
plt.xlabel('SNR (dB)', fontsize=20)
plt.ylabel('$P_e$', fontsize=20)
plt.semilogy(SNRdbs, euclid_error_rates)
plt.semilogy(SNRdbs, error_theory)
plt.semilogy(SNRdbs, network_error_rates)

legend = ['Euclid', 'Theory', 'Joint Training']
plt.legend(legend, fontsize=16)
plt.show()

In [None]:
Xtest, ytest, CStest = get_data(num_symbols=10000)

sigmas = np.linspace(0.25, 1.75, 50)

#sigmas = np.linspace(0.75, 4.5, 50) #sigmas = np.linspace(2.5, 4.5, 500)#
SNRs = []
error_rates_joint = []
error_rates_euclid = []
avg_symbol_energy = np.mean(np.array(CStest.symbol_seq)**2)
gain_factor = np.max(np.convolve(CStest.h, CStest.h))
print(gain_factor)

for sigma in sigmas:

    error_rates_joint.append( evaluate_yo(sigma, Xtest, ytest, CStest, 2) )
    received_symbols_euclid = CStest.transmission(mode='euclidean', noise_level=sigma)
    SNRs.append(avg_symbol_energy/(sigma**2))
    error_rates_euclid.append(CStest.evaluate(received_symbols_euclid)[1])

SNRsDB = 10*np.log10(SNRs)
error_rates_joint = np.array(error_rates_joint)
error_rates_euclid = np.array(error_rates_euclid)
error_theory = 1.5 * (1 - scipy.stats.norm.cdf(np.sqrt(gain_factor/sigmas**2)))

In [None]:
plt.figure(figsize=(13,8))
plt.title('Decision-Making Noise Plots', fontsize=24)
plt.xlabel('SNR (dB)', fontsize=20)
plt.ylabel('$P_e$', fontsize=20)
num = 0
plt.semilogy(SNRsDB, error_rates_euclid)
plt.semilogy(SNRsDB[num:], error_rates_joint[num:])
plt.semilogy(SNRsDB, error_theory)
legend = ['Euclidean', 'Sender AND Receiver Network', 'Theory']
legend2 = ['Euclidean', 'Theory']
plt.legend(legend2, fontsize=16)
#plt.ylim([1e-3, 1])
plt.show()