In [None]:
import sys
sys.path.append('..')
from Comms_System import Comms_System, SNR_plot
import numpy as np
import matplotlib.pyplot as plt
from ML_components import load_params
from Network import NeuralNetwork
import scipy
import torch
from filters import butter_lowpass

In [None]:
symbol_set = [3, 1, -1, -3] # all symbols that we use
num_symbols = 10000
symbol_seq = np.random.choice(symbol_set, num_symbols, replace=True)
m = 8
CS = Comms_System(symbol_set=symbol_set, symbol_seq=symbol_seq, num_samples=m, beta=0.35)


In [None]:
def ideal_lpf(Tx, cutoff_freq):
    Tx_freq = torch.fft.rfft(Tx)
    xf = torch.fft.rfftfreq(Tx.shape[0], 1/8)
    Tx_freq[xf > cutoff_freq] = 0
    Tx_low = torch.fft.irfft(Tx_freq, n=Tx.shape[0])
    return Tx_low

In [None]:
cutoffs = ['2', '1', '06', '05']
cutoff_ints = [2, 1, 0.6, 0.5]
#cutoffs = ['06']
#cutoff_ints = [0.6]
figsize=(6.4, 4.8)
#figsize=(13, 8)

# Inspect learned filters from Joint Training

### Sender Weights

In [None]:
for cutoff, c_int in zip(cutoffs, cutoff_ints):
    print('Cutoff Frequency =', c_int)
    path = '../Joint_Models/'
    net = torch.load(path + 'SenderIdeal' + cutoff)
    learned_filter = list(net.parameters())[0].detach()[0][0]
    learned_filter = learned_filter / torch.sqrt(torch.sum(torch.square(learned_filter))) # normalize
    plt.figure(figsize=figsize)
    plt.title('Sender Filter (Time)', fontsize=20)
    plt.plot(learned_filter)
    plt.ylabel('Amplitude', fontsize=16)
    plt.show()
    plt.figure(figsize=figsize)
    plt.title('Sender Filter (Frequency)', fontsize=20)
    plt.magnitude_spectrum(learned_filter, Fs=CS.m, color='C1', sides='twosided', scale='dB')
    plt.xlabel('Frequency', fontsize=16)
    plt.ylabel('Magnitude (dB)', fontsize=16)
    plt.ylim([-95, -20])
    plt.show()
    print('_________________________________________________________')
    print()

### Receiver Weights

In [None]:
for cutoff, c_int in zip(cutoffs, cutoff_ints):
    print('Cutoff Frequency =', c_int)
    path = '../Joint_Models/'
    net = torch.load(path + 'ReceiverIdeal' + cutoff)
    learned_filter = list(net.parameters())[0].detach()[0][0]
    learned_filter = learned_filter / torch.sqrt(torch.sum(torch.square(learned_filter))) # normalize
    plt.figure(figsize=figsize)
    plt.title('Receiver Filter (Time)', fontsize=20)
    plt.plot(learned_filter)
    plt.ylabel('Amplitude', fontsize=16)
    plt.show()
    plt.figure(figsize=figsize)
    plt.title('Receiver Filter (Frequency)', fontsize=20)
    plt.magnitude_spectrum(learned_filter, Fs=CS.m, color='C1', sides='twosided', scale='dB')
    plt.xlabel('Frequency', fontsize=16)
    plt.ylabel('Magnitude (dB)', fontsize=16)
    plt.ylim([-120, 0])
    plt.show()
    print('_________________________________________________________')
    print()

### Total Impulse Response of Sender (i.e. sender filter convolved with LPF)

In [None]:
for cutoff, c_int in zip(cutoffs, cutoff_ints):
    print('Cutoff Frequency =', c_int)
    path = '../Joint_Models/'
    net = torch.load(path + 'SenderIdeal' + cutoff)
    learned_filter = list(net.parameters())[0].detach()[0][0]
    learned_filter = learned_filter / torch.sqrt(torch.sum(torch.square(learned_filter))) # normalize
    total_sender_response = ideal_lpf(learned_filter, c_int)
    total_sender_response = total_sender_response / torch.sqrt(torch.sum(torch.square(total_sender_response))) # normalize
    plt.figure(figsize=figsize)
    plt.title('Full Sender Response (Time)', fontsize=20)
    plt.plot(total_sender_response)
    plt.ylabel('Amplitude', fontsize=16)
    plt.show()
    plt.figure(figsize=figsize)
    plt.title('Full Sender Response (Frequency)', fontsize=20)
    plt.magnitude_spectrum(total_sender_response, Fs=8, sides='twosided', scale='dB', color='C1')
    plt.xlabel('Frequency', fontsize=16)
    plt.ylabel('Magnitude (dB)', fontsize=16)
    plt.ylim([-130, 0])
    plt.show()
    print('_________________________________________________________')
    print()

### Total Response of Whole System

In [None]:
for cutoff, c_int in zip(cutoffs, cutoff_ints):
    print('Cutoff Frequency =', c_int)
    path = '../Joint_Models/'
    tx_net = torch.load(path + 'SenderIdeal'+cutoff)
    rx_net = torch.load(path + 'ReceiverIdeal'+cutoff)
    
    learned_tx_filter = list(tx_net.parameters())[0].detach()[0][0]
    learned_rx_filter = list(rx_net.parameters())[0].detach()[0][0]
    total_sender_response = ideal_lpf(learned_tx_filter, c_int)
    
    total_sender_response = total_sender_response / torch.sqrt(torch.sum(torch.square(total_sender_response))) # normalize
    total_sender_response = total_sender_response.detach().numpy()
    learned_rx_filter = learned_rx_filter.detach().numpy()
    
    full_response = np.convolve(total_sender_response, learned_rx_filter)
    full_response = full_response / np.sqrt(np.sum(np.square(full_response))) # normalize
    
    plt.figure(figsize=figsize)
    plt.title('Full Reponse of System (Time)', fontsize=16)
    plt.plot(full_response)
    plt.ylabel('Amplitude', fontsize=16)
    plt.show()
    
    plt.figure(figsize=figsize)
    plt.title('Full Response of System (Frequency)', fontsize=16)
    plt.xlabel('Frequency', fontsize=16)
    plt.ylabel('Magnitude (dB)', fontsize=16)
    plt.ylim([-130, 0])
    plt.magnitude_spectrum(full_response, Fs=8, scale='dB', sides='twosided', color='C1')
    plt.show()