In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# Data Downloading & Unzid

In [None]:
 !gdown --id 1thV_9B1Noyf2Q91FTVJwGrS-Xh0BjqFT
 !unzip '/content/only2speaker.zip'

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: min/tr/mix/01aa010i_4.4784_209c0109_-4.4784.wav  
  inflating: min/tr/mix/01ao0308_0.88225_209o0101_-0.88225.wav  
  inflating: min/tr/mix/01ao0318_4.2526_209c010m_-4.2526.wav  
  inflating: min/tr/mix/01ao030e_2.2716_209o010c_-2.2716.wav  
  inflating: min/tr/mix/01ac0216_3.7128_209o0102_-3.7128.wav  
  inflating: min/tr/mix/01ao030r_4.8083_209a010y_-4.8083.wav  
  inflating: min/tr/mix/01ac020a_3.0031_209a010h_-3.0031.wav  
  inflating: min/tr/mix/01ac0211_3.1823_209a010p_-3.1823.wav  
  inflating: min/tr/mix/01aa0103_4.3498_209o0109_-4.3498.wav  
  inflating: min/tr/mix/01aa010n_4.0892_209o010h_-4.0892.wav  
  inflating: min/tr/mix/01aa010k_1.9741_209a010c_-1.9741.wav  
  inflating: min/tr/mix/01ao030r_0.072955_209a0103_-0.072955.wav  
  inflating: min/tr/mix/01aa010f_3.325_209a010n_-3.325.wav  
  inflating: min/tr/mix/01ac0214_1.633_209a0102_-1.633.wav  
  inflating: min/tr/mix/01ao0309_4.9094_209o0102_-4

# Preprocessing & Dataset & DataLoader

## Preprocessing directory file & reading it as json

In [None]:
import os
import json
import librosa
def preprocess_one_dir(in_dir, out_dir, filename,sample_rate):
    file_infos = []
    wav_list = os.listdir(in_dir)
    for wav in wav_list:
        if not wav.endswith('.wav'):
            continue
        wav_path = os.path.join(in_dir, wav)
        samples,_ = librosa.load(wav_path, sr = sample_rate)
        file_infos.append((wav_path, len(samples)))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    with open(os.path.join(out_dir, filename + '.json'), 'w') as f:
        json.dump(file_infos, f, indent = 4)

def preprocess():
    for data_type in ['tr', 'cv', 'tt']:
        for speaker in ['mix', 's1', 's2']:
            preprocess_one_dir(os.path.join('min',data_type, speaker),
                               os.path.join('out',data_type), speaker, sample_rate = 8000)

## Some useful tools, load_wav_file to varialble (batched)

In [None]:
def load_mixtures_and_sources(batch):

    mixtures, sources = [], []

    mix_infos, s1_infos, s2_infos, sample_rate = batch

    for mix_info, s1_info, s2_info in zip(mix_infos, s1_infos, s2_infos):
        mix_path = mix_info[0]
        s1_path = s1_info[0]
        s2_path = s2_info[0]

        assert mix_info[1] == s1_info[1] and s1_info[1] == s2_info[1]

        # read wav file
        mix, _ = librosa.load(mix_path, sr = sample_rate)
        s1, _ = librosa.load(s1_path, sr = sample_rate)
        s2, _ = librosa.load(s2_path, sr = sample_rate)

        s = np.dstack((s1, s2))[0]

        utt_len = mix.shape[-1]

        mixtures.append(mix)
        sources.append(s)

    return mixtures, sources


def load_mixtures(batch):

    mixtures, filenames = [], []

    mix_infos, sample_rate = batch

    for mix_info in mix_infos:
        mix_path = mix_info[0]

        mix,_ = librosa.load(mix_path, sr = sample_rate)
        mixtures.append(mix)
        filenames.append(mix_path)

    return mixtures, filenames

def pad_list(xs, pad_value):
    n_batch = len(xs)
    max_len = max(x.size(0) for x in xs)
    pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
    for i in range(n_batch):
        pad[i, :xs[i].size(0)] = xs[i]
    return pad

    

## Dataset & DataLoader

In [None]:
import torch.utils.data as data
def sort(infos): 
    return sorted(infos, key=lambda info: int(info[1]), reverse=True)
