In [1]:
%load_ext autoreload
%autoreload 2

import os
os.chdir(globals()['_dh'][0])
os.chdir('..')
print(os.path.abspath(os.curdir))

/home/gridsan/glcf411/ofdm_neural_architecture


In [2]:
from tqdm import tqdm
import numpy as np
import random
import matplotlib.pyplot as plt

from pytorch_lightning import Trainer, seed_everything

In [3]:
import torch
from torch import optim
from asteroid.engine import System
from torch.utils.data import TensorDataset, DataLoader

np.random.seed(42)
random.seed(42)
seed_everything(42, workers=True)
n_test = 1000
n_sc = 28
nfft = 64
osfactor = 1
sig_len = 4096 + 160
cos_idx = np.arange(n_sc) + 1
np.random.shuffle(cos_idx)

cos_waves1 = np.exp(1j*2*np.pi*osfactor*cos_idx[:n_sc//2].reshape(-1,1)/nfft*(np.arange(sig_len).reshape(1,-1)))
cos_waves1C = np.exp(1j*2*np.pi*osfactor*(-cos_idx[:n_sc//2]).reshape(-1,1)/nfft*(np.arange(sig_len).reshape(1,-1)))

cos_waves2 = np.exp(1j*2*np.pi*osfactor*cos_idx[n_sc//2:].reshape(-1,1)/nfft*(np.arange(sig_len).reshape(1,-1)))
cos_waves2C = np.exp(1j*2*np.pi*osfactor*(-cos_idx[n_sc//2:]).reshape(-1,1)/nfft*(np.arange(sig_len).reshape(1,-1)))

def generate_sig(coeff=1, sigtype=1):
    cos_waves0 = cos_waves1 if sigtype==1 else cos_waves2
    cos_waves0C = cos_waves1C if sigtype==1 else cos_waves2C
    syms = coeff*(np.random.randn(cos_waves0.shape[0], 1))
    
    sig_comp = np.vstack((syms * cos_waves0, syms * cos_waves0C)) * 1/np.sqrt(n_sc)
    sig = sig_comp.sum(axis=0)
    return sig, sig_comp, syms

def add_interference(sig):
    interference, _, _ = generate_sig(coeff=4, sigtype=2)
    return sig+interference

seed_everything(420123, workers=True)
np.random.seed(420123)
random.seed(420123)
all_sig, all_sig_noisy = [], [] 
for _ in tqdm(range(n_test)):
    sig, sig_comp, syms = generate_sig()
    idx = 0#np.random.randint(64)
    sig_noisy = add_interference(sig)
    all_sig.append(sig[idx:idx+sig_len-160])
    all_sig_noisy.append(sig_noisy[idx:idx+sig_len-160])
    
window_len = sig_len - 160

all_sig = np.array(all_sig).reshape(-1, 1, window_len).real
all_sig_noisy = np.array(all_sig_noisy).reshape(-1, 1, window_len).real

tensor_x = torch.Tensor(all_sig_noisy)
tensor_y = torch.Tensor(all_sig)


test_dataset = TensorDataset(tensor_x,tensor_y)
test_loader = DataLoader(test_dataset, batch_size=16, num_workers=40)

Global seed set to 42
Global seed set to 420123
100%|██████████| 1000/1000 [00:02<00:00, 485.23it/s]


In [4]:
import glob
from waveunet import Waveunet
from asteroid.models import SuDORMRFImprovedNet, DPTNet, DPRNNTasNet, ConvTasNet

all_residual = {}
for idx in range(7):
    if idx == 0:
        model = Waveunet(n_src=1, n_first_filter=20, depth=5)
        model_name = 'waveunet_longksz_20filters_5depth'
    elif idx == 1:
        model = Waveunet(n_src=1, long_kernel_size=15, n_first_filter=1)
        model_name = 'waveunet0'
    elif idx == 2:
        model = SuDORMRFImprovedNet(n_src=1)
        model_name = 'sudormrf'
    elif idx ==3:
        model = DPTNet(n_src=1)
        model_name = 'dptnet'
    elif idx==4:
        model = DPRNNTasNet(n_src=1)
        model_name = 'dprnntasnet'
    elif idx==5:
        model = ConvTasNet(n_src=1)
        model_name = 'convtasnet'
        
    elif idx == 6:
        model = Waveunet(n_src=1, n_first_filter=20, depth=12)
        model_name = 'waveunet_longksz_20filters_12depth'

    folder_name = f"models/case1/{model_name}/sinsep_4096/"
    # file_list = os.listdir(folder_name)
    file_list = glob.glob(folder_name+"*")
    file_list = sorted(file_list, key=lambda t: -os.stat(t).st_mtime)
    file_list = [ fname for fname in file_list if fname.endswith('.ckpt')]
    filename = file_list[0]
    path_name = filename
    
    loss = torch.nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, amsgrad=True)
    system = System(model, optimizer, loss, test_loader, test_loader)

    ckpt = torch.load(path_name, map_location=torch.device('cpu'))
    system.load_state_dict(ckpt['state_dict'], strict=False)

    system.eval()
    with torch.no_grad():
        sig_est = system(tensor_x)
    residual = (sig_est - tensor_y).cpu().detach().numpy().squeeze()
    all_residual[model_name] = residual
    print("Model ", model_name,": MSE ", f"{10*np.log10(np.mean(np.mean(np.abs(residual)**2, axis=1))):.03f}", " ", filename)

Model  waveunet_longksz_20filters_5depth : MSE  -65.526   models/case1/waveunet_longksz_20filters_5depth/sinsep_4096/epoch=1999-step=11250000.ckpt
Model  waveunet0 : MSE  -57.246   models/case1/waveunet0/sinsep_4096/epoch=191-step=1080000.ckpt
Model  sudormrf : MSE  -37.023   models/case1/sudormrf/sinsep_4096/epoch=1978-step=11131875.ckpt
Model  dptnet : MSE  -36.825   models/case1/dptnet/sinsep_4096/epoch=1078-step=6069375.ckpt
Model  dprnntasnet : MSE  -41.425   models/case1/dprnntasnet/sinsep_4096/epoch=1995-step=11227500.ckpt
Model  convtasnet : MSE  -40.790   models/case1/convtasnet/sinsep_4096/epoch=1846-step=10389375.ckpt
Model  waveunet_longksz_20filters_12depth : MSE  -58.055   models/case1/waveunet_longksz_20filters_12depth/sinsep_4096/epoch=851-step=4792500.ckpt


In [6]:
import pickle
pickle.dump(all_residual, open('tmp_output/case1_residual_outputs.pkl', 'wb'))