## 进行本地模型测试
### 安装依赖包
#### 需要提前在目录下安装libsndfile
##### cd SageMaker/libsndfile-1.0.28/
#### ./configure --prefix=/usr    \
#### --disable-static \
#### --docdir=/usr/share/doc/libsndfile-1.0.28 &&
#### make && sudo make install

In [18]:
!pip install nvidia-ml-py3
!pip install torch==1.5.1
!pip install numpy==1.16.4
!pip install numba==0.48
!pip install pytorch-lightning==0.7.6
!pip install Cython==0.29.20
!pip install asteroid==0.3.0
!pip install sagemaker-inference==1.3.2.post0
!pip install PyYAML==5.3.1

You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p36/bin/python -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/home/ec2-user/a

In [1]:
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import sys
import os
import time
import librosa
from scipy.signal import lfilter
import numpy as np
import scipy.io.wavfile
import soundfile as sf

from asteroid.data import LibriMix
from asteroid.engine.system import System
from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr
from asteroid import ConvTasNet
from asteroid.models import ConvTasNet
from asteroid.models import DPRNNTasNet

import fnmatch, os, warnings


### 定义帮助函数

In [7]:
def remove_pad(inputs, inputs_lengths):
    """
    Args:
        inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size
        inputs_lengths: torch.Tensor, [B]
    Returns:
        results: a list containing B items, each item is [C, T], T varies
    """
    results = []
    dim = inputs.dim()
    if dim == 3:
        C = inputs.size(1)
    for input, length in zip(inputs, inputs_lengths):
        if dim == 3: # [B, C, T]
            results.append(input[:,:length].view(C, -1).cpu().numpy())
        elif dim == 2:  # [B, T]
            results.append(input[:length].view(-1).cpu().numpy())
    return results

def write(inputs, filename, sr=16000):
    librosa.output.write_wav(filename, inputs, sr)

def wavread(filename):
    fs, x = scipy.io.wavfile.read(filename)
    if np.issubdtype(x.dtype, np.integer):
        x = x / np.iinfo(x.dtype).max
    return x, fs

def wavwrite(filename, s, fs):
    if s.dtype != np.int16:
        s = np.array(s * 2**15, dtype=np.int16)
    if np.any(s > np.iinfo(np.int16).max) or np.any(s < np.iinfo(np.int16).min):
        warnings.warn('Warning: clipping detected when writing {}'.format(filename))
    scipy.io.wavfile.write(filename, fs, s)

def asl_meter(x, fs, nbits=16):
    '''Measure the Active Speech Level (ASR) of x following ITU-T P.56.
    If x is integer, it will be scaled to (-1, 1) according to nbits.
    '''

    if np.issubdtype(x.dtype, np.integer):
        x = x / 2**(nbits-1)

    # Constants
    MIN_LOG_OFFSET = 1e-20
    T = 0.03                # Time constant of smoothing in seconds
    g = np.exp(-1/(T*fs))
    H = 0.20                # Time of handover in seconds
    I = int(np.ceil(H*fs))
    M = 15.9                # Margin between threshold and ASL in dB

    a = np.zeros(nbits-1)                       # Activity count
    c = 0.5**np.arange(nbits-1, 0, step=-1)     # Threshold level
    h = np.ones(nbits)*I                        # Hangover count
    s = 0
    sq = 0
    p = 0
    q = 0
    asl = -100

    L = len(x)
    s = sum(abs(x))
    sq = sum(x**2)
    dclevel = s/np.arange(1, L+1)
    lond_term_level = 10*np.log10(sq/np.arange(1, L+1) + MIN_LOG_OFFSET)
    c_dB = 20*np.log10(c)

    for i in range(L):
        p = g * p + (1-g) * abs(x[i])
        q = g * q + (1-g) * p

        for j in range(nbits-1):
            if q >= c[j]:
                a[j] += 1
                h[j] = 0
            elif h[j] < I:
                a[j] += 1
                h[j] += 1

    a_dB = -100 * np.ones(nbits-1)

    for i in range(nbits-1):
        if a[i] != 0:
            a_dB[i] = 10*np.log10(sq/a[i])

    delta = a_dB - c_dB
    idx = np.where(delta <= M)[0]

    if len(idx) != 0:
        idx = idx[0]
        if idx > 1:
            asl = bin_interp(a_dB[idx], a_dB[idx-1], c_dB[idx], c_dB[idx-1], M)
        else:
            asl = a_dB[idx]

    return asl

def bin_interp(upcount, lwcount, upthr, lwthr, margin, tol=0.1):
    n_iter = 1
    if abs(upcount - upthr - margin) < tol:
        midcount = upcount
    elif abs(lwcount - lwthr - margin) < tol:
        midcount = lwcount
    else:
        midcount = (upcount + lwcount)/2
        midthr = (upthr + lwthr)/2
        diff = midcount - midthr - margin
        while abs(diff) > tol:
            n_iter += 1
            if n_iter > 20:
                tol *= 1.1
            if diff > tol:
                midcount = (upcount + midcount)/2
                midthr = (upthr + midthr)/2
            elif diff < -tol:
                midcount = (lwcount + midcount)/2
                midthr = (lwthr + midthr)/2
            diff = midcount - midthr - margin
    return midcount