# Reading Data to Dataset, adding batch_size
# output a object (list), 
class AudioDataset(data.Dataset):

    def __init__(self, json_dir, batch_size, sample_rate = 8000):

        super(AudioDataset, self).__init__()
        mix_json = os.path.join(json_dir, 'mix.json')
        s1_json = os.path.join(json_dir, 's1.json')
        s2_json = os.path.join(json_dir, 's2.json')

        with open(mix_json, 'r') as f:
            mix_infos = json.load(f)
        with open(s1_json, 'r') as f:
            s1_infos = json.load(f)
        with open(s2_json, 'r') as f:
            s2_infos = json.load(f)
        
        sorted_mix_infos = sort(mix_infos)
        sorted_s1_infos = sort(s1_infos)
        sorted_s2_infos = sort(s2_infos)

        
        minibatch = []
        start = 0
        while True:
            end = min(len(mix_infos), start + batch_size)

            minibatch.append([sorted_mix_infos[start:end],
                              sorted_s1_infos[start:end],
                              sorted_s2_infos[start:end],
                              sample_rate])
            if end == len(mix_infos):
                break
            start = end

        self.minibatch = minibatch

    def __getitem__(self, index):
        return self.minibatch[index]

    def __len__(self):
        return len(self.minibatch)


class AudioDataLoader(data.DataLoader):

    def __init__(self, *args, **kwargs):
        super(AudioDataLoader, self).__init__(*args, **kwargs)
        self.collate_fn = _collate_fn


def _collate_fn(batch):


    assert len(batch) == 1
    mixtures, sources = load_mixtures_and_sources(batch[0])

    ilens = np.array([mix.shape[0] for mix in mixtures])

    pad_value = 0
    mixtures_pad = pad_list([torch.from_numpy(mix).float() for mix in mixtures], pad_value)
    ilens  = torch.from_numpy(ilens)
    sources_pad = pad_list([torch.from_numpy(s).float() for s in sources], pad_value)

    sources_pad = sources_pad.permute((0, 2, 1)).contiguous()

    return mixtures_pad, ilens, sources_pad

In [None]:
class EvalDataset(data.Dataset):

    def __init__(self, mix_dir, mix_json, batch_size, sample_rate = 8000):

        super(EvalDataset, self).__init__()

        assert mix_dir != None or mix_json != None
        if mix_dir is not None:
            preprocess_one_dir(mix_dir, mix_dir, 'mix', sample_rate = sample_rate)
            mix_json = os.path.join(mix_dir, 'mix.json')
        
        with open(mix_json, 'r') as f:
            mix_infos = json.load(f)

        sorted_mix_infos = sort(mix_infos)
        mini_batch = []
        start = 0
        while True:
            end = min(len(sorted_mix_infos), start + batch_size)
            mini_batch.append([sorted_mix_infos[start:end], sample_rate])
            if end == len(sorted_mix_infos):
                break
            start = end
        self.mini_batch = mini_batch

    def __getitem__(self, index):
        return self.mini_batch[index]

    def __len__(self):
        return len(self.mini_batch)


class EvalDataLoader(data.DataLoader):

    def __init__(self, *args, **kwargs):
        super(EvalDataLoader, self).__init__(*args, **kwargs)
        self.collate_fn = __collate_fn_eval

    
def __collate_fn_eval(batch):
    assert len(batch) == 1
    mixtures, filenames = load_mixtures(batch[0])

    ilens = np.array([mix.shape[0] for mix in mixtures])

    pad_value = 0
    mixtures_pad = padlist([torch.from_numpy(mix).float() for mix in mixtures], pad_value)

    ilens = torch.from_numpy(ilens)

    return mixtures_pad, ilens, filenames
    

# GRU

