In [7]:
%load_ext autoreload
%autoreload 2

import os
os.chdir(globals()['_dh'][0])
os.chdir('..')
# print(os.path.abspath(os.curdir))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import rfcutils

In [9]:
#testset_identifier = 'TestSet1Example'
testset_identifier = 'TestSet1Mixture'
soi_type, interference_sig_type = 'OFDMQPSK', 'CommSignal2'
all_sig_mixture = np.load(os.path.join('dataset', f'{testset_identifier}_testmixture_{soi_type}_{interference_sig_type}.npy'))
meta_data = np.load(os.path.join('dataset', f'{testset_identifier}_testmixture_{soi_type}_{interference_sig_type}_metadata.npy'))
# For TestSet1Example
all_sig_mixture_groundtruth, all_sig1, all_bits1 = pickle.load(open(os.path.join('dataset', f'GroundTruth_{testset_identifier}_Dataset_{soi_type}_{interference_sig_type}.pkl'), 'rb'))

FileNotFoundError: [Errno 2] No such file or directory: 'dataset\\GroundTruth_TestSet1Mixture_Dataset_OFDMQPSK_CommSignal2.pkl'

In [None]:
example_idx = 0

In [None]:
plt.figure()
plt.plot(all_sig1[example_idx], color='tab:green')
plt.xlim([4000, 4640])
plt.show()

plt.figure()
plt.plot(all_sig_mixture[example_idx], color='tab:purple')
plt.xlim([4000, 4640])
plt.show()

In [None]:
sig_len = 40960
def get_soi_generation_fn(soi_sig_type):
    if soi_sig_type == 'QPSK':
        generate_soi = lambda n, s_len: rfcutils.generate_qpsk_signal(n, s_len//16)
        demod_soi = rfcutils.qpsk_matched_filter_demod
    elif soi_sig_type == 'QAM16':
        generate_soi = lambda n, s_len: rfcutils.generate_qam16_signal(n, s_len//16)
        demod_soi = rfcutils.qam16_matched_filter_demod
    elif soi_sig_type ==  'QPSK2':
        generate_soi = lambda n, s_len: rfcutils.generate_qpsk2_signal(n, s_len//4)
        demod_soi = rfcutils.qpsk2_matched_filter_demod
    elif soi_sig_type == 'OFDMQPSK':
        generate_soi = lambda n, s_len: rfcutils.generate_ofdm_signal(n, s_len//80)
        _,_,_,RES_GRID = rfcutils.generate_ofdm_signal(1, sig_len//80)
        demod_soi = lambda s: rfcutils.ofdm_demod(s, RES_GRID)
    else:
        raise Exception("SOI Type not recognized")
    return generate_soi, demod_soi

generate_soi, demod_soi = get_soi_generation_fn(soi_type)

In [None]:
bits_demod, syms_demod = demod_soi(all_sig1[example_idx:example_idx+1]) # retain the shape, axis=0 should be the mixture index
np.allclose(bits_demod, all_bits1[example_idx:example_idx+1])

### Baseline Curves with TestSet1Example

In [None]:
all_sinr = np.arange(-30, 0.1, 3)
n_per_batch = 100

def run_demod_test(sig1_est, bit1_est, soi_type, interference_sig_type, testset_identifier):
    # For SampleEvalSet
    all_sig_mixture_groundtruth, all_sig1, all_bits1 = pickle.load(open(os.path.join('dataset', f'GroundTruth_{testset_identifier}_Dataset_{soi_type}_{interference_sig_type}.pkl'), 'rb'))
    
    # Evaluation pipeline
    def eval_mse(all_sig_est, all_sig_soi):
        assert all_sig_est.shape == all_sig_soi.shape, 'Invalid SOI estimate shape'
        return np.mean(np.abs(all_sig_est - all_sig_soi)**2, axis=1)
    
    def eval_ber(bit_est, bit_true):
        ber = np.sum((bit_est != bit_true).astype(np.float32), axis=1) / bit_true.shape[1]
        assert bit_est.shape == bit_true.shape, 'Invalid bit estimate shape'
        return ber

    all_mse, all_ber = [], [] 
    for idx, sinr in enumerate(all_sinr):
        batch_mse =  eval_mse(sig1_est[idx*n_per_batch:(idx+1)*n_per_batch], all_sig1[idx*n_per_batch:(idx+1)*n_per_batch])
        bit_true_batch = all_bits1[idx*n_per_batch:(idx+1)*n_per_batch]
        batch_ber = eval_ber(bit1_est[idx*n_per_batch:(idx+1)*n_per_batch], bit_true_batch)
        all_mse.append(batch_mse)
        all_ber.append(batch_ber)

    all_mse, all_ber = np.array(all_mse), np.array(all_ber)
    mse_mean = 10*np.log10(np.mean(all_mse, axis=-1))
    ber_mean = np.mean(all_ber, axis=-1)
    return mse_mean, ber_mean

In [None]:
testset_identifier = 'TestSet1Example'
for soi_type in ['QPSK', 'OFDMQPSK']:
    for interference_sig_type in ['EMISignal1', 'CommSignal2', 'CommSignal3', 'CommSignal5G1']:

        all_mse, all_ber = {}, {}
        for id_string in ['Default_NoMitigation', 'Default_LMMSE', 'Default_TF_UNet', 'Default_Torch_WaveNet']:
            sig1_est = np.load(os.path.join('outputs', f'{id_string}_{testset_identifier}_estimated_soi_{soi_type}_{interference_sig_type}.npy'))
            bit1_est = np.load(os.path.join('outputs', f'{id_string}_{testset_identifier}_estimated_bits_{soi_type}_{interference_sig_type}.npy'))
            mse_mean, ber_mean = run_demod_test(sig1_est, bit1_est, soi_type, interference_sig_type, testset_identifier)
            all_mse[id_string] = mse_mean
            all_ber[id_string] = ber_mean

        plt.figure()
        for id_string in all_mse.keys():
            plt.plot(all_sinr, all_mse[id_string], 'x--', label=id_string)
        plt.legend()
        plt.grid()
        plt.gca().set_ylim(top=3)
        plt.xlabel('SINR [dB]')
        plt.ylabel('MSE [dB]')
        plt.title(f'MSE - {soi_type} + {interference_sig_type}')
        plt.show()

        plt.figure()
        for id_string in all_ber.keys():
            plt.semilogy(all_sinr, all_ber[id_string], 'x--', label=id_string)
        plt.legend()
        plt.grid()
        plt.ylim([1e-4, 1])
        plt.xlabel('SINR [dB]')
        plt.ylabel('BER')
        plt.title(f'BER - {soi_type} + {interference_sig_type}')
        plt.show()