In [1]:
%load_ext autoreload
%autoreload 2

import os
os.chdir(globals()['_dh'][0])
os.chdir('..')
print(os.path.abspath(os.curdir))

/home/gridsan/glcf411/ofdm_neural_architecture


In [2]:
from tqdm import tqdm
import numpy as np
import random
import matplotlib.pyplot as plt

from pytorch_lightning import Trainer, seed_everything

In [3]:
nfft = 64
sig_len = 4096 + 160
osfactor = 1
cos_waves = np.exp(1j*2*np.pi*osfactor*np.arange(nfft).reshape(-1,1)/nfft*(np.arange(sig_len).reshape(1,-1)))
n_sc = 28

In [4]:
import torch
from torch import optim
from asteroid.engine import System
from torch.utils.data import TensorDataset, DataLoader

n_test = 1000

nfft = 64
sig_len = 4096 + 160
osfactor = 1
cos_waves = np.exp(1j*2*np.pi*osfactor*np.arange(nfft).reshape(-1,1)/nfft*(np.arange(sig_len).reshape(1,-1)))

n_sc = 28
def generate_sig(coeff=1):
    # syms = coeff*(2*np.random.randint(2, size=(nfft, 1)) - 1)
    syms = coeff*(2*np.random.randint(4, size=(nfft,1)) - 3)/np.sqrt(5)
    syms[0,:] = 0
    syms[n_sc+1:,:] = 0
    syms[nfft//2+1:,:] = np.flipud(syms[1:nfft//2,:])
    sig_comp = syms * 1/np.sqrt(2*n_sc) * cos_waves
    sig = sig_comp.sum(axis=0)
    return sig, sig_comp, syms

def add_noise(sig, noise_pow=0.01):
    noise = np.sqrt(noise_pow)*np.random.randn(len(sig))
    return sig + noise

def add_interference(sig):
    interference, _, _ = generate_sig(coeff=4)
    return sig+interference

def reconstruct_sig(syms):
    sig_comp = syms * np.sqrt(2)/np.sqrt(n_sc) * cos_waves
    sig = sig_comp.sum(axis=0)
    return sig, sig_comp

seed_everything(420123, workers=True)
np.random.seed(420123)
random.seed(420123)
all_sig, all_sig_noisy = [], [] 
for _ in tqdm(range(n_test)):
    sig, sig_comp, syms = generate_sig()
    idx = 0#np.random.randint(64)
    sig_noisy = add_interference(sig)
    all_sig.append(sig[idx:idx+sig_len-160])
    all_sig_noisy.append(sig_noisy[idx:idx+sig_len-160])
    
window_len = sig_len - 160

all_sig = np.array(all_sig).reshape(-1, 1, window_len).real
all_sig_noisy = np.array(all_sig_noisy).reshape(-1, 1, window_len).real

tensor_x = torch.Tensor(all_sig_noisy)
tensor_y = torch.Tensor(all_sig)


test_dataset = TensorDataset(tensor_x,tensor_y)
test_loader = DataLoader(test_dataset, batch_size=16, num_workers=40)

Global seed set to 420123
100%|██████████| 1000/1000 [00:01<00:00, 620.14it/s]


In [5]:
import glob
from waveunet import Waveunet
from asteroid.models import SuDORMRFImprovedNet, DPTNet, DPRNNTasNet, ConvTasNet

all_residual = {}
for idx in range(7):
    if idx == 0:
        model = Waveunet(n_src=1, n_first_filter=20, depth=5)
        model_name = 'waveunet_longksz_20filters_5depth'
    elif idx == 1:
        model = Waveunet(n_src=1, long_kernel_size=15, n_first_filter=1)
        model_name = 'waveunet0'
    elif idx == 2:
        model = SuDORMRFImprovedNet(n_src=1)
        model_name = 'sudormrf'
    elif idx ==3:
        model = DPTNet(n_src=1)
        model_name = 'dptnet'
    elif idx==4:
        model = DPRNNTasNet(n_src=1)
        model_name = 'dprnntasnet'
    elif idx==5:
        model = ConvTasNet(n_src=1)
        model_name = 'convtasnet'

    elif idx == 6:
        model = Waveunet(n_src=1, n_first_filter=20, depth=12)
        model_name = 'waveunet_longksz_20filters_12depth'
        
    folder_name = f"models/case4/{model_name}/sinsep_4096/"
    # file_list = os.listdir(folder_name)
    file_list = glob.glob(folder_name+"*")
    file_list = sorted(file_list, key=lambda t: -os.stat(t).st_mtime)
    file_list = [ fname for fname in file_list if fname.endswith('.ckpt')]
    filename = file_list[0]
    path_name = filename
    
    loss = torch.nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, amsgrad=True)
    system = System(model, optimizer, loss, test_loader, test_loader)

    ckpt = torch.load(path_name, map_location=torch.device('cpu'))
    system.load_state_dict(ckpt['state_dict'], strict=False)

    system.eval()
    with torch.no_grad():
        sig_est = system(tensor_x)
    residual = (sig_est - tensor_y).cpu().detach().numpy().squeeze()
    all_residual[model_name] = residual
    print("Model ", model_name,": MSE ", f"{10*np.log10(np.mean(np.mean(np.abs(residual)**2, axis=1))):.03f}", " ", filename)