In [None]:
class GRUModel(nn.Module):
    def __init__(self, original_shape = 0):
        super(GRUModel, self).__init__()
        self.shape = original_shape
        self.gru = torch.nn.GRU(128, 128, 4, dropout=0.2, batch_first=True)
        self.fc_input = nn.Sequential(*[
                nn.Linear(self.shape, 256),
                nn.ReLU(),
                nn.Dropout(0.4),
                nn.Linear(256, 128)
            ])
        
        
        self.gru = nn.ModuleList([torch.nn.GRU(128, 128, 1, batch_first=True) for _ in range(4)])
        self.ln = nn.ModuleList([torch.nn.LayerNorm(128) for _ in range(4)])
        # self.bn = nn.ModuleList([torch.nn.BNorm(128) for _ in range(4)])

        self.fc_output = nn.Sequential(*[
                nn.Linear(128, 256),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(256, self.shape),
                nn.ReLU()
            ])

    def total_param(self):
        return sum([p.numel() for p in self.parameters()]) / 1000.0 / 1000.0

    def forward(self, x):
        mix = x
        out = self.fc_input(x).cuda()
        out = out.view(out.size(0),1, out.size(1))

        for i in range(4):
            out, _ = self.gru[i](out)
            out = self.ln[i](out)

        mask = self.fc_output(out)
        estimated_s1 = mix.view(10,1,self.shape) * mask
        estimated_s2 = mix.view(10,1,self.shape) - estimated_s1
        final = torch.stack([estimated_s1, estimated_s2], 1)
 
        return final.view(10,2,self.shape)



In [None]:
import time
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

print('Training...')
start = time.time()
total_loss = 0

for epoch in range(0, 10):
    for i, (dataset) in enumerate(tr_dataloader):
        padded_mixture, mixture_lengths, padded_source = dataset

        padded_mixture = padded_mixture.cuda()
        mixture_lengths = mixture_lengths.cuda()
        padded_source = padded_source.cuda()
        model = GRUModel(padded_source.size(-1))
        model.cuda()
        model.train()
        estimate_source = model(padded_mixture)

        loss, max_snr, estimate_source, reorder_estimate_source = \
        cal_loss(padded_source, estimate_source, mixture_lengths)
        
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        tr_avg_loss = total_loss / (i + 1)
        print('-' * 100)
        print('Train Summary | End of Epoch {0} | Time {1: .2f}s | Train Loss {2: .3f}'.format(epoch + 1, time.time() - start, tr_avg_loss))
        print('-' * 100)

Training...
----------------------------------------------------------------------------------------------------
Train Summary | End of Epoch 1 | Time  29.74s | Train Loss  5.425
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
Train Summary | End of Epoch 1 | Time  31.23s | Train Loss  5.336
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
Train Summary | End of Epoch 1 | Time  32.93s | Train Loss  5.193
----------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------
Train Summary | End of Epoch 1 | Time  34.06s | Train Loss  5.068
-----------------

KeyboardInterrupt: ignored

--------------------------**Tasnet**----------------------------------

# Conv - TasNet

In [None]:
class ConvTasNet(nn.Module):

    '''
    N : number of filters in autoencoder
    L : Length of the filters (in samples)
    B : number of channels in bottleneck 1 x 1-conv block
    H : number of channels in convolutional blocks
    P : kernel size in convolutional blocks
    X : number of convolutional blocks in each repeat
    R : number of repeats
    C : number of speakers

    Encoder 
    -> 1-D Conv 
    -> Layer Norm 
    -> 1*1 Conv 
    -> 1-D Conv -> SUM UP -> PReLU -> 1*1 Conv -> Sigmoid -> 1-D Conv -> Seperated sources
    Encoder : para  (L, N) number of filter & length of filter 



    '''

    def __init__(self, N, L, B, H, P, x, R, C, norm_type = 'gLN', causal = False, mask_nonlinear = 'relu'):
        super(ConvTasNet,self).__init__()
        # Hypyer-parameters
        self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C
        self.norm_type = norm_type
        self.causal = causal
        self.mask_nonlinear = mask_nonlinear
        self.encoder = Encoder(L, N)
        self.seperator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear)
        self.decoder = Decoder(N, L)

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p)


    def forward(self, mixture):

        mixutre_w = self.encoder(mixture)
        est_mask = self.separator(mixture_w)
        est_source = self.decoder(mixture_w, est_mask)

        T_origin = mixture.size(-1)
        T_conv = est_source.size(-1)
        est_source = F.pad(est_source, (0, T_origin - T_conv))
        return est_source













#AutoEncoder

In [None]:
# AutoEncoder

