# Load model and ckp

In [None]:
import torch
import pickle
from EaBNet import make_eabnet_with_postnet
%reload_ext autoreload
%autoreload 2

In [None]:
args_path = 'data/experiments/eabnet/train_bs8/args.pickle'
with open(args_path, 'rb') as file:
    args = pickle.load(file)

ckp_path = 'data/experiments/eabnet/train_2nd_stage_with_postnet/checkpoints/351864.pth'
model_ckp = torch.load(ckp_path, map_location='cpu')['model_state_dict']
model = make_eabnet_with_postnet(args)
model.load_state_dict(model_ckp, strict=True)

# Make a dataset sample 

In [None]:
from dataset.mcse_dataset import generate_random_noisy_for_speech, load_audio_and_random_crop
import json
import os
from scipy.io import wavfile
import scipy.signal as signal
import matplotlib.pyplot as plt
import IPython

In [None]:
# load all noises
noise_root = 'data/datasets/datasets_fullband/noise_fullband'
noise_records = os.listdir(noise_root)
noise_records.sort()
print(len(noise_records))

## resample to 16khz

In [None]:
audio_name = '00032'
audio_dir = './demo'
audio_path = os.path.join(audio_dir, audio_name+'.wav')
fs, audio = wavfile.read(audio_path)
print(fs,audio.shape)
assert len(audio.shape)==1

resample_fs = 16000
if fs!=resample_fs:
    audio = signal.resample(audio, int(resample_fs*len(audio)/fs)).astype(audio.dtype)
    wavfile.write(audio_path, resample_fs, audio)
    fs, audio = wavfile.read(audio_path)
    print(fs,audio.shape)

In [None]:
# def generate_random_noisy_for_speech(opt, clip_seconds, target_speech, all_noises, speech_root, noise_root, speech_start_sec=None):
with open('dataset/mcse_dataset_settings.json','r') as f:
    opt = json.load(f)

target_speech = audio_name + '.wav'
speech_root = audio_dir
noise_root = 'demo/noise'
all_noises = os.listdir(noise_root)
clip_seconds = 2

specific = {
    'noisy_dBFS': -30,
    'noise_snr_list': [-1.5,-1.5]
}

sample = generate_random_noisy_for_speech(opt, clip_seconds, target_speech, all_noises, speech_root, noise_root, speech_start_sec=0, 
                                         specific=specific)
meta = sample['meta']
room = sample['room']
freefield = sample['freefield']
noisy = sample['noisy']
clean = sample['clean']
fig, ax = room.plot(img_order=0)
print(ax)
ax.view_init(elev=90,azim=-90)
ax.set_xlabel('x')
ax.set_ylabel('y')
print(meta)
plt.show()

def show_audio(audio,fs,name):
    print(f'{name} shape={audio.shape}')
    plt.plot(audio)
    plt.show()
    IPython.display.display(IPython.display.Audio(audio, rate=fs))
    
# plt.plot(room.rir[0][0])
# plt.show()
# plt.plot(freefield.rir[0][0])
# plt.show()
print(noisy.shape)
print(clean.shape)
show_audio(noisy[0],fs,'noisy[0]')
show_audio(clean,fs,'clean')

# Model inference

In [None]:
from train_distributed import prepare_data

model = model.cuda()

def to_tensor(x):
    return torch.tensor(x,dtype=torch.float)[None]
noisy_stft, target_stft = prepare_data(to_tensor(noisy), to_tensor(clean), 'cuda', args)

with torch.no_grad():
    output = model(noisy_stft)
    
device = 'cuda'
esti_stft=output['esti_stft']
print(esti_stft.shape)
sr = args.sr
wav_len = int(args.wav_len * sr)
win_size = int(args.win_size * sr)
win_shift = int(args.win_shift * sr)
fft_num = args.fft_num
esti_stft = esti_stft.permute(0, 3, 2, 1)
print(esti_stft.shape)
esti_wav = torch.istft(torch.view_as_complex(esti_stft.contiguous()), fft_num, win_shift, win_size, torch.hann_window(win_size).to(device))
esti_wav = esti_wav.cpu().numpy()   #[1, 76640]

show_audio(esti_wav[0], 16000,'esti_wav')