def rms_energy(x):
    return 10*np.log10((1e-12 + x.dot(x))/len(x))

def preprocess_wav(filename):
    # filename = './1-20.wav'
    x, fs = wavread(filename)
    x_mirror = x
    # x = np.float32(x)
    # print(x.dtype)
    # N_dB = rms_energy(x_mirror)
    # S_dB = asl_meter(x, fs)
    # print(N_dB)
    # print(S_dB)

    # N_new = S_dB
    # x_mirror = 10**(N_new/20) * x_mirror / 10**(N_dB/20)
    asl_level = -26.0

    y = x + x_mirror
    y = y/10**(asl_meter(y, fs)/20) * 10**(asl_level/20)
    return x, fs

def pad_list(xs, pad_value):
    n_batch = len(xs)
    max_len = max(x.size(0) for x in xs)
    pad = xs[0].new(n_batch, max_len, * xs[0].size()[1:]).fill_(pad_value)
    for i in range(n_batch):
        pad[i, :xs[i].size(0)] = xs[i]
    return pad

def separate_process(x, model):
    ilens = np.array([x.shape[0]])

    # perform padding and convert to tensor
    pad_value = 0
    mixtures_pad = pad_list([torch.from_numpy(x).float()], pad_value)
    ilens = torch.from_numpy(ilens) 
    
    with torch.no_grad():
        mixture, mix_lengths = mixtures_pad.cuda(), ilens.cuda()
        # Forward
        estimate_source = model(mixture)  # [B, C, T]
        # Remove padding and flat
        print("wait.....")
        flat_estimate = remove_pad(estimate_source, mix_lengths)
    return flat_estimate[0]

### 本地加载和推理实验

In [6]:
exp_dir = 'split_weight'
model_path = os.path.join(exp_dir, 'best_model.pth')
model = DPRNNTasNet.from_pretrained(model_path)
model.eval()
model.cuda()
# 10 seconds for each 
instance_time = 10
filename = './audio_raw_1-20_clip_0.wav'
# filename = './data/tt/mix/1-20-t-1_1-20-t-1_0dB.wav'

x, fs = preprocess_wav(filename)
instance_seg = instance_time * fs

l_x = len(x)
s_x = instance_seg
s1 = np.array([])
s2 = np.array([])
gl_i = 0
print('start timer')
Timer1 = time.time()
for start in range(0, l_x-s_x, s_x):
    flat_estimate = separate_process(x[start:start+s_x], model)
    s1 = np.concatenate([s1, flat_estimate[0]])
    s2 = np.concatenate([s2, flat_estimate[1]])
    gl_i = start+s_x
if gl_i < l_x:
    flat_estimate = separate_process(x[gl_i:], model)
    s1 = np.concatenate([s1, flat_estimate[0]])
    s2 = np.concatenate([s2, flat_estimate[1]])
print('finish {}'.format(time.time()-Timer1)) 
write(s1, './complete_split_s1.wav')
write(s2, './complete_split_s2.wav')

start timer
tensor([[[-0.0006, -0.0022, -0.0024,  ...,  0.0016,  0.0015,  0.0008],
         [-0.0004, -0.0010, -0.0012,  ...,  0.0009,  0.0009,  0.0005]]],
       device='cuda:0')
tensor([[[ 7.0338e-04,  1.2980e-03,  8.9964e-04,  ..., -4.0627e-04,
           1.8699e-03,  1.2702e-03],
         [ 4.3443e-05,  2.1753e-04,  1.2414e-04,  ...,  9.2056e-04,
           9.4701e-04,  6.2036e-04]]], device='cuda:0')
tensor([[[ 0.0010,  0.0054,  0.0080,  ..., -0.0066, -0.0066, -0.0038],
         [ 0.0006,  0.0013,  0.0014,  ..., -0.0011, -0.0011, -0.0008]]],
       device='cuda:0')
tensor([[[-0.0020, -0.0048, -0.0048,  ..., -0.0029, -0.0024, -0.0010],
         [-0.0008, -0.0023, -0.0022,  ..., -0.0005, -0.0004, -0.0003]]],
       device='cuda:0')
tensor([[[ 0.0003,  0.0007,  0.0011,  ...,  0.0017,  0.0020,  0.0013],
         [-0.0010, -0.0024, -0.0024,  ...,  0.0010,  0.0009,  0.0006]]],
       device='cuda:0')
tensor([[[0.0003, 0.0012, 0.0016,  ..., 0.0038, 0.0039, 0.0024],
         [0.0009, 0.00