In [None]:
import sys
sys.path.append('..')
from Comms_System import Comms_System, SNR_plot, network_sender_receiver
from filters import butter_lowpass, ideal_lowpass
import numpy as np
import matplotlib.pyplot as plt
import torch
#from NetworkPytorch import joint_train_loop
from scipy import signal
import scipy
import torchaudio
from scipy.stats import norm
from commpy import rcosfilter
#from DE_Pytorch_Joint import DE
from DE_Pytorch_Joint_Minibatch import DE as DE_mini
import random

In [None]:
def get_data():

    upsampled = CS.upsample()
    X_tx = torch.Tensor(upsampled).view(1, 1, -1)
    classes = np.array(symbol_set)
    y = symbol_seq
    class_idx = {v: i for i, v in enumerate(classes)}
    y_idx = np.array([class_idx[v] for v in y])
    y = torch.Tensor(y_idx)
    
    return X_tx, y

In [None]:
def make_nets():
    NN_tx = torch.nn.Sequential(torch.nn.Conv1d(1, 1, 64, padding=63)) # padding=len(CS.h) - 1
    NN_rx = torch.nn.Sequential(torch.nn.Conv1d(1, 1, 64), torch.nn.Conv1d(1, 4, 8, stride=8))
    for param1, param2 in zip(NN_tx.parameters(), NN_rx.parameters()):
        param1.requires_grad = False
        param2.requires_grad = False
    return (NN_tx, NN_rx)

In [None]:
symbol_set = [3, 1, -1, -3] # all symbols that we use
num_symbols = 100000
symbol_seq = np.random.choice(symbol_set, num_symbols, replace=True)
m = 8
CS = Comms_System(symbol_set=symbol_set, symbol_seq=symbol_seq, num_samples=m, beta=0.35)
classes = np.array(symbol_set)


Xtrain, ytrain = get_data()
Xtest, ytest = get_data()


D = DE_mini(objective_function=torch.nn.CrossEntropyLoss(), pop_size=20, pop_fun=make_nets, 
     X=Xtrain, y=ytrain, Xtest=Xtest, ytest=ytest, F=0.6, cr=0.85, use_cuda=False, SNRdb=10, lowpass='butter', 
            cutoff_freq=0.675, noise='constant')

#D = DE(objective_function=torch.nn.CrossEntropyLoss(), pop_size=20, pop_fun=make_nets, 
#     X=Xtrain, y=ytrain, F=0.6, cr=0.85, use_cuda=False, SNRdb=6.5, lowpass='ideal', 
#            cutoff_freq=0.6)

In [None]:
D.evolution(num_epochs=300, batch_size=10000, verbose=True, print_epoch=1)

In [None]:


SNRdbs, euclid_er3, network_er, NN_er, block_er, joint_er_ideal, error_theory = \
SNR_plot(num_symbols=200000, joint_cutoff=0.675, joint_models=[NN_tx, NN_rx], lowpass='ideal', 
         range=[0, 19], num_SNRs=50) #[NN_tx, NN_rx]

In [None]:
plt.figure(figsize=(13,8))
#plt.title('Performance of DE-Trained Networks', fontsize=20)
plt.xlabel('SNR (dB)', fontsize=20)
plt.ylabel('$P_e$', fontsize=24)
num = 0
plt.semilogy(SNRdbs[num:], euclid_er2[num:], label='Euclidean', linewidth=3)
plt.semilogy(SNRdbs[num:], joint_er_ideal[num:], label='Joint Networks (Ideal LPF)', linewidth=3)
plt.semilogy(SNRdbs[num:], joint_er_butter[num:], label='Joint Networks (Butter LPF)', linewidth=3)
#plt.semilogy(SNRdbs[num:], error_theory[num:], alpha=1, label='Theory', linewidth=3)
plt.legend(fontsize=14)
plt.grid(True)
plt.show()
#plt.savefig('JointPerformanceDE5', bbox_inches='tight', transparent=True)

## Plot Full Response

In [None]:
lowpass = 'butter'
cutoff = 0.675

CS = Comms_System(symbol_set=symbol_set, symbol_seq=symbol_seq, num_samples=m, beta=0.35)
b, a = butter_lowpass(cutoff_freq=cutoff, sampling_rate=CS.m, order=10)
learned_tx_filter = np.array(list(NN_tx.parameters())[0].detach()[0][0])
learned_rx_filter = np.array(list(NN_rx.parameters())[0].detach()[0][0])
if lowpass == 'butter':
    total_sender_response = scipy.signal.filtfilt(b, a, learned_tx_filter)
elif lowpass == 'ideal':
    total_sender_response = ideal_lowpass(learned_tx_filter, cutoff, CS.m)
total_sender_response = total_sender_response / np.sqrt(np.sum(np.square(total_sender_response))) # normalize
full_response = np.convolve(total_sender_response, learned_rx_filter)
full_response = full_response / np.sqrt(np.sum(np.square(full_response))) # normalize
raised = np.convolve(CS.h, CS.h)
#raised = rcosfilter(N=127, alpha=0.35, Ts=1, Fs=m)[1]
raised /= np.sqrt(np.sum(np.square(raised)))


figsize = (13, 8)
plt.figure(figsize=figsize)
plt.title('Full Response of System (Time)', fontsize=16)
plt.plot(-full_response, label='Learned Full Response')
plt.plot(raised, label='Raised Cosine')
plt.ylabel('Amplitude', fontsize=16)
plt.grid(True)
plt.legend(fontsize=14)
plt.show()

plt.figure(figsize=figsize)
plt.title('Full Response (Frequency)', fontsize=20)
plt.magnitude_spectrum(full_response, Fs=8, scale='dB', sides='twosided', color='C1', label='Learned Full Response')
plt.magnitude_spectrum(raised, Fs=8, scale='dB', sides='twosided', color='C2', label='Raised Cosine')
plt.xlabel('Frequency', fontsize=16)
plt.ylabel('Magnitude (dB)', fontsize=16)
plt.ylim([-130, 0])
plt.xlim([-2, 2])
plt.legend(fontsize=14)
plt.grid(True)
plt.show()

## Plot Sender Response and Weights

In [None]:
figsize = (13, 8)
print(lowpass)

plt.figure(figsize=figsize)
plt.title('Full Sender Response (Time)', fontsize=20)
plt.plot(total_sender_response)
plt.ylabel('Amplitude', fontsize=16)
plt.show()

plt.figure(figsize=figsize)
plt.title('Full Sender Response (Frequency)', fontsize=20)
plt.magnitude_spectrum(total_sender_response, Fs=8, scale='dB', sides='twosided', color='C1')
plt.xlabel('Frequency', fontsize=16)
plt.ylabel('Magnitude (dB)', fontsize=16)
plt.ylim([-130, 0])
plt.show()

plt.title('Learned Sender Filter')
plt.plot(learned_tx_filter)
plt.show()
plt.title('Learned Receiver Filter')
plt.plot(learned_rx_filter)
plt.show()