class Encoder(nn.Module):
    """
    N: number of filters in autoencoder
    L: Length of the filters (in samples)

nn.Conv1d:
input size = (Batch_size, Number of channels, length of signal sequence)
output size = (Batch_size, number of output channels, length of )

mixture shape =  [batch_size, sample]
    """ 
    def __init__(self, L, N):
        super(Encoder, self).__init__()

        self.L = L
        self.N = N

        self.conv1d_U = nn.Conv1d(1, N, kernel_size = L, stride = L//2, bias = False)

    def forward(self, mixture):
        mixture = torch.unsqueeze(mixture, 1)   # shape [batch_size, sample] -> [batch_size, 1, sample]
        mixture_encoded = F.relu(self.conv1d_U(mixture))  
        return mixture_encoded                  # shape [batch_size, N, K]              K = (T-L)/(L/2) + 1


class Decoder(nn.Module):
    def __init__(self, N, L):
        super(Decoder, self).__init__()

        self.N, self.L = N, L

        self.basis_signals = nn.Linear(N, L, bias = False)

    def forward(self, mixture_encoded, est_mask):

        '''
        mixture_encoded shape   [M, N, K]
        est_mask shape  [M, C, N, K]
        '''

        

        source_w = torch.unsqueeze(mixture_encoded, 1) * est_mask  #[M, C, N, K]
        source_w = torch.transpose(source_w, 2, 3)                 #[M, C, K, N]

        est_source = self.basis_signals(source_w)
        est_source = overlap_and_add(est_source, self.L//2)
        return est_source


def overlap_and_add(signal, frame_step):
    outer_dimensions = signal.size()[:-2]
    frames, frame_length = signal.size()[-2:]

    subframe_length = math.gcd(frame_length, frame_step)
    subframe_step = frame_step // subframe_length
    subframes_per_frame = frame_length //subframe_length
    output_size = frame_step * (frames - 1) + frame_length
    output_subframes = output_size // subframe_length

    subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)

    frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
    frame = signal.new_tensor(frame).long()
    frame = frame.contiguous().view(-1)

    result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
    result.index_add_(-2, frame, subframe_signal)
    result = result.view(*outer_dimensions, -1)

    return result


# Block

In [None]:
class TemporalConvNet(nn.Module):

    '''
    N : number of filters in autoencoder
    B : number of channels in bottleneck 1 x 1-conv block
    H : number of channels in convolutional blocks
    P : kernel size in convolutional blocks
    X : number of convolutional blocks in each repeat
    R : number of repeats
    C : number of speakers


    class : TemporalBlock
    class : ChannelwiselayerNorm
    
    '''


    def __init__(self, N, B, H, P, X, R, C, norm_type = 'gLN', causal = False, mask_nonlinear = 'relu'):
        super(TemporalConvNet, self).__init__()
        self.C = C
        self.mask_nonlinear = mask_nonlinear
        
        layer_norm = ChannelwiseLayerNorm(N)

        bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias = False)

        repeats = []

        for r in range(R):
            blocks = []
            for x in range(X):
                dilation = 2**x
                padding = (P - 1) * dilation if causal else (P - 1) * dilation//2

                blocks += [TemporalBlock(B, H, P, stride = 1,
                                         padding = padding,
                                         dilation = dilation,
                                         norm_type = norm_type,
                                         causal = causal)]
            repeats += [nn.Sequential(*blocks)]
        
        temporal_con_net = nn.Sequential(*repeats)

        mask_conv1x1 = nn.Conv1d(B, C*N, 1, bias = False)

        self.network = nn.Sequential(layer_norm,
                                     bottleneck_conv1x1,
                                     temporal_con_net,
                                     mask_conv1x1)
        
    def forward(self, mixture_w):

        M, N, K = mixture_w.size()
        score = self.network(mixture_W)
        score = score.view(M, self.C, N, K)

        if self.mask_nonlinear == 'softmax':
            est_mask = F.softmax(score, dim = 1)
        elif self.mask_nonlinear == 'relu':
            est_mask = F.relu(score)
        return est_mask



class TemporalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type = 'gLN', causal = False):

        super(TemporalBlock, self).__init__()
        
        conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias = False)
        prelu = nn.PReLU()
        norm = chose_norm(norm_type, out_channels)
        dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding, dilation, norm_type, causal)

        self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)


    def forward(self, x):

        residual = x
        out = self.net(x)
        return out + residual


