In [None]:
import sys
sys.path.append('..')
from Comms_System import Comms_System, butter_lowpass, SNR_plot
import numpy as np
import matplotlib.pyplot as plt
import torch
from NetworkPytorch import train_loop
from DE_Pytorch import DE
from torchsummary import summary
from scipy import signal
from scipy.stats import norm

In [None]:
def get_data(num_symbols, sigma, lowpass=None):
    
    symbol_set = [3, 1, -1, -3,] # all symbols that we use
    symbol_seq = np.random.choice(symbol_set, num_symbols, replace=True)
    m = 8
    CS = Comms_System(symbol_set=symbol_set, symbol_seq=symbol_seq, num_samples=m, norm_h=False)

    gain_factor = np.max(np.convolve(CS.h, CS.h))
    upsampled = CS.upsample(v=False)
    if lowpass is not None:
        print('low')
        b, a = butter_lowpass(lowpass, CS.m, 4)
        upsampled = signal.lfilter(b, a, upsampled)

    Tx = np.convolve(upsampled, CS.h)
    # Tx = Tx / np.sqrt(np.mean(np.square(Tx)))
    Tx = Tx + np.random.normal(0.0, sigma, Tx.shape)  # add gaussian noise
    
    X = torch.tensor(Tx)
    X = X.view(1, 1, -1).float() # reshape and cast to float so PyTorch understands it
    y = symbol_seq
    classes = np.array(symbol_set)
    num_classes = len(classes)

    class_idx = {v: i for i, v in enumerate(classes)}
    y_idx = np.array([class_idx[v] for v in y])
    y = torch.Tensor(y_idx)
    
    return X, y

def make_net():
    net = torch.nn.Sequential(torch.nn.Conv1d(1, 1, 64), torch.nn.Conv1d(1, 4, 8, stride=8))
    for param in net.parameters():
        param.requires_grad = False
    return net

In [None]:
# Create Data
Xtrain, ytrain = get_data(num_symbols=10000, sigma=2, lowpass=None)
Xtest, ytest = get_data(num_symbols=10000, sigma=2, lowpass=None)

# Create 1D Convolutional Neural Network with PyTorch and define optimizer and loss
NN = torch.nn.Sequential(torch.nn.Conv1d(1, 1, 64), torch.nn.Conv1d(1, 4, 8, stride=8))
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(NN.parameters(), lr=1e-3)

D = DE(objective_function=torch.nn.CrossEntropyLoss(), population_function=make_net, 
       X=Xtrain, y=ytrain, Xtest=Xtest, ytest=ytest, F=0.55, cr=0.85, use_cuda=False)
#summary(make_net(), input_size=(1, 1, 8063))

In [None]:
# DE Training

best_agent = D.evolution(num_epochs=10000, verbose=True, print_epoch=500)
#best_agent, opt_agent = D.early_stop_training(patience=500, measure='accuracy')

D.evaluate()
acc = torch.sum(D.best_agent.to('cpu')(Xtest).argmax(axis=1) == ytest)/len(ytest)
print('Accuracy:', acc.item())

In [None]:
norm_learned_filter = learned_filter / torch.sqrt(torch.sum(torch.square(learned_filter)))
norm_learned_filter

In [None]:
# Backprop Training

testcosts, traincosts = train_loop(model=NN, optimizer=optimizer, cost=criterion, Xtrain=Xtrain, ytrain=ytrain, 
                                   Xtest=Xtest, ytest=ytest, epochs=3000, eval=True, plot_iteration=300, 
                                   use_cuda=False)

In [None]:
#symbol_set = [7, 5, 3, 1, -1, -3, -5, -7] # all symbols that we use
symbol_set = [3, 1, -1, 3]
num_symbols = 100000
symbol_seq = np.random.choice(symbol_set, num_symbols, replace=True)
m = 8
CS = Comms_System(symbol_set=symbol_set, symbol_seq=symbol_seq, num_samples=m, norm_h=False)
#sigma = 1 # corresponds roughly to SNR 16 (old sigma=1)
#sigma = 0.7 # corresponds roughly to SNR 10 (old sigma=2)
#sigma = 1.06 # corresponds roughly to SNR 6.4 (old sigma=3)
SNR = 16
sigma = CS.SNR_to_sigma(16)
print(sigma)


euclid_decisions = CS.transmission(noise_level=sigma, mode='euclidean', norm_signal=False)
conv_decisions = CS.transmission(noise_level=sigma, mode='network', norm_signal=False, model=NN)

print(1 - CS.evaluate(euclid_decisions)[1])
print(1 - CS.evaluate(conv_decisions)[1])

In [None]:
SNRsDB, _, _, _, _, error_rates_conv, error_theory = SNR_plot(10000, lowpass=None, conv_model=NN,
                                                              norm_h=False, norm_signal=False, 
                                                              use_gain=True)

In [None]:
plt.figure(figsize=(18,11))
plt.title('Noise Plot', fontsize=24)
plt.xlabel('SNR (dB)', fontsize=20)
plt.ylabel('$P_e$', fontsize=20)
num = 0
plt.semilogy(SNRsDB[num:], error_rates_conv[num:])
plt.semilogy(SNRsDB[num:], error_theory[num:])
legend = ['Receiver Network', 'Theory']
plt.legend(legend, fontsize=16)
#plt.savefig('EveythingOldWithoutGain')
plt.show()