Model  waveunet_longksz_20filters_5depth : MSE  -41.156   models/case4/waveunet_longksz_20filters_5depth/sinsep_4096/epoch=1911-step=10755000.ckpt
Model  waveunet0 : MSE  -4.665   models/case4/waveunet0/sinsep_4096/epoch=290-step=1636875.ckpt
Model  sudormrf : MSE  -11.495   models/case4/sudormrf/sinsep_4096/epoch=313-step=1766250.ckpt
Model  dptnet : MSE  -2.432   models/case4/dptnet/sinsep_4096/epoch=60-step=343125.ckpt
Model  dprnntasnet : MSE  -0.542   models/case4/dprnntasnet/sinsep_4096/epoch=18-step=106875.ckpt
Model  convtasnet : MSE  -0.913   models/case4/convtasnet/sinsep_4096/epoch=16-step=95625.ckpt
Model  waveunet_longksz_20filters_12depth : MSE  -46.023   models/case4/waveunet_longksz_20filters_12depth/sinsep_4096/epoch=1999-step=11250000.ckpt


In [6]:
import pickle
pickle.dump(all_residual, open('tmp_output/case4_residual_outputs.pkl', 'wb'))

In [7]:
from waveunet import Waveunet
all_residual = {}
all_k_sz = [15, 21, 31, 51, 63, 65, 71, 81, 91, 101, 151, 201]
for k_sz in all_k_sz:
    model = Waveunet(n_src=1, long_kernel_size=k_sz, n_first_filter=20, depth=5)
    # model.cuda()
    folder_name = f"models/case4_ksizes/waveunet_20filters_5depth_ksz{k_sz}/sinsep_4096/"

    file_list = os.listdir(folder_name)
    file_list = [ fname for fname in file_list if fname.endswith('.ckpt')]
    filename = file_list[0]
    path_name = folder_name+filename

    loss = torch.nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, amsgrad=True)
    system = System(model, optimizer, loss, test_loader, test_loader)

    ckpt = torch.load(path_name, map_location=torch.device('cpu'))
    system.load_state_dict(ckpt['state_dict'], strict=False)

    system.eval()
    with torch.no_grad():
        sig_est = system(tensor_x)
    residual = (sig_est - tensor_y).cpu().detach().numpy().squeeze()
    all_residual[k_sz] = residual
    print("KSz ", k_sz,": MSE ", f"{10*np.log10(np.mean(np.mean(np.abs(residual)**2, axis=1))):.03f}", " ", path_name)

KSz  15 : MSE  -6.030   models/case4_ksizes/waveunet_20filters_5depth_ksz15/sinsep_4096/epoch=344-step=1940625.ckpt
KSz  21 : MSE  -5.621   models/case4_ksizes/waveunet_20filters_5depth_ksz21/sinsep_4096/epoch=277-step=1563750.ckpt
KSz  31 : MSE  -6.183   models/case4_ksizes/waveunet_20filters_5depth_ksz31/sinsep_4096/epoch=493-step=2778750.ckpt
KSz  51 : MSE  -16.319   models/case4_ksizes/waveunet_20filters_5depth_ksz51/sinsep_4096/epoch=1997-step=11238750.ckpt
KSz  63 : MSE  -41.380   models/case4_ksizes/waveunet_20filters_5depth_ksz63/sinsep_4096/epoch=1999-step=11250000.ckpt
KSz  65 : MSE  -42.824   models/case4_ksizes/waveunet_20filters_5depth_ksz65/sinsep_4096/epoch=1997-step=11238750.ckpt
KSz  71 : MSE  -42.099   models/case4_ksizes/waveunet_20filters_5depth_ksz71/sinsep_4096/epoch=1992-step=11210625.ckpt
KSz  81 : MSE  -42.690   models/case4_ksizes/waveunet_20filters_5depth_ksz81/sinsep_4096/epoch=1996-step=11233125.ckpt
KSz  91 : MSE  -44.727   models/case4_ksizes/waveunet_20f

In [8]:
import pickle
pickle.dump(all_residual, open('tmp_output/case4_ksz_residual_outputs.pkl', 'wb'))