def chose_norm(norm_type, channel_size):

    if norm_type == "glN":
        return GlobalLayerNorm(channel_size)
    elif norm_type == "cLN":
        return ChannelwiselayerNorm（channel_size)
    else:
        return nn.BatchNorm1(channel_size)


# Depthwise Conv + Pointwise Conv


In [None]:

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, norm_type = "gLN", causal = False):

        super(DepthwiseSeparableConv, self).__init__()

        depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, 
                                   stride = stride, padding = padding,
                                   dilation = dilation, groups = in_channels, bias = False)
        
        if causal:
            chomp = Chomp1d(padding)
        
        prelu = nn.PReLU()
        norm = chose_norm(norm_type, in_channels)

        pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias = False)

        if causal:
            self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
        else:
            self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)

    def forward(self, x):
        return self.net(x)

class Chomp1d(nn.Module):

    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()
        


# Norm Layer

In [None]:
class ChannelwiselayerNorm(nn.Module):

    ''' channel wise layer normalization (cLN)'''

    def __init__(self, channel_size):
        super(ChannelwiselayerNorm, self).__init__
        self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1))  # [1, N, 1]
        self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
        self.reset_parameters()

    def reset_parameters():
        self.gamma.data.fill_(1)
        self.beta.data.zero_()

    def forward(self, y):
        '''
        Args: y: [batch size, channel size, length] 

        gLN_y: [M, N, K]
        '''

        var = torch.pow(y, dim = 1, keepdim = True, unbiased = False)  #[M, 1, K]
        cLN_y = self.gamma * (y - mean)/ torch.pow(var + EPS, 0.5) + self.beta

        return cLN_y


class GlobalLayerNorm(nn.Module):
    
    def __ini__(self, channel_size):
        
        super(GlobalLayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1))  #[1, N, 1]
        self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1))   #[1, N, 1]

    def reset_parameters():
        self.gamma.data.fill_(1)
        self.beta.data.zero_()

    def forward(self, y):

        mean = y.mean(dim = 1, keepdim = True).mean(dim = 2,keepdim = True)   #[M, mean, mean]
        var = (torch.pow(y - mean)).mean(dim = 1, keepdim = True).mean(dim = 2, keepdim = True)
        gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta

        return gLN_y



# PIT Loss function

In [None]:
EPS = 1e-8
def get_mask(source, source_lengths):
    B, _, T = source.size()
    mask = source.new_ones((B, 1, T))
    
    for i in range(B):
        mask[i, :, source_lengths[i]:] = 0

    return mask

def cal_loss(source, estimate_source, source_lengths, PIT = True):


    if PIT:
        max_snr, perms, max_snr_idx = cal_si_snr_with_pit(source, estimate_source, source_lengths)
        loss = 0 - torch.mean(max_snr)
        reorder_estimate_source = reorder_source(estimate_source, perms, max_snr_idx)
        return loss, max_snr, estimate_source, reorder_estimate_source
    else:
        si_snr = cal_si_snr(source,estimate_source, source_lengths)
        loss = 0 - torch.mean(si_snr)
        return loss, si_snr, estimate_source, estimate_source


def cal_si_snr(source, estimate_source, source_lengths):

    assert source.size() == estimate_source.size()

    B, C, T = source.size()

    mask = get_mask(source, source_lengths)
    estimate_source *= mask

    num_samples = source_lengths.view(-1, 1, 1).float()
    mean_target = torch.sum(source, dim = 2, keepdim = True)/num_samples
    mean_estimate = torch.sum(estimate_source, dim =2, keepdim = True)/num_samples

    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate

    zero_mean_target *= mask
    zero_mean_estimate *= mask

    s_target = zero_mean_target
    s_estimate = zero_mean_estimate

    pair_wise_dot = torch.sum(s_estimate * s_target, dim = 2, keepdim = True)
    s_target_energy = torch.sum(s_target ** 2, dim = 2, keepdim = True) + EPS
    pair_wise_proj = pair_wise_dot * s_target / s_target_energy

    e_noise = s_estimate - pair_wise_proj

    pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim = 2, keepdim = True) / (torch.sum(e_noise ** 2, dim =2, keepdim = True) + EPS)
    pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS)
    si_snr = torch.mean(pair_wise_si_snr, dim = -1, keepdim = True)


    return si_snr

def cal_si_snr_with_pit(source, estimate_source, source_lengths):

    '''
    source  [B, C, T]
    estimate_source [B, C, T]
    source_lengths [B]
    '''

    assert source.size() == estimate_source.size()
    B, C, T = source.size()
    mask = get_mask(source, source_lengths)                                       # [B, 1, T]  original = 1, padded = 0
    estimate_source *= mask                                                       # [B, 1, T]  

    num_samples = source_lengths.view(-1, 1, 1).float()                           # [B, 1, 1]
    mean_target = torch.sum(source, dim = 2, keepdim = True) / num_samples        # [B, C, 1 (sum)] / [B, 1, 1]  = [B, C, mean]
    mean_estimate = torch.sum(estimate_source, dim = 2, keepdim = True) / num_samples  # [B, C, 1]
    zero_mean_target = source - mean_target                                       # [B, C, T] - [B, C, T]
    zero_mean_estimate = estimate_source - mean_estimate

    zero_mean_target *= mask
    zero_mean_estimate *= mask

    s_target = torch.unsqueeze(zero_mean_target, dim = 1)                         # [B, 1, C, T]
    s_estimate = torch.unsqueeze(zero_mean_estimate, dim = 2)                     # [B, C, 1, T]

    pair_wise_dot = torch.sum(e_stimate * s_target, dim = 3, keepdim = True)      # [B, C, C, 1]
    s_target_energy = torch.sum(s_target ** 2, dim = 3, keepdim = True)           # [B, 1, C, 1]
    pair_wise_proj = pair_wise_dot * s_target / s_target_energy                   # [B, C, C, T、
    

    e_noise = s_estimate - pair_wise_proj                                         # [B, C, C, T]

    pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim = 3)/(torch.sum(e_noise **2, dim = 3) + EPS)  # [B, C, C] / [B, C, C]  
    pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS)                   # [B, C, C]

    perms = source.new_tensor(list(permutation(range(C))), dtype = torch.long)    # [C!, C]

    index = torch.unsqueeze(perms, 2)                                             # [C!, C, 1]
    perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1)    # [C!, C, C]

    snr_set = torch.einsum("bij,pij -> bp", [pair_wise_si_snr, perms_one_hot])    # []
    max_snr_idx = torch.argmax(snr_set, dim = 1)                                  # [B]
    max_snr, _ = torch.max(snr_set, dim = 1, keepdim = True)                      
    max_snr /= c
    return max_snr, perms, max_snr_idx
    




def reorder_source(source, perms, max_snr_idx):
    '''
    source:  [B, C, T]
    perms: [C!, C]
    max_snr_idx [B]


    '''
    B, C, *_ = source.size()
    max_snr_perm = torch.index_select(perms, dim = 0, index = max_snr_idx)

    reorder_source = torch.zeros_like(source)

    for b in range(B):
        for c in range(C):
            reorder_source[b, c] = source[b, max_snr_perm[b][c]]
    return reorder_source
    



# Training

In [None]:
def run_one_epoch(epoch, cross_valid = False):
    start = time.time()
    total_loss = 0

    data_loader = tr_loader if not cross_valid else cv_loader


    if visdom_epoch and not cross_valid:
        vis_opts_epoch = dict(title = visdom_id + 'epoch' + str(epoch), ylabel = 'Loss', xlabel = 'Epoch')
        vis_window_epoch = None
        vis_iters = torch.arange(1, len(data_loader) + 1)
        vis_iters_loss = torch.Tensor(len(data_loader))

    for i, data in enumerate(data_loader):
        padded_mixture, mixture_lengths, padded_source = data
        if use_cuda:
            padded_mixture = padded_mixture.cuda()
            mixture_lengths = mixture_lengths.cuda()
            padded_source = padded_source.cuda()
        estimate_source = model(padded_mixture)
        loss, max_snr, estimate_source, reorder_estimate_source =\
        cal_los(padded_source, estimate_source, mixture_lengths, pit)
        if not cross_valid:
            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

            optimizer.step()
        total_loss += loss.item()

        if i % print_freq == 0:
            print('Epoch {0} | Iter {1} | Average Loss {2:.3f} | '
                      'Current Loss {3:.6f} | {4:.1f} ms/batch'.format(
                          epoch + 1, i + 1, total_loss / (i + 1),
                          loss.item(), 1000 * (time.time() - start) / (i + 1)),
                      flush=True)
        
        if visdom_epoch and not cross_valid:
            vis_iters_loss[i] = loss.item()
            if i % print_freq == 0:
                x_axis = vis_iters[:i+1]
                y_axis = vis_iters_loss[:i+1]
                if vis_window_epoch is None:
                    vis_window_epoch = vis.line(X = x_axis, Y = y_axis, opts = vis_opts_epoch)
                else:
                    vis.line(X = x_axis, Y = y_axis, win = vis_window_epoch, update = 'replace')
    return total_loss / (i + 1)



def train():

    for epoch in range(start_epoch, epochs):
        print('Traning...')
        model.train()
        start = time.time()
        tr_avg_loss = run_one_epoch(epoch)
        print('*' * 100)
        print('Train Summary | End of Epoch {0} | Time {1:.2f}s | '
                  'Train Loss {2:.3f}'.format(
                      epoch + 1, time.time() - start, tr_avg_loss))
        print('*' * 100)

        if checkpoint:
            file_path = os.path.join(
                self.save_folder,  'epoch%d.pth.tar' % (epoch + 1)
            )
            torch.save(model.module.serialize(model.module, optimizer, epoch + 1, tr_loss = tr_loss, cv_loss = cv_loss)
            print('Saving checkpoint model to %s' % file_path)

        print('Cross Validation')
        model.eval()
        val_loss = run_one_epoch(epoch, cross_valid = True)
        print('*' * 100)
        print('Valid Summary | End of Epoch {0} | Time {1:.2f}s | '
                  'Valid Loss {2:.3f}'.format(
                      epoch + 1, time.time() - start, val_loss))
        print('*' * 100)

        if half_lr:
            if val_loss >= pre_val_loss:
                val_no_impv += 1
                if val_no_impv >= 3:
                    halving = True
                if val_no_impv >= 10 and early_stop:
                    print('No improvement, early stopping')
                    break
            else:
                val_no_impv = 0

        if halving:
            optim_state = optimizer.state_dict()
            optim_state['param_groups'][0]['lr'] = \
                optim_state['param_grous'][0]['lr'] / 2.0
            optimizer.load_state_dict(optim_state)
            print('Learning rate adjusted to: {lr:.6f}'.format(
                    lr=optim_state['param_groups'][0]['lr']))
            halving = False
        prev_val_loss = val_loss


        tr_loss[epoch] = tr_avg_loss
        cv_loss[epoch] = val_loss

        if val_loss < best_val_loss:
            best_val_loss = val_loss
                file_path = os.path.join(save_folder, model_path)
                torch.save(model.module.serialize(model.module,
                                    optimizer, epoch + 1,
                                    tr_loss= tr_loss,
                                    cv_loss= cv_loss),file_path)
                print("Find better validated model, saving to %s" % file_path)
        if visdom:
            x_axis = vis_epochs[0:epoch + 1]
            y_axis = torch.stack(
                (tr_loss[0:epoch + 1], cv_loss[0:epoch + 1]), dim = 1
            )

            if vis_window is None:
                vis_window = vis.line(
                    X = x_axis,
                    Y = y_axis,
                    opts = vis_opts
                )
            else:
                vis.line(
                    X = x_axis.unsqueeze(0).expand(y_axis.size(1),
                        x_axis.size(0)).transpose,
                    Y = y_axis,
                    win = vis_window,
                    update = 'replace'
                )



# Main

In [None]:
tr_dataset =
cv_dataset =
tr_dataloader =
cv_dataloader =

data = {}

model = ConvTasNet()

model = torch.nn.DataParallel()
model.cuda()

if optimizer == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(),
                                lr = lr,
                                momentum = momentum,
                                weight_decay = l2)
elif optimizer == 'adam':
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr = lr,
                                 weight_decay = l2)
else:
    print('No optimizer provided')

run = train()