In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import numpy as np
import scipy.fftpack
from scipy.linalg import toeplitz
from scipy.signal import fftconvolve
import itertools
import warnings
import os
import glob
import torch
import torch.nn as nn
import tqdm
import librosa
import time
import random
import soundfile

In [None]:
## FUNCTIONS FOR EVALUATION OF MODEL, TO CALCULATE SDR, SIR AND SAR, TAKEN  FROM THE BSS EVAL TOOLBOX, AS MENTIONED IN OUR PAPER
def bss_eval_sources(reference_sources, estimated_sources):
    if estimated_sources.ndim == 1:
        estimated_sources = estimated_sources[np.newaxis, :]
    if reference_sources.ndim == 1:
        reference_sources = reference_sources[np.newaxis, :]
    if reference_sources.size == 0 or estimated_sources.size == 0:
        return np.array([]), np.array([]), np.array([]), np.array([])
    nsrc = estimated_sources.shape[0]
    sdr = np.empty((nsrc, nsrc))
    sir = np.empty((nsrc, nsrc))
    sar = np.empty((nsrc, nsrc))
    for jest in range(nsrc):
        for jtrue in range(nsrc):
            s_true, e_spat, e_interf, e_artif = \
                    _bss_decomp_mtifilt(reference_sources,
                                        estimated_sources[jest],
                                        jtrue, 512)
            sdr[jest, jtrue], sir[jest, jtrue], sar[jest, jtrue] = \
                    _bss_source_crit(s_true, e_spat, e_interf, e_artif)
    perms = list(itertools.permutations(list(range(nsrc))))
    mean_sir = np.empty(len(perms))
    dum = np.arange(nsrc)
    for (i, perm) in enumerate(perms):
        mean_sir[i] = np.mean(sir[perm, dum])
    popt = perms[np.argmax(mean_sir)]
    idx = (popt, dum)
    return (sdr[idx], sir[idx], sar[idx], np.asarray(popt))

def _bss_decomp_mtifilt(reference_sources, estimated_source, j, flen):
    nsampl = estimated_source.size
    s_true = np.hstack((reference_sources[j], np.zeros(flen - 1)))
    e_spat = _project(reference_sources[j, np.newaxis, :], estimated_source,
                      flen) - s_true
    e_interf = _project(reference_sources,
                        estimated_source, flen) - s_true - e_spat
    e_artif = -s_true - e_spat - e_interf
    e_artif[:nsampl] += estimated_source
    return (s_true, e_spat, e_interf, e_artif)


def _bss_source_crit(s_true, e_spat, e_interf, e_artif):
    s_filt = s_true + e_spat
    sdr = _safe_db(np.sum(s_filt**2), np.sum((e_interf + e_artif)**2))
    sir = _safe_db(np.sum(s_filt**2), np.sum(e_interf**2))
    sar = _safe_db(np.sum((s_filt + e_interf)**2), np.sum(e_artif**2))
    return (sdr, sir, sar)


def _safe_db(num, den):
    if den == 0:
        return np.Inf
    return 10 * np.log10(num / den)


def _project(reference_sources, estimated_source, flen):
    nsrc = reference_sources.shape[0]
    nsampl = reference_sources.shape[1]
    reference_sources = np.hstack((reference_sources,
                                   np.zeros((nsrc, flen - 1))))
    estimated_source = np.hstack((estimated_source, np.zeros(flen - 1)))
    n_fft = int(2**np.ceil(np.log2(nsampl + flen - 1.)))
    sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1)
    sef = scipy.fftpack.fft(estimated_source, n=n_fft)
    G = np.zeros((nsrc * flen, nsrc * flen))
    for i in range(nsrc):
        for j in range(nsrc):
            ssf = sf[i] * np.conj(sf[j])
            ssf = np.real(scipy.fftpack.ifft(ssf))
            ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])),
                          r=ssf[:flen])
            G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss
            G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T
    D = np.zeros(nsrc * flen)
    for i in range(nsrc):
        ssef = sf[i] * np.conj(sef)
        ssef = np.real(scipy.fftpack.ifft(ssef))
        D[i * flen: (i+1) * flen] = np.hstack((ssef[0], ssef[-1:-flen:-1]))
    try:
        C = np.linalg.solve(G, D).reshape(flen, nsrc, order='F')
    except np.linalg.linalg.LinAlgError:
        C = np.linalg.lstsq(G, D)[0].reshape(flen, nsrc, order='F')
    sproj = np.zeros(nsampl + flen - 1)
    for i in range(nsrc):
        sproj += fftconvolve(C[:, i], reference_sources[i])[:nsampl + flen - 1]
    return sproj


def bss_eval(mixed_wav, src1_wav, src2_wav, pred_src1_wav, pred_src2_wav):
    len = pred_src1_wav.shape[0]
    src1_wav = src1_wav[:len]
    src2_wav = src2_wav[:len]
    mixed_wav = mixed_wav[:len]
    bss_eval_sources
    sdr, sir, sar, _ = bss_eval_sources(np.array([src1_wav, src2_wav]),
                                 np.array([pred_src1_wav, pred_src2_wav]))
    sdr_mixed, _, _, _ = bss_eval_sources(np.array([src1_wav, src2_wav]),
                                   np.array([mixed_wav, mixed_wav]))
    nsdr = sdr - sdr_mixed
    return nsdr, sir, sar, len

In [None]:
## RANDOMISING TRAINING DATA AND ENSURING THAT BOTH MALE AND FEMALE AUDIO SAMPLES GO INTO TRAINING

male = ["abjones","bobon","bug","davidson","fdps","geniusturtle","jmzen","Kenshin","khair","leon","stool"] #LIST OF ALL MALE VOICES
male=np.array(male)
female =["amy", "Ani", "annar","ariel","heycat","tammy","titon","yifen"] #LIST OF ALL FEMALE VOICES
female=np.array(female)
Nsample=len(male)
Ntest=0.2 * Nsample
Ntrain=int(0.8*Nsample)
index=np.arange(Nsample)
index_random=np.random.permutation(index)
train_index=index_random[:Ntrain]
test_index=index_random[Ntrain:]
male_train=[]
male_test=[]
for i in range(Ntrain):
  male_train.append(male[train_index[i]])
  
for i in range(Nsample-Ntrain):
  male_test.append(male[test_index[i]])

Nsample=len(female)
Ntest=0.2 * Nsample
Ntrain=int(0.8*Nsample)
index=np.arange(Nsample)
index_random=np.random.permutation(index)
train_index=index_random[:Ntrain]
test_index=index_random[Ntrain:]
female_train=[]
female_test=[]
for i in range(Ntrain):
  female_train.append(female[train_index[i]])
for i in range(Nsample-Ntrain):
  female_test.append(female[test_index[i]])

print(male_train)
print(male_test)
print(female_train)
print(female_test)
test_data = male_test+female_test
random.shuffle(test_data)
print(test_data)
train_d = male_train+female_train
random.shuffle(train_d)
print(train_d)

['Kenshin', 'bobon', 'davidson', 'jmzen', 'fdps', 'abjones', 'leon', 'bug']
['khair', 'geniusturtle', 'stool']
['Ani', 'tammy', 'titon', 'ariel', 'heycat', 'annar']
['amy', 'yifen']
['stool', 'amy', 'khair', 'yifen', 'geniusturtle']
['bug', 'leon', 'Kenshin', 'davidson', 'fdps', 'heycat', 'Ani', 'abjones', 'jmzen', 'annar', 'ariel', 'titon', 'tammy', 'bobon']


In [None]:
#functions to generate data for training and testing, calculate short time fourier transform and inverse stft

def data_generator(train_d,train):
    for wav in glob.glob('/content/drive/MyDrive/wavfile/*'):
        f = os.path.split(wav)[1]
        if (f.startswith(train_d[0]) or f.startswith(train_d[1]) or f.startswith(train_d[2]) or f.startswith(train_d[2]) or f.startswith(train_d[3]) or f.startswith(train_d[4]) or f.startswith(train_d[5]) or f.startswith(train_d[6]) or f.startswith(train_d[7]) or f.startswith(train_d[8]) or f.startswith(train_d[9]) or f.startswith(train_d[10]) or f.startswith(train_d[11]) or f.startswith(train_d[12]) or f.startswith(train_d[13])) == train:
            origin_source, sampling_rate = librosa.load(wav, sr=None, mono=False)  
            
            source_resampled = librosa.resample(origin_source, sampling_rate, 8000)
            mixed_source_origin = librosa.to_mono(source_resampled)
            resampled_left = source_resampled[0]
            resampled_right = source_resampled[1]
            mixed_magnitude_spectrum = np.abs(spectrum(mixed_source_origin))
            right_magnitude_spectrum = np.abs(spectrum(np.asfortranarray(resampled_right)))
            left_magnitude_spectrum = np.abs(spectrum(np.asfortranarray(resampled_left))) 
            max_value = np.max(mixed_magnitude_spectrum)
            mixed_phase = np.angle(spectrum(mixed_source_origin)) # won't be used in training, preserving this for obtaining the predicted wavfile from the predicted feature masks
            #Normalizing spectrograms
            mixed_magnitude_spectrum = mixed_magnitude_spectrum / max_value
            right_magnitude_spectrum = right_magnitude_spectrum / max_value
            left_magnitude_spectrum = left_magnitude_spectrum / max_value
            yield origin_source[0, :], origin_source[1, :], librosa.to_mono(origin_source), left_magnitude_spectrum, right_magnitude_spectrum, mixed_magnitude_spectrum, max_value, mixed_phase

def spectrum(wav):
    return librosa.stft(wav, n_fft = 1024, hop_length = 256)

def to_wav(mag, phase):
    mat = mag * np.exp(1.j * phase)
    return np.array(librosa.istft(mat, hop_length = 256))

In [None]:
# defining our model 

class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(Conv, self).__init__()
        self.model = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                          padding=(kernel_size - 1) // 2),
                nn.BatchNorm2d(num_features=out_channels),
                nn.ReLU())
    def forward(self, x):
        return self.model(x) 

class Hourglass(nn.Module): #one hourglass module
    def __init__(self, depth, output_channels, add_channels):
        super(Hourglass, self).__init__()
        next_channels = output_channels + add_channels
        self.model = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                   Conv(output_channels, next_channels),
                                   Hourglass(depth - 1, next_channels,
                                             add_channels) if depth else Conv(
                                       next_channels, next_channels),
                                   Conv(next_channels, output_channels),
                                   nn.UpsamplingNearest2d(scale_factor=2)) 
        self.skip = Conv(output_channels, output_channels)
    def forward(self, x):
        return self.skip(x) + self.model(x)

class StackedHourglassNet(nn.Module): #stacking the hourglasses to build the final structure
    def __init__(self, num_stacks, channels, output_channels, add_channels):
        super(StackedHourglassNet, self).__init__()
        self.num_stacks = num_stacks
        self.prepare = nn.Sequential(Conv(1, 64, kernel_size=7, stride=1),
                                     Conv(64, 128), 
                                     Conv(128, 128),
                                     Conv(128, channels))
        self.hourglass = nn.ModuleList(
            nn.Sequential(
                Hourglass(depth=4, output_channels=channels, add_channels=add_channels),
                Conv(channels, channels),
                Conv(channels, channels, 1)) for i in range(num_stacks))
        self.output = nn.ModuleList(Conv(channels, output_channels) for i in range(num_stacks))
        self.next = nn.ModuleList(Conv(channels, channels) for i in range(num_stacks - 1))
        self.merge = nn.ModuleList(Conv(output_channels, channels) for i in range(num_stacks - 1))
    def forward(self, x):
        x = self.prepare(x)
        predicts = []
        for i in range(self.num_stacks):
            x = self.hourglass[i](x)
            predicts.append(self.output[i](x))
            if i != self.num_stacks - 1:
                x = self.merge[i](predicts[-1]) + self.next[i](x)
        return torch.stack(predicts, 0)

def get_model(): # Function to create an object of class model
    return StackedHourglassNet(num_stacks=NUM_STACKS, channels=FIRST_DEPTH_CHANNELS,
                                     output_channels=OUTPUT_CHANNELS,
                                     add_channels=NEXT_DEPTH_ADD_CHANNELS)

def get_train_data():
    cnt = 0
    train_data = []
    train_pos = []
    for _, _, _, left_mag, right_mag, mixed_mag, _, _ in data_generator(train_d,train=True):
        train_data.append((left_mag, right_mag, mixed_mag))
        for i in range(mixed_mag.shape[-1]):
            train_pos.append((cnt, i))
        cnt += 1
    return train_data, train_pos

In [None]:
#TOTAL NUMBER OF TRAINABLE PARAMETERS
net = get_model()
total_params_SH1 = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(total_params_SH1)

8868038


In [None]:
ITERATIONS = 1500
BATCH_SIZE = 4  
NUM_STACKS = 1
FIRST_DEPTH_CHANNELS = 64
OUTPUT_CHANNELS = 2  
NEXT_DEPTH_ADD_CHANNELS = 64  
TRAIN_SAVE_POINT = 50  
TEST_STEP = 20 
TOTAL_TEST = 230
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 



def train():
    net = get_model()
    net.to(device)
    training_data, training_indices_array = get_train_data()
    loss_s = torch.empty(1).to(device)
    optimiser = torch.optim.Adam(net.parameters(), lr=0.001)
    for i in tqdm.tqdm(range(ITERATIONS)):
        # declaring input and output tensors
        input = torch.empty(BATCH_SIZE, 1, 512, 64)
        output = torch.empty(BATCH_SIZE, 2, 512, 64)
        
        for j in range(BATCH_SIZE):
            indexing, starting_index = training_indices_array[np.random.randint(len(training_indices_array))]  
            left_mag, right_mag, mixed_mag = training_data[indexing]
            input[j, 0, :, :] = torch.from_numpy(mixed_mag[:512, starting_index:starting_index + 64])
            output[j, 0, :, :] = torch.from_numpy(left_mag[:512, starting_index:starting_index + 64])
            output[j, 1, :, :] = torch.from_numpy(right_mag[:512, starting_index:starting_index + 64])
            input = input.to(device)
            output = output.to(device)
        optimiser.zero_grad()
        predicted = net(input).to(device)
 
        loss = sum(torch.mean(torch.abs(predicted[x].mul(input) - output)) for x in range(1))
        if loss>=0:
          loss.backward()
          loss_s += loss
          optimiser.step()
        if (i + 1) % TRAIN_SAVE_POINT == 0:
            torch.save(net.state_dict(), 'Model1/checkpoint_{}.pt'.format(i + 1))
            print( "loss is {}".format(loss_s / TRAIN_SAVE_POINT,))
            loss_s = 0
    

def test(model='Model1/checkpoint_1500.pt'):
    models = get_model()
    models.load_state_dict(torch.load(model),strict=False)
    models.to(device)
    input = np.empty((BATCH_SIZE, 1, 512, 64), dtype=np.float32)
    length = 0.
    cnt = 0
    gnsdr = 0.
    gsir = 0.
    gsar = 0.
    for left, right, mix, left_magnitude, right_magnitude, mix_magnitude, max_value, mix_phase in data_generator(train_d,train=False):
        source_length = mix_magnitude.shape[-1]
        start_ind = 0
        predicts_left = np.zeros((512, source_length), dtype=np.float32)
        predicts_right = np.zeros((512, source_length), dtype=np.float32)
        
        start = time.time()
        while start_ind + 64 < source_length:
            
            if start_ind and start_ind + (BATCH_SIZE - 1) * 32 + 64 < source_length:
                for i in range(BATCH_SIZE):
                    input[i, 0, :, :] = mix_magnitude[0:512, start_ind + i * 32:start_ind + i * 32 + 64]
                output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy() 
                for i in range(BATCH_SIZE):
                    predicts_left[:, start_ind + i * 32 + 16: start_ind + i * 32 + 48] = output[i, 0, :, 16:48]
                    predicts_right[:, start_ind + i * 32 + 16: start_ind + i * 32 + 48] = output[i, 1, :, 16:48]
                start_ind += BATCH_SIZE * 32
            else:
                input[0, 0, :, :] = mix_magnitude[0:512, start_ind:start_ind + 64]
                output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy() 
                if start_ind == 0:
                    predicts_left[:, 0:64] = output[0, 0, :, :]
                    predicts_right[:, 0:64] = output[0, 1, :, :]
                else:
                    predicts_left[:, start_ind + 16: start_ind + 48] = output[0, 0, :, 16:48]
                    predicts_right[:, start_ind + 16:start_ind + 48] = output[0, 1, :, 16:48]
                start_ind += 32
        input[0, 0, :, :] = mix_magnitude[0:512, source_length - 64:source_length]
        output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy()  
        length = source_length - start_ind - 16
        predicts_left[:, start_ind + 16:source_length] = output[0, 0, :, 64 - length:64]
        predicts_right[:, start_ind + 16:source_length] = output[0, 1, :, 64 - length:64]
        
        predicts_left =np.nan_to_num(predicts_left)
        predicts_right =np.nan_to_num(predicts_right)
        predicts_left[np.where(predicts_left < 0)] = 0
        predicts_right[np.where(predicts_right < 0)] = 0
        predicts_left = predicts_left * mix_magnitude[0:512, :] * max_value
        predicts_right = predicts_right * mix_magnitude[0:512, :] * max_value
        predicts_left_wav = to_wav(predicts_left, mix_phase[0:512, :])
        predicts_right_wav = to_wav(predicts_right, mix_phase[0:512, :])
        predicts_left_wav = librosa.resample(predicts_left_wav, 8000, 16000)
        predicts_right_wav = librosa.resample(predicts_right_wav, 8000, 16000)
        sdr, sir, sar, lens = bss_eval(mix, left, right, predicts_left_wav,
                                              predicts_right_wav)
        length = length + lens
        gnsdr = gnsdr + sdr * lens
        gsir = gsir + sir * lens
        gsar = gsar + sar * lens
        cnt += 1
        if cnt % TEST_STEP == 0:
            print('GNSDR: ', gnsdr / length)
            print('GSIR: ', gsir / length)
            print('GSAR: ', gsar / length)
            
            
        if cnt == TOTAL_TEST:
            break


os.mkdir('Model1')

# New Section

In [None]:
train()

  3%|▎         | 52/1500 [00:06<03:16,  7.38it/s]

loss of 50 is tensor([-9.1422e+33], device='cuda:0', grad_fn=<DivBackward0>)


  7%|▋         | 102/1500 [00:12<03:00,  7.73it/s]

loss of 100 is 0.007760205771774054


 10%|█         | 152/1500 [00:19<02:54,  7.72it/s]

loss of 150 is 0.006792802829295397


 13%|█▎        | 202/1500 [00:25<02:49,  7.67it/s]

loss of 200 is 0.006208804901689291


 17%|█▋        | 252/1500 [00:31<02:43,  7.63it/s]

loss of 250 is 0.005664951168000698


 20%|██        | 302/1500 [00:38<02:36,  7.65it/s]

loss of 300 is 0.005553694441914558


 23%|██▎       | 352/1500 [00:44<02:30,  7.63it/s]

loss of 350 is 0.005484114866703749


 27%|██▋       | 402/1500 [00:51<02:26,  7.50it/s]

loss of 400 is 0.005225309636443853


 30%|███       | 452/1500 [00:57<02:19,  7.52it/s]

loss of 450 is 0.005576816853135824


 33%|███▎      | 502/1500 [01:03<02:14,  7.42it/s]

loss of 500 is 0.005300813354551792


 37%|███▋      | 552/1500 [01:10<02:05,  7.53it/s]

loss of 550 is 0.0048544565215706825


 40%|████      | 602/1500 [01:16<02:00,  7.44it/s]

loss of 600 is 0.004580163862556219


 43%|████▎     | 652/1500 [01:23<01:53,  7.50it/s]

loss of 650 is 0.004732952918857336


 47%|████▋     | 702/1500 [01:30<01:47,  7.41it/s]

loss of 700 is 0.004712476395070553


 50%|█████     | 752/1500 [01:36<01:44,  7.14it/s]

loss of 750 is 0.004698533099144697


 53%|█████▎    | 802/1500 [01:43<01:33,  7.43it/s]

loss of 800 is 0.004441540688276291


 57%|█████▋    | 852/1500 [01:49<01:28,  7.35it/s]

loss of 850 is 0.0044045522809028625


 60%|██████    | 902/1500 [01:56<01:21,  7.34it/s]

loss of 900 is 0.004813037812709808


 63%|██████▎   | 952/1500 [02:03<01:14,  7.34it/s]

loss of 950 is 0.004298831801861525


 67%|██████▋   | 1002/1500 [02:09<01:08,  7.31it/s]

loss of 1000 is 0.004038430750370026


 70%|███████   | 1052/1500 [02:16<01:01,  7.30it/s]

loss of 1050 is 0.004518741276115179


 73%|███████▎  | 1102/1500 [02:23<00:54,  7.33it/s]

loss of 1100 is 0.004611147567629814


 77%|███████▋  | 1152/1500 [02:29<00:48,  7.24it/s]

loss of 1150 is 0.004254947416484356


 80%|████████  | 1202/1500 [02:36<00:40,  7.33it/s]

loss of 1200 is 0.0039976718835532665


 83%|████████▎ | 1252/1500 [02:43<00:33,  7.33it/s]

loss of 1250 is 0.003943032119423151


 87%|████████▋ | 1302/1500 [02:50<00:32,  6.06it/s]

loss of 1300 is 0.003961499780416489


 90%|█████████ | 1352/1500 [02:56<00:20,  7.30it/s]

loss of 1350 is 0.004004259593784809


 93%|█████████▎| 1402/1500 [03:03<00:13,  7.22it/s]

loss of 1400 is 0.004023166839033365


 97%|█████████▋| 1452/1500 [03:10<00:06,  7.28it/s]

loss of 1450 is 0.004071845207363367


100%|██████████| 1500/1500 [03:17<00:00,  7.61it/s]

loss of 1500 is 0.0036650157999247313





In [None]:
test()

  0%|          | 0/230 [00:00<?, ?it/s]





  9%|▊         | 20/230 [00:40<06:54,  1.97s/it]

GNSDR:  [ 9.29741592 10.42193195]
GSIR:  [12.58453554 16.66428616]
GSAR:  [12.57899469 11.90913728]


 17%|█▋        | 40/230 [01:21<06:42,  2.12s/it]

GNSDR:  [ 9.70749218 11.37613688]
GSIR:  [13.09621566 18.36202007]
GSAR:  [12.84738225 12.59647486]


 26%|██▌       | 60/230 [02:01<04:58,  1.76s/it]

GNSDR:  [10.28430351 11.68311888]
GSIR:  [13.7723883  18.17143087]
GSAR:  [13.29402505 13.05509648]


 35%|███▍      | 80/230 [02:44<05:54,  2.36s/it]

GNSDR:  [10.04313206 11.1565374 ]
GSIR:  [13.62193581 17.37148438]
GSAR:  [12.95262842 12.63493689]


 43%|████▎     | 100/230 [03:24<04:57,  2.29s/it]

GNSDR:  [10.04616764 10.98417414]
GSIR:  [13.73086824 17.11728164]
GSAR:  [12.86229138 12.50103796]


 52%|█████▏    | 120/230 [04:06<03:45,  2.05s/it]

GNSDR:  [ 9.83632429 10.59295268]
GSIR:  [13.51770058 16.48298477]
GSAR:  [12.65608534 12.24096361]


 61%|██████    | 140/230 [04:52<03:02,  2.03s/it]

GNSDR:  [ 9.70593504 10.46370629]
GSIR:  [13.3423235  16.42841189]
GSAR:  [12.56302139 12.06448865]


 70%|██████▉   | 160/230 [05:34<02:25,  2.08s/it]

GNSDR:  [ 9.36049741 10.08767785]
GSIR:  [12.88312197 16.15539319]
GSAR:  [12.34921475 11.65705907]


 78%|███████▊  | 180/230 [06:15<01:37,  1.95s/it]

GNSDR:  [9.22316346 9.93416299]
GSIR:  [12.69611405 16.05380933]
GSAR:  [12.27546722 11.49451346]


 87%|████████▋ | 200/230 [06:59<01:04,  2.15s/it]

GNSDR:  [9.14837977 9.86141017]
GSIR:  [12.64298148 16.05023346]
GSAR:  [12.1842329  11.39280715]


 96%|█████████▌| 220/230 [07:39<00:18,  1.80s/it]

GNSDR:  [9.08708454 9.79673213]
GSIR:  [12.57614819 15.98145621]
GSAR:  [12.13160477 11.32328618]


100%|██████████| 230/230 [07:58<00:00,  2.08s/it]


#STACKED HOURGLASS - 2 

In [None]:
NUM_STACKS = 2 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train():
    net = get_model()
    net.to(device)
    training_data, training_indices_array = get_train_data()
    loss_s = torch.empty(1).to(device)
    optimiser = torch.optim.Adam(net.parameters(), lr=0.001)
    for i in tqdm.tqdm(range(ITERATIONS)):
        # declaring input and output tensors
        input = torch.empty(BATCH_SIZE, 1, 512, 64)
        output = torch.empty(BATCH_SIZE, 2, 512, 64)
        
        for j in range(BATCH_SIZE):
            indexing, starting_index = training_indices_array[np.random.randint(len(training_indices_array))]  
            left_mag, right_mag, mixed_mag = training_data[indexing]
            input[j, 0, :, :] = torch.from_numpy(mixed_mag[:512, starting_index:starting_index + 64])
            output[j, 0, :, :] = torch.from_numpy(left_mag[:512, starting_index:starting_index + 64])
            output[j, 1, :, :] = torch.from_numpy(right_mag[:512, starting_index:starting_index + 64])
            input = input.to(device)
            output = output.to(device)
        optimiser.zero_grad()
        predicted = net(input).to(device)
 
        loss = sum(torch.mean(torch.abs(predicted[x].mul(input) - output)) for x in range(2))
        if loss>=0:
          loss.backward()
          loss_s += loss
          optimiser.step()
        if (i + 1) % TRAIN_SAVE_POINT == 0:
            torch.save(net.state_dict(), 'Model2/checkpoint_{}.pt'.format(i + 1))
            print( "loss is {}".format(loss_s / TRAIN_SAVE_POINT,))
            loss_s = 0

def test(model='Model2/checkpoint_1500.pt'):
    models = get_model()
    models.load_state_dict(torch.load(model),strict=False)
    models.to(device)
    input = np.empty((BATCH_SIZE, 1, 512, 64), dtype=np.float32)
    length = 0.
    cnt = 0
    gnsdr = 0.
    gsir = 0.
    gsar = 0.
    for left, right, mix, left_magnitude, right_magnitude, mix_magnitude, max_value, mix_phase in data_generator(train_d,train=False):
        source_length = mix_magnitude.shape[-1]
        start_ind = 0
        predicts_left = np.zeros((512, source_length), dtype=np.float32)
        predicts_right = np.zeros((512, source_length), dtype=np.float32)
        
        start = time.time()
        while start_ind + 64 < source_length:
            
            if start_ind and start_ind + (BATCH_SIZE - 1) * 32 + 64 < source_length:
                for i in range(BATCH_SIZE):
                    input[i, 0, :, :] = mix_magnitude[0:512, start_ind + i * 32:start_ind + i * 32 + 64]
                output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy() 
                for i in range(BATCH_SIZE):
                    predicts_left[:, start_ind + i * 32 + 16: start_ind + i * 32 + 48] = output[i, 0, :, 16:48]
                    predicts_right[:, start_ind + i * 32 + 16: start_ind + i * 32 + 48] = output[i, 1, :, 16:48]
                start_ind += BATCH_SIZE * 32
            else:
                input[0, 0, :, :] = mix_magnitude[0:512, start_ind:start_ind + 64]
                output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy() 
                if start_ind == 0:
                    predicts_left[:, 0:64] = output[0, 0, :, :]
                    predicts_right[:, 0:64] = output[0, 1, :, :]
                else:
                    predicts_left[:, start_ind + 16: start_ind + 48] = output[0, 0, :, 16:48]
                    predicts_right[:, start_ind + 16:start_ind + 48] = output[0, 1, :, 16:48]
                start_ind += 32
        input[0, 0, :, :] = mix_magnitude[0:512, source_length - 64:source_length]
        output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy()  
        length = source_length - start_ind - 16
        predicts_left[:, start_ind + 16:source_length] = output[0, 0, :, 64 - length:64]
        predicts_right[:, start_ind + 16:source_length] = output[0, 1, :, 64 - length:64]
        
        predicts_left =np.nan_to_num(predicts_left)
        predicts_right =np.nan_to_num(predicts_right)
        predicts_left[np.where(predicts_left < 0)] = 0
        predicts_right[np.where(predicts_right < 0)] = 0
        predicts_left = predicts_left * mix_magnitude[0:512, :] * max_value
        predicts_right = predicts_right * mix_magnitude[0:512, :] * max_value
        predicts_left_wav = to_wav(predicts_left, mix_phase[0:512, :])
        predicts_right_wav = to_wav(predicts_right, mix_phase[0:512, :])
        predicts_left_wav = librosa.resample(predicts_left_wav, 8000, 16000)
        predicts_right_wav = librosa.resample(predicts_right_wav, 8000, 16000)
        sdr, sir, sar, lens = bss_eval(mix, left, right, predicts_left_wav,
                                              predicts_right_wav)
        length = length + lens
        gnsdr = gnsdr + sdr * lens
        gsir = gsir + sir * lens
        gsar = gsar + sar * lens
        cnt += 1
        if cnt % TEST_STEP == 0:
            print('GNSDR: ', gnsdr / length)
            print('GSIR: ', gsir / length)
            print('GSAR: ', gsar / length)
            
            
        if cnt == TOTAL_TEST:
            break

os.mkdir('Model2')

In [None]:
#TOTAL NUMBER OF TRAINABLE PARAMETERS
net = get_model()
total_params_SH2 = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(total_params_SH2)

17475276


In [None]:
train()

  3%|▎         | 50/1500 [00:10<06:10,  3.92it/s]

loss of 50 is tensor([0.0195], device='cuda:0', grad_fn=<DivBackward0>)


  7%|▋         | 100/1500 [00:21<06:08,  3.80it/s]

loss of 100 is 0.015661755576729774


 10%|█         | 150/1500 [00:31<05:40,  3.97it/s]

loss of 150 is 0.01224567275494337


 13%|█▎        | 200/1500 [00:42<05:30,  3.93it/s]

loss of 200 is 0.011883318424224854


 17%|█▋        | 250/1500 [00:53<05:22,  3.88it/s]

loss of 250 is 0.012043341062963009


 20%|██        | 300/1500 [01:04<05:08,  3.90it/s]

loss of 300 is 0.011263801716268063


 23%|██▎       | 350/1500 [01:15<04:57,  3.87it/s]

loss of 350 is 0.010137338191270828


 27%|██▋       | 400/1500 [01:26<04:46,  3.84it/s]

loss of 400 is 0.010054368525743484


 30%|███       | 450/1500 [01:37<05:35,  3.13it/s]

loss of 450 is 0.009947076439857483


 33%|███▎      | 500/1500 [01:48<04:20,  3.83it/s]

loss of 500 is 0.010157236829400063


 37%|███▋      | 550/1500 [01:59<04:07,  3.84it/s]

loss of 550 is 0.009244334883987904


 40%|████      | 600/1500 [02:10<03:55,  3.83it/s]

loss of 600 is 0.009110729210078716


 43%|████▎     | 650/1500 [02:21<03:43,  3.80it/s]

loss of 650 is 0.009565181098878384


 47%|████▋     | 700/1500 [02:32<03:31,  3.77it/s]

loss of 700 is 0.008802094496786594


 50%|█████     | 750/1500 [02:44<03:18,  3.78it/s]

loss of 750 is 0.008578644134104252


 53%|█████▎    | 800/1500 [02:55<03:15,  3.58it/s]

loss of 800 is 0.008322295732796192


 57%|█████▋    | 850/1500 [03:06<02:51,  3.79it/s]

loss of 850 is 0.008967826142907143


 60%|██████    | 900/1500 [03:18<02:38,  3.79it/s]

loss of 900 is 0.009264418855309486


 63%|██████▎   | 950/1500 [03:29<02:25,  3.77it/s]

loss of 950 is 0.008302843198180199


 67%|██████▋   | 1000/1500 [03:40<02:15,  3.69it/s]

loss of 1000 is 0.008184941485524178


 70%|███████   | 1050/1500 [03:52<01:58,  3.79it/s]

loss of 1050 is 0.0075713531114161015


 73%|███████▎  | 1100/1500 [04:03<01:47,  3.73it/s]

loss of 1100 is 0.007960546761751175


 77%|███████▋  | 1150/1500 [04:14<01:33,  3.74it/s]

loss of 1150 is 0.00811872910708189


 80%|████████  | 1200/1500 [04:26<01:20,  3.71it/s]

loss of 1200 is 0.0080467713996768


 83%|████████▎ | 1250/1500 [04:37<01:09,  3.58it/s]

loss of 1250 is 0.007549282629042864


 87%|████████▋ | 1300/1500 [04:49<00:53,  3.75it/s]

loss of 1300 is 0.00805914681404829


 90%|█████████ | 1350/1500 [05:00<00:40,  3.75it/s]

loss of 1350 is 0.0071518453769385815


 93%|█████████▎| 1400/1500 [05:12<00:26,  3.71it/s]

loss of 1400 is 0.007266509812325239


 97%|█████████▋| 1450/1500 [05:23<00:13,  3.74it/s]

loss of 1450 is 0.007946201600134373


100%|██████████| 1500/1500 [05:34<00:00,  4.48it/s]

loss of 1500 is 0.0073224701918661594





In [None]:
test()

  9%|▊         | 20/230 [00:47<07:21,  2.10s/it]

GNSDR:  [ 9.80038162 10.70581223]
GSIR:  [13.92219722 17.90606632]
GSAR:  [12.35646391 11.90790399]


 17%|█▋        | 40/230 [01:28<06:58,  2.20s/it]

GNSDR:  [10.2807994  11.76483737]
GSIR:  [14.51503801 19.34933434]
GSAR:  [12.73278651 12.84248022]


 26%|██▌       | 60/230 [02:11<05:18,  1.88s/it]

GNSDR:  [10.81713491 12.04350591]
GSIR:  [15.09797193 19.07112211]
GSAR:  [13.18366846 13.26820841]


 35%|███▍      | 80/230 [02:55<06:06,  2.44s/it]

GNSDR:  [10.39018532 11.41933706]
GSIR:  [14.6088455  18.26053922]
GSAR:  [12.78269194 12.72324377]


 43%|████▎     | 100/230 [03:38<05:10,  2.39s/it]

GNSDR:  [10.32795374 11.24362138]
GSIR:  [14.56788984 18.01217075]
GSAR:  [12.70881958 12.56851562]


 52%|█████▏    | 120/230 [04:21<03:51,  2.10s/it]

GNSDR:  [10.1185801  10.92782104]
GSIR:  [14.27041494 17.48634681]
GSAR:  [12.56779111 12.36979024]


 61%|██████    | 140/230 [05:09<03:16,  2.18s/it]

GNSDR:  [ 9.89296359 10.73810313]
GSIR:  [13.91844425 17.45360738]
GSAR:  [12.44916856 12.13580855]


 70%|██████▉   | 160/230 [05:54<02:31,  2.16s/it]

GNSDR:  [ 9.5134581  10.36237065]
GSIR:  [13.37553193 17.23238807]
GSAR:  [12.23535929 11.71269808]


 78%|███████▊  | 180/230 [06:36<01:41,  2.04s/it]

GNSDR:  [ 9.33180972 10.13206078]
GSIR:  [13.14165889 16.96677621]
GSAR:  [12.09741074 11.51331324]


 87%|████████▋ | 200/230 [07:22<01:06,  2.23s/it]

GNSDR:  [ 9.22554019 10.02901378]
GSIR:  [13.03318357 16.88618576]
GSAR:  [11.98986417 11.39522551]


 96%|█████████▌| 220/230 [08:03<00:18,  1.83s/it]

GNSDR:  [9.14332844 9.93643997]
GSIR:  [12.92863707 16.77771608]
GSAR:  [11.92048093 11.30910882]


100%|██████████| 230/230 [08:23<00:00,  2.19s/it]


#STACKED HOURGLASS - 4

In [None]:
MAX_ITERATIONS = 500
NUM_STACKS = 4  
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train():
    net = get_model()
    net.to(device)
    training_data, training_indices_array = get_train_data()
    loss_s = torch.empty(1).to(device)
    optimiser = torch.optim.Adam(net.parameters(), lr=0.001)
    for i in tqdm.tqdm(range(ITERATIONS)):
        # declaring input and output tensors
        input = torch.empty(BATCH_SIZE, 1, 512, 64)
        output = torch.empty(BATCH_SIZE, 2, 512, 64)
        
        for j in range(BATCH_SIZE):
            indexing, starting_index = training_indices_array[np.random.randint(len(training_indices_array))]  
            left_mag, right_mag, mixed_mag = training_data[indexing]
            input[j, 0, :, :] = torch.from_numpy(mixed_mag[:512, starting_index:starting_index + 64])
            output[j, 0, :, :] = torch.from_numpy(left_mag[:512, starting_index:starting_index + 64])
            output[j, 1, :, :] = torch.from_numpy(right_mag[:512, starting_index:starting_index + 64])
            input = input.to(device)
            output = output.to(device)
        optimiser.zero_grad()
        predicted = net(input).to(device)
 
        loss = sum(torch.mean(torch.abs(predicted[x].mul(input) - output)) for x in range(4))
        if loss>=0:
          loss.backward()
          loss_s += loss
          optimiser.step()
        if (i + 1) % TRAIN_SAVE_POINT == 0:
            torch.save(net.state_dict(), 'Model4/checkpoint_{}.pt'.format(i + 1))
            print( "loss is {}".format(loss_s / TRAIN_SAVE_POINT,))
            loss_s = 0
def test(model='Model4/checkpoint_500.pt'):
    models = get_model()
    models.load_state_dict(torch.load(model),strict=False)
    models.to(device)
    input = np.empty((BATCH_SIZE, 1, 512, 64), dtype=np.float32)
    length = 0.
    cnt = 0
    gnsdr = 0.
    gsir = 0.
    gsar = 0.
    for left, right, mix, left_magnitude, right_magnitude, mix_magnitude, max_value, mix_phase in data_generator(train_d,train=False):
        source_length = mix_magnitude.shape[-1]
        start_ind = 0
        predicts_left = np.zeros((512, source_length), dtype=np.float32)
        predicts_right = np.zeros((512, source_length), dtype=np.float32)
        
        start = time.time()
        while start_ind + 64 < source_length:
            
            if start_ind and start_ind + (BATCH_SIZE - 1) * 32 + 64 < source_length:
                for i in range(BATCH_SIZE):
                    input[i, 0, :, :] = mix_magnitude[0:512, start_ind + i * 32:start_ind + i * 32 + 64]
                output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy() 
                for i in range(BATCH_SIZE):
                    predicts_left[:, start_ind + i * 32 + 16: start_ind + i * 32 + 48] = output[i, 0, :, 16:48]
                    predicts_right[:, start_ind + i * 32 + 16: start_ind + i * 32 + 48] = output[i, 1, :, 16:48]
                start_ind += BATCH_SIZE * 32
            else:
                input[0, 0, :, :] = mix_magnitude[0:512, start_ind:start_ind + 64]
                output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy() 
                if start_ind == 0:
                    predicts_left[:, 0:64] = output[0, 0, :, :]
                    predicts_right[:, 0:64] = output[0, 1, :, :]
                else:
                    predicts_left[:, start_ind + 16: start_ind + 48] = output[0, 0, :, 16:48]
                    predicts_right[:, start_ind + 16:start_ind + 48] = output[0, 1, :, 16:48]
                start_ind += 32
        input[0, 0, :, :] = mix_magnitude[0:512, source_length - 64:source_length]
        output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy()  
        length = source_length - start_ind - 16
        predicts_left[:, start_ind + 16:source_length] = output[0, 0, :, 64 - length:64]
        predicts_right[:, start_ind + 16:source_length] = output[0, 1, :, 64 - length:64]
        
        predicts_left =np.nan_to_num(predicts_left)
        predicts_right =np.nan_to_num(predicts_right)
        predicts_left[np.where(predicts_left < 0)] = 0
        predicts_right[np.where(predicts_right < 0)] = 0
        predicts_left = predicts_left * mix_magnitude[0:512, :] * max_value
        predicts_right = predicts_right * mix_magnitude[0:512, :] * max_value
        predicts_left_wav = to_wav(predicts_left, mix_phase[0:512, :])
        predicts_right_wav = to_wav(predicts_right, mix_phase[0:512, :])
        predicts_left_wav = librosa.resample(predicts_left_wav, 8000, 16000)
        predicts_right_wav = librosa.resample(predicts_right_wav, 8000, 16000)
        sdr, sir, sar, lens = bss_eval(mix, left, right, predicts_left_wav,
                                              predicts_right_wav)
        length = length + lens
        gnsdr = gnsdr + sdr * lens
        gsir = gsir + sir * lens
        gsar = gsar + sar * lens
        cnt += 1
        if cnt % TEST_STEP == 0:
            print('GNSDR: ', gnsdr / length)
            print('GSIR: ', gsir / length)
            print('GSAR: ', gsar / length)
            
            
        if cnt == TOTAL_TEST:
            break

In [None]:
#TOTAL NUMBER OF TRAINABLE PARAMETERS
net = get_model()
total_params_SH4 = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(total_params_SH4)

34689752


In [None]:
train()

 10%|█         | 50/500 [00:19<03:44,  2.01it/s]

loss of 50 is tensor([0.0396], device='cuda:0', grad_fn=<DivBackward0>)


 20%|██        | 100/500 [00:38<03:19,  2.01it/s]

loss of 100 is 0.030864646658301353


 30%|███       | 150/500 [00:58<02:53,  2.02it/s]

loss of 150 is 0.027471253648400307


 40%|████      | 200/500 [01:18<02:31,  1.98it/s]

loss of 200 is 0.02457033284008503


 50%|█████     | 250/500 [01:38<02:05,  2.00it/s]

loss of 250 is 0.02325206995010376


 60%|██████    | 300/500 [01:58<01:41,  1.97it/s]

loss of 300 is 0.022393090650439262


 70%|███████   | 350/500 [02:18<01:17,  1.94it/s]

loss of 350 is 0.02138429880142212


 80%|████████  | 400/500 [02:38<00:51,  1.95it/s]

loss of 400 is 0.02080450765788555


 90%|█████████ | 450/500 [02:59<00:25,  1.94it/s]

loss of 450 is 0.020517243072390556


100%|██████████| 500/500 [03:19<00:00,  2.50it/s]

loss of 500 is 0.021114736795425415





In [None]:
test()

  9%|▊         | 20/230 [00:46<07:53,  2.26s/it]

GNSDR:  [8.77572573 9.35972856]
GSIR:  [13.09098554 14.54126495]
GSAR:  [11.25119801 11.35620464]


 17%|█▋        | 40/230 [01:30<07:33,  2.39s/it]

GNSDR:  [ 9.19988281 10.31610165]
GSIR:  [13.44184313 16.37196536]
GSAR:  [11.69999143 11.97317494]


 26%|██▌       | 60/230 [02:16<05:45,  2.03s/it]

GNSDR:  [ 9.66853114 10.40482768]
GSIR:  [14.18354908 15.8919944 ]
GSAR:  [11.99346434 12.26446479]


 35%|███▍      | 80/230 [03:04<06:39,  2.67s/it]

GNSDR:  [ 9.27616072 10.01407079]
GSIR:  [13.51103944 15.68836142]
GSAR:  [11.81366642 11.85461118]


 43%|████▎     | 100/230 [03:49<05:32,  2.56s/it]

GNSDR:  [9.15603459 9.88720403]
GSIR:  [13.27621983 15.70383453]
GSAR:  [11.79036806 11.6557938 ]


 52%|█████▏    | 120/230 [04:35<04:09,  2.27s/it]

GNSDR:  [8.931424   9.58234537]
GSIR:  [12.98708057 15.2968552 ]
GSAR:  [11.62044893 11.41389723]


 61%|██████    | 140/230 [05:26<03:25,  2.29s/it]

GNSDR:  [8.69864192 9.42148711]
GSIR:  [12.58665004 15.39792937]
GSAR:  [11.52714384 11.16580782]


 70%|██████▉   | 160/230 [06:14<02:44,  2.34s/it]

GNSDR:  [8.37893851 9.06931373]
GSIR:  [12.13934455 15.18427573]
GSAR:  [11.33023557 10.75796063]


 78%|███████▊  | 180/230 [07:00<01:50,  2.20s/it]

GNSDR:  [8.21373037 8.89284431]
GSIR:  [11.91034135 15.04945253]
GSAR:  [11.22728797 10.58347742]


 87%|████████▋ | 200/230 [07:50<01:11,  2.40s/it]

GNSDR:  [8.13190496 8.81071897]
GSIR:  [11.86577725 14.99643082]
GSAR:  [11.10396898 10.47688817]


 96%|█████████▌| 220/230 [08:34<00:19,  1.97s/it]

GNSDR:  [8.0546265  8.74925667]
GSIR:  [11.76520215 14.95189058]
GSAR:  [11.04173583 10.40547879]


100%|██████████| 230/230 [08:56<00:00,  2.33s/it]


#CNN



In [None]:
class CNN(nn.Module):
  def __init__(self):
        super(CNN, self).__init__()
        self.model = nn.Sequential(
                Conv(1, 64, kernel_size=7, stride=1),
                Conv(64, 128), 
                Conv(128, 128),
                Conv(128, 256),
                Conv(256, 2)
                )

  def forward(self, x):
      return self.model(x)

def get_model(): # Function to create an object of class model
    return CNN()

In [None]:
MAX_ITERATIONS = 500
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train():
    net = get_model()
    net.to(device)
    training_data, training_indices_array = get_train_data()
    loss_s = torch.empty(1).to(device)
    optimiser = torch.optim.Adam(net.parameters(), lr=0.001)
    for i in tqdm.tqdm(range(ITERATIONS)):
        # declaring input and output tensors
        input = torch.empty(BATCH_SIZE, 1, 512, 64)
        output = torch.empty(BATCH_SIZE, 2, 512, 64)
        
        for j in range(BATCH_SIZE):
            indexing, starting_index = training_indices_array[np.random.randint(len(training_indices_array))]  
            left_mag, right_mag, mixed_mag = training_data[indexing]
            input[j, 0, :, :] = torch.from_numpy(mixed_mag[:512, starting_index:starting_index + 64])
            output[j, 0, :, :] = torch.from_numpy(left_mag[:512, starting_index:starting_index + 64])
            output[j, 1, :, :] = torch.from_numpy(right_mag[:512, starting_index:starting_index + 64])
            input = input.to(device)
            output = output.to(device)
        optimiser.zero_grad()
        predicted = net(input).to(device)
 
        loss = sum(torch.mean(torch.abs(predicted[x].mul(input) - output)) for x in range(4))
        if loss>=0:
          loss.backward()
          loss_s += loss
          optimiser.step()
        if (i + 1) % TRAIN_SAVE_POINT == 0:
            torch.save(net.state_dict(), 'ModelCNN/checkpoint_{}.pt'.format(i + 1))
            print( "loss is {}".format(loss_s / TRAIN_SAVE_POINT,))
            loss_s = 0
def test(model='ModelCNN/checkpoint_500.pt'):
    models = get_model()
    models.load_state_dict(torch.load(model),strict=False)
    models.to(device)
    input = np.empty((BATCH_SIZE, 1, 512, 64), dtype=np.float32)
    length = 0.
    cnt = 0
    gnsdr = 0.
    gsir = 0.
    gsar = 0.
    for left, right, mix, left_magnitude, right_magnitude, mix_magnitude, max_value, mix_phase in data_generator(train_d,train=False):
        source_length = mix_magnitude.shape[-1]
        start_ind = 0
        predicts_left = np.zeros((512, source_length), dtype=np.float32)
        predicts_right = np.zeros((512, source_length), dtype=np.float32)
        
        start = time.time()
        while start_ind + 64 < source_length:
            
            if start_ind and start_ind + (BATCH_SIZE - 1) * 32 + 64 < source_length:
                for i in range(BATCH_SIZE):
                    input[i, 0, :, :] = mix_magnitude[0:512, start_ind + i * 32:start_ind + i * 32 + 64]
                output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy() 
                for i in range(BATCH_SIZE):
                    predicts_left[:, start_ind + i * 32 + 16: start_ind + i * 32 + 48] = output[i, 0, :, 16:48]
                    predicts_right[:, start_ind + i * 32 + 16: start_ind + i * 32 + 48] = output[i, 1, :, 16:48]
                start_ind += BATCH_SIZE * 32
            else:
                input[0, 0, :, :] = mix_magnitude[0:512, start_ind:start_ind + 64]
                output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy() 
                if start_ind == 0:
                    predicts_left[:, 0:64] = output[0, 0, :, :]
                    predicts_right[:, 0:64] = output[0, 1, :, :]
                else:
                    predicts_left[:, start_ind + 16: start_ind + 48] = output[0, 0, :, 16:48]
                    predicts_right[:, start_ind + 16:start_ind + 48] = output[0, 1, :, 16:48]
                start_ind += 32
        input[0, 0, :, :] = mix_magnitude[0:512, source_length - 64:source_length]
        output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy()  
        length = source_length - start_ind - 16
        predicts_left[:, start_ind + 16:source_length] = output[0, 0, :, 64 - length:64]
        predicts_right[:, start_ind + 16:source_length] = output[0, 1, :, 64 - length:64]
        
        predicts_left =np.nan_to_num(predicts_left)
        predicts_right =np.nan_to_num(predicts_right)
        predicts_left[np.where(predicts_left < 0)] = 0
        predicts_right[np.where(predicts_right < 0)] = 0
        predicts_left = predicts_left * mix_magnitude[0:512, :] * max_value
        predicts_right = predicts_right * mix_magnitude[0:512, :] * max_value
        predicts_left_wav = to_wav(predicts_left, mix_phase[0:512, :])
        predicts_right_wav = to_wav(predicts_right, mix_phase[0:512, :])
        predicts_left_wav = librosa.resample(predicts_left_wav, 8000, 16000)
        predicts_right_wav = librosa.resample(predicts_right_wav, 8000, 16000)
        sdr, sir, sar, lens = bss_eval(mix, left, right, predicts_left_wav,
                                              predicts_right_wav)
        length = length + lens
        gnsdr = gnsdr + sdr * lens
        gsir = gsir + sir * lens
        gsar = gsar + sar * lens
        cnt += 1
        if cnt % TEST_STEP == 0:
            print('GNSDR: ', gnsdr / length)
            print('GSIR: ', gsir / length)
            print('GSAR: ', gsar / length)
            
            
        if cnt == TOTAL_TEST:
            break

os.mkdir('ModelCNN')

In [None]:
#TOTAL NUMBER OF TRAINABLE PARAMETERS
net = get_model()
total_params_CNN = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(total_params_CNN)

In [None]:
train()

 10%|█         | 51/500 [00:22<03:09,  2.37it/s]

loss of 50 is tensor([5.0581e+17], device='cuda:0', grad_fn=<DivBackward0>)


 20%|██        | 101/500 [00:44<02:45,  2.41it/s]

loss of 100 is 0.04187924414873123


 30%|███       | 151/500 [01:05<02:25,  2.40it/s]

loss of 150 is 0.040555041283369064


 40%|████      | 201/500 [01:27<02:03,  2.42it/s]

loss of 200 is 0.03867791220545769


 50%|█████     | 251/500 [01:49<01:43,  2.41it/s]

loss of 250 is 0.040079522877931595


 60%|██████    | 301/500 [02:10<01:22,  2.40it/s]

loss of 300 is 0.03759855031967163


 70%|███████   | 351/500 [02:32<01:01,  2.42it/s]

loss of 350 is 0.03719960153102875


 80%|████████  | 401/500 [02:54<00:40,  2.43it/s]

loss of 400 is 0.03653066232800484


 90%|█████████ | 451/500 [03:15<00:20,  2.42it/s]

loss of 450 is 0.037184394896030426


100%|██████████| 500/500 [03:37<00:00,  2.30it/s]

loss of 500 is 0.03388440981507301





In [None]:
test()


  0%|          | 0/230 [00:00<?, ?it/s][A
  0%|          | 1/230 [00:02<11:21,  2.98s/it][A
  1%|          | 2/230 [00:05<10:31,  2.77s/it][A
  1%|▏         | 3/230 [00:08<10:15,  2.71s/it][A
  2%|▏         | 4/230 [00:12<11:56,  3.17s/it][A
  2%|▏         | 5/230 [00:16<13:25,  3.58s/it][A
  3%|▎         | 6/230 [00:20<14:14,  3.82s/it][A
  3%|▎         | 7/230 [00:23<13:05,  3.52s/it][A
  3%|▎         | 8/230 [00:26<12:24,  3.35s/it][A
  4%|▍         | 9/230 [00:29<11:34,  3.14s/it][A
  4%|▍         | 10/230 [00:32<11:15,  3.07s/it][A
  5%|▍         | 11/230 [00:36<12:09,  3.33s/it][A
  5%|▌         | 12/230 [00:39<11:40,  3.21s/it][A
  6%|▌         | 13/230 [00:42<12:22,  3.42s/it][A
  6%|▌         | 14/230 [00:45<11:39,  3.24s/it][A
  7%|▋         | 15/230 [00:48<11:13,  3.13s/it][A
  7%|▋         | 16/230 [00:51<10:34,  2.97s/it][A
  7%|▋         | 17/230 [00:54<10:30,  2.96s/it][A
  8%|▊         | 18/230 [00:56<09:59,  2.83s/it][A
  8%|▊         | 19/230 [01:0

GNSDR:  [-5.87881972  0.47665181]
GSIR:  [-0.83845632  3.30814261]
GSAR:  [-0.25534847  5.94358982]



  9%|▉         | 21/230 [01:07<10:56,  3.14s/it][A
 10%|▉         | 22/230 [01:11<11:40,  3.37s/it][A
 10%|█         | 23/230 [01:13<11:05,  3.22s/it][A
 10%|█         | 24/230 [01:16<10:42,  3.12s/it][A
 11%|█         | 25/230 [01:19<10:23,  3.04s/it][A
 11%|█▏        | 26/230 [01:22<10:11,  3.00s/it][A
 12%|█▏        | 27/230 [01:26<11:26,  3.38s/it][A
 12%|█▏        | 28/230 [01:29<10:37,  3.16s/it][A
 13%|█▎        | 29/230 [01:32<10:15,  3.06s/it][A
 13%|█▎        | 30/230 [01:34<08:59,  2.70s/it][A
 13%|█▎        | 31/230 [01:37<09:06,  2.75s/it][A
 14%|█▍        | 32/230 [01:40<10:08,  3.08s/it][A
 14%|█▍        | 33/230 [01:43<10:01,  3.05s/it][A
 15%|█▍        | 34/230 [01:47<10:36,  3.25s/it][A
 15%|█▌        | 35/230 [01:51<11:10,  3.44s/it][A
 16%|█▌        | 36/230 [01:54<10:16,  3.18s/it][A
 16%|█▌        | 37/230 [01:56<09:38,  3.00s/it][A
 17%|█▋        | 38/230 [01:59<09:26,  2.95s/it][A
 17%|█▋        | 39/230 [02:01<08:55,  2.80s/it][A
 17%|█▋    

GNSDR:  [-5.5543744   0.08443792]
GSIR:  [-0.27965419  3.05156787]
GSAR:  [-0.17814215  5.73585023]



 18%|█▊        | 41/230 [02:07<08:43,  2.77s/it][A
 18%|█▊        | 42/230 [02:09<08:28,  2.70s/it][A
 19%|█▊        | 43/230 [02:14<09:48,  3.15s/it][A
 19%|█▉        | 44/230 [02:16<09:11,  2.97s/it][A
 20%|█▉        | 45/230 [02:20<10:17,  3.34s/it][A
 20%|██        | 46/230 [02:24<10:59,  3.58s/it][A
 20%|██        | 47/230 [02:29<11:24,  3.74s/it][A
 21%|██        | 48/230 [02:33<11:41,  3.85s/it][A
 21%|██▏       | 49/230 [02:35<10:24,  3.45s/it][A
 22%|██▏       | 50/230 [02:38<09:29,  3.16s/it][A
 22%|██▏       | 51/230 [02:41<10:00,  3.35s/it][A
 23%|██▎       | 52/230 [02:45<09:40,  3.26s/it][A
 23%|██▎       | 53/230 [02:49<10:26,  3.54s/it][A
 23%|██▎       | 54/230 [02:51<09:25,  3.21s/it][A
 24%|██▍       | 55/230 [02:54<08:41,  2.98s/it][A
 24%|██▍       | 56/230 [02:56<08:11,  2.82s/it][A
 25%|██▍       | 57/230 [02:59<08:13,  2.85s/it][A
 25%|██▌       | 58/230 [03:02<07:56,  2.77s/it][A
 26%|██▌       | 59/230 [03:06<09:01,  3.17s/it][A
 26%|██▌   

GNSDR:  [-6.0262153   0.68109619]
GSIR:  [-1.03543039  3.88728926]
GSAR:  [-0.14308602  5.81004849]



 27%|██▋       | 61/230 [03:12<09:18,  3.31s/it][A
 27%|██▋       | 62/230 [03:15<08:35,  3.07s/it][A
 27%|██▋       | 63/230 [03:17<08:05,  2.91s/it][A
 28%|██▊       | 64/230 [03:21<08:50,  3.19s/it][A
 28%|██▊       | 65/230 [03:24<08:08,  2.96s/it][A
 29%|██▊       | 66/230 [03:26<07:41,  2.81s/it][A
 29%|██▉       | 67/230 [03:30<08:47,  3.24s/it][A
 30%|██▉       | 68/230 [03:34<09:26,  3.50s/it][A
 30%|███       | 69/230 [03:37<08:52,  3.31s/it][A
 30%|███       | 70/230 [03:40<08:17,  3.11s/it][A
 31%|███       | 71/230 [03:43<08:08,  3.07s/it][A
 31%|███▏      | 72/230 [03:47<08:43,  3.31s/it][A
 32%|███▏      | 73/230 [03:49<08:07,  3.10s/it][A
 32%|███▏      | 74/230 [03:53<08:40,  3.33s/it][A
 33%|███▎      | 75/230 [03:56<08:21,  3.24s/it][A
 33%|███▎      | 76/230 [03:59<08:06,  3.16s/it][A
 33%|███▎      | 77/230 [04:02<07:37,  2.99s/it][A
 34%|███▍      | 78/230 [04:06<08:15,  3.26s/it][A
 34%|███▍      | 79/230 [04:09<08:00,  3.18s/it][A
 35%|███▍  

GNSDR:  [-6.05540204  0.97843022]
GSIR:  [-1.16653897  4.3620058 ]
GSAR:  [0.04702102 5.82263182]



 35%|███▌      | 81/230 [04:16<08:17,  3.34s/it][A
 36%|███▌      | 82/230 [04:18<07:57,  3.22s/it][A
 36%|███▌      | 83/230 [04:23<08:33,  3.49s/it][A
 37%|███▋      | 84/230 [04:26<08:46,  3.61s/it][A
 37%|███▋      | 85/230 [04:29<08:00,  3.32s/it][A
 37%|███▋      | 86/230 [04:32<07:28,  3.11s/it][A
 38%|███▊      | 87/230 [04:34<07:05,  2.98s/it][A
 38%|███▊      | 88/230 [04:38<07:41,  3.25s/it][A
 39%|███▊      | 89/230 [04:41<07:30,  3.19s/it][A
 39%|███▉      | 90/230 [04:44<07:18,  3.13s/it][A
 40%|███▉      | 91/230 [04:48<07:44,  3.34s/it][A
 40%|████      | 92/230 [04:52<08:00,  3.48s/it][A
 40%|████      | 93/230 [04:56<08:26,  3.70s/it][A
 41%|████      | 94/230 [04:59<07:37,  3.37s/it][A
 41%|████▏     | 95/230 [05:01<07:03,  3.14s/it][A
 42%|████▏     | 96/230 [05:04<06:50,  3.06s/it][A
 42%|████▏     | 97/230 [05:07<06:32,  2.95s/it][A
 43%|████▎     | 98/230 [05:11<07:23,  3.36s/it][A
 43%|████▎     | 99/230 [05:15<07:53,  3.62s/it][A
 43%|████▎ 

GNSDR:  [-6.31055748  1.18041206]
GSIR:  [-1.63136613  4.6318743 ]
GSAR:  [0.19871622 5.86222172]



 44%|████▍     | 101/230 [05:22<07:24,  3.45s/it][A
 44%|████▍     | 102/230 [05:24<06:45,  3.17s/it][A
 45%|████▍     | 103/230 [05:27<06:19,  2.99s/it][A
 45%|████▌     | 104/230 [05:29<05:59,  2.85s/it][A
 46%|████▌     | 105/230 [05:33<06:32,  3.14s/it][A
 46%|████▌     | 106/230 [05:37<07:14,  3.51s/it][A
 47%|████▋     | 107/230 [05:42<07:37,  3.72s/it][A
 47%|████▋     | 108/230 [05:44<06:49,  3.36s/it][A
 47%|████▋     | 109/230 [05:47<06:28,  3.21s/it][A
 48%|████▊     | 110/230 [05:50<06:01,  3.01s/it][A
 48%|████▊     | 111/230 [05:54<06:40,  3.36s/it][A
 49%|████▊     | 112/230 [05:56<06:06,  3.11s/it][A
 49%|████▉     | 113/230 [05:59<05:54,  3.03s/it][A
 50%|████▉     | 114/230 [06:03<06:33,  3.39s/it][A
 50%|█████     | 115/230 [06:06<06:12,  3.24s/it][A
 50%|█████     | 116/230 [06:09<05:57,  3.14s/it][A
 51%|█████     | 117/230 [06:13<06:18,  3.35s/it][A
 51%|█████▏    | 118/230 [06:16<06:00,  3.22s/it][A
 52%|█████▏    | 119/230 [06:18<05:35,  3.02s

GNSDR:  [-6.35204717  1.07761307]
GSIR:  [-1.97442281  4.81791182]
GSAR:  [0.69536917 5.53295237]



 53%|█████▎    | 121/230 [06:24<05:08,  2.83s/it][A
 53%|█████▎    | 122/230 [06:28<05:42,  3.17s/it][A
 53%|█████▎    | 123/230 [06:31<05:30,  3.09s/it][A
 54%|█████▍    | 124/230 [06:34<05:18,  3.00s/it][A
 54%|█████▍    | 125/230 [06:36<05:12,  2.97s/it][A
 55%|█████▍    | 126/230 [06:39<04:53,  2.83s/it][A
 55%|█████▌    | 127/230 [06:43<05:34,  3.25s/it][A
 56%|█████▌    | 128/230 [06:46<05:21,  3.16s/it][A
 56%|█████▌    | 129/230 [06:49<04:58,  2.96s/it][A
 57%|█████▋    | 130/230 [06:51<04:41,  2.82s/it][A
 57%|█████▋    | 131/230 [06:55<05:08,  3.12s/it][A
 57%|█████▋    | 132/230 [06:59<05:29,  3.36s/it][A
 58%|█████▊    | 133/230 [07:01<05:00,  3.10s/it][A
 58%|█████▊    | 134/230 [07:05<05:18,  3.32s/it][A
 59%|█████▊    | 135/230 [07:08<04:51,  3.06s/it][A
 59%|█████▉    | 136/230 [07:11<04:44,  3.03s/it][A
 60%|█████▉    | 137/230 [07:13<04:27,  2.88s/it][A
 60%|██████    | 138/230 [07:17<04:52,  3.18s/it][A
 60%|██████    | 139/230 [07:21<05:17,  3.48s

GNSDR:  [-6.35106487  0.96798518]
GSIR:  [-2.16544579  4.98588949]
GSAR:  [1.00192306 5.23244392]



 61%|██████▏   | 141/230 [07:27<04:58,  3.36s/it][A
 62%|██████▏   | 142/230 [07:32<05:18,  3.62s/it][A
 62%|██████▏   | 143/230 [07:35<05:21,  3.69s/it][A
 63%|██████▎   | 144/230 [07:38<04:48,  3.36s/it][A
 63%|██████▎   | 145/230 [07:41<04:32,  3.20s/it][A
 63%|██████▎   | 146/230 [07:45<04:54,  3.50s/it][A
 64%|██████▍   | 147/230 [07:48<04:25,  3.20s/it][A
 64%|██████▍   | 148/230 [07:52<04:41,  3.43s/it][A
 65%|██████▍   | 149/230 [07:56<04:56,  3.66s/it][A
 65%|██████▌   | 150/230 [07:58<04:27,  3.34s/it][A
 66%|██████▌   | 151/230 [08:02<04:35,  3.49s/it][A
 66%|██████▌   | 152/230 [08:05<04:13,  3.25s/it][A
 67%|██████▋   | 153/230 [08:08<04:02,  3.15s/it][A
 67%|██████▋   | 154/230 [08:12<04:15,  3.36s/it][A
 67%|██████▋   | 155/230 [08:16<04:30,  3.61s/it][A
 68%|██████▊   | 156/230 [08:19<04:12,  3.42s/it][A
 68%|██████▊   | 157/230 [08:21<03:48,  3.13s/it][A
 69%|██████▊   | 158/230 [08:25<04:01,  3.36s/it][A
 69%|██████▉   | 159/230 [08:28<03:41,  3.11s

GNSDR:  [-6.35884967  0.83323741]
GSIR:  [-2.13854785  4.80789913]
GSAR:  [0.93055088 5.1600469 ]



 70%|███████   | 161/230 [08:34<03:48,  3.32s/it][A
 70%|███████   | 162/230 [08:37<03:27,  3.05s/it][A
 71%|███████   | 163/230 [08:41<03:41,  3.30s/it][A
 71%|███████▏  | 164/230 [08:44<03:28,  3.16s/it][A
 72%|███████▏  | 165/230 [08:48<03:45,  3.48s/it][A
 72%|███████▏  | 166/230 [08:52<03:59,  3.74s/it][A
 73%|███████▎  | 167/230 [08:55<03:32,  3.37s/it][A
 73%|███████▎  | 168/230 [08:57<03:19,  3.21s/it][A
 73%|███████▎  | 169/230 [09:02<03:33,  3.49s/it][A
 74%|███████▍  | 170/230 [09:06<03:41,  3.70s/it][A
 74%|███████▍  | 171/230 [09:10<03:41,  3.75s/it][A
 75%|███████▍  | 172/230 [09:14<03:39,  3.79s/it][A
 75%|███████▌  | 173/230 [09:16<03:13,  3.39s/it][A
 76%|███████▌  | 174/230 [09:20<03:18,  3.55s/it][A
 76%|███████▌  | 175/230 [09:23<03:02,  3.32s/it][A
 77%|███████▋  | 176/230 [09:25<02:44,  3.05s/it][A
 77%|███████▋  | 177/230 [09:28<02:39,  3.01s/it][A
 77%|███████▋  | 178/230 [09:31<02:34,  2.96s/it][A
 78%|███████▊  | 179/230 [09:34<02:29,  2.93s

GNSDR:  [-6.31167148  0.73460041]
GSIR:  [-2.03101033  4.66975966]
GSAR:  [0.84885052 5.11172873]



 79%|███████▊  | 181/230 [09:41<02:38,  3.23s/it][A
 79%|███████▉  | 182/230 [09:43<02:29,  3.12s/it][A
 80%|███████▉  | 183/230 [09:46<02:17,  2.93s/it][A
 80%|████████  | 184/230 [09:49<02:13,  2.89s/it][A
 80%|████████  | 185/230 [09:53<02:26,  3.26s/it][A
 81%|████████  | 186/230 [09:57<02:34,  3.52s/it][A
 81%|████████▏ | 187/230 [09:59<02:18,  3.22s/it][A
 82%|████████▏ | 188/230 [10:03<02:24,  3.43s/it][A
 82%|████████▏ | 189/230 [10:08<02:29,  3.65s/it][A
 83%|████████▎ | 190/230 [10:10<02:12,  3.32s/it][A
 83%|████████▎ | 191/230 [10:13<02:00,  3.08s/it][A
 83%|████████▎ | 192/230 [10:17<02:10,  3.42s/it][A
 84%|████████▍ | 193/230 [10:19<01:56,  3.15s/it][A
 84%|████████▍ | 194/230 [10:22<01:46,  2.96s/it][A
 85%|████████▍ | 195/230 [10:25<01:42,  2.94s/it][A
 85%|████████▌ | 196/230 [10:29<01:48,  3.20s/it][A
 86%|████████▌ | 197/230 [10:31<01:42,  3.11s/it][A
 86%|████████▌ | 198/230 [10:34<01:33,  2.93s/it][A
 87%|████████▋ | 199/230 [10:38<01:39,  3.20s

GNSDR:  [-6.32929306  0.78525739]
GSIR:  [-2.06630626  4.71751401]
GSAR:  [0.82717496 5.10445043]



 87%|████████▋ | 201/230 [10:45<01:32,  3.20s/it][A
 88%|████████▊ | 202/230 [10:48<01:35,  3.40s/it][A
 88%|████████▊ | 203/230 [10:52<01:35,  3.55s/it][A
 89%|████████▊ | 204/230 [10:56<01:35,  3.67s/it][A
 89%|████████▉ | 205/230 [11:00<01:33,  3.72s/it][A
 90%|████████▉ | 206/230 [11:03<01:22,  3.44s/it][A
 90%|█████████ | 207/230 [11:07<01:22,  3.57s/it][A
 90%|█████████ | 208/230 [11:10<01:13,  3.36s/it][A
 91%|█████████ | 209/230 [11:14<01:15,  3.61s/it][A
 91%|█████████▏| 210/230 [11:18<01:15,  3.79s/it][A
 92%|█████████▏| 211/230 [11:21<01:01,  3.23s/it]


#STACKED HOURGLASS 4 - 1500 ITERATIONS

In [None]:
MAX_ITERATIONS = 1500
NUM_STACKS = 4  
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train():
    net = get_model()
    net.to(device)
    training_data, training_indices_array = get_train_data()
    loss_s = torch.empty(1).to(device)
    optimiser = torch.optim.Adam(net.parameters(), lr=0.001)
    for i in tqdm.tqdm(range(ITERATIONS)):
        # declaring input and output tensors
        input = torch.empty(BATCH_SIZE, 1, 512, 64)
        output = torch.empty(BATCH_SIZE, 2, 512, 64)
        
        for j in range(BATCH_SIZE):
            indexing, starting_index = training_indices_array[np.random.randint(len(training_indices_array))]  
            left_mag, right_mag, mixed_mag = training_data[indexing]
            input[j, 0, :, :] = torch.from_numpy(mixed_mag[:512, starting_index:starting_index + 64])
            output[j, 0, :, :] = torch.from_numpy(left_mag[:512, starting_index:starting_index + 64])
            output[j, 1, :, :] = torch.from_numpy(right_mag[:512, starting_index:starting_index + 64])
            input = input.to(device)
            output = output.to(device)
        optimiser.zero_grad()
        predicted = net(input).to(device)
 
        loss = sum(torch.mean(torch.abs(predicted[x].mul(input) - output)) for x in range(1))
        if loss>=0:
          loss.backward()
          loss_s += loss
          optimiser.step()
        if (i + 1) % TRAIN_SAVE_POINT == 0:
            torch.save(net.state_dict(), 'Model4(1)/checkpoint_{}.pt'.format(i + 1))
            print( "loss is {}".format(loss_s / TRAIN_SAVE_POINT,))
            loss_s = 0
    
def test(model='Model4(1)/checkpoint_1500.pt'):
    models = get_model()
    models.load_state_dict(torch.load(model),strict=False)
    models.to(device)
    input = np.empty((BATCH_SIZE, 1, 512, 64), dtype=np.float32)
    length = 0.
    cnt = 0
    gnsdr = 0.
    gsir = 0.
    gsar = 0.
    for left, right, mix, left_magnitude, right_magnitude, mix_magnitude, max_value, mix_phase in data_generator(train_d,train=False):
        source_length = mix_magnitude.shape[-1]
        start_ind = 0
        predicts_left = np.zeros((512, source_length), dtype=np.float32)
        predicts_right = np.zeros((512, source_length), dtype=np.float32)
        
        start = time.time()
        while start_ind + 64 < source_length:
            
            if start_ind and start_ind + (BATCH_SIZE - 1) * 32 + 64 < source_length:
                for i in range(BATCH_SIZE):
                    input[i, 0, :, :] = mix_magnitude[0:512, start_ind + i * 32:start_ind + i * 32 + 64]
                output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy() 
                for i in range(BATCH_SIZE):
                    predicts_left[:, start_ind + i * 32 + 16: start_ind + i * 32 + 48] = output[i, 0, :, 16:48]
                    predicts_right[:, start_ind + i * 32 + 16: start_ind + i * 32 + 48] = output[i, 1, :, 16:48]
                start_ind += BATCH_SIZE * 32
            else:
                input[0, 0, :, :] = mix_magnitude[0:512, start_ind:start_ind + 64]
                output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy() 
                if start_ind == 0:
                    predicts_left[:, 0:64] = output[0, 0, :, :]
                    predicts_right[:, 0:64] = output[0, 1, :, :]
                else:
                    predicts_left[:, start_ind + 16: start_ind + 48] = output[0, 0, :, 16:48]
                    predicts_right[:, start_ind + 16:start_ind + 48] = output[0, 1, :, 16:48]
                start_ind += 32
        input[0, 0, :, :] = mix_magnitude[0:512, source_length - 64:source_length]
        output = models(torch.from_numpy(input).to(device))[-1].data.cpu().numpy()  
        length = source_length - start_ind - 16
        predicts_left[:, start_ind + 16:source_length] = output[0, 0, :, 64 - length:64]
        predicts_right[:, start_ind + 16:source_length] = output[0, 1, :, 64 - length:64]
        
        predicts_left =np.nan_to_num(predicts_left)
        predicts_right =np.nan_to_num(predicts_right)
        predicts_left[np.where(predicts_left < 0)] = 0
        predicts_right[np.where(predicts_right < 0)] = 0
        predicts_left = predicts_left * mix_magnitude[0:512, :] * max_value
        predicts_right = predicts_right * mix_magnitude[0:512, :] * max_value
        predicts_left_wav = to_wav(predicts_left, mix_phase[0:512, :])
        predicts_right_wav = to_wav(predicts_right, mix_phase[0:512, :])
        predicts_left_wav = librosa.resample(predicts_left_wav, 8000, 16000)
        predicts_right_wav = librosa.resample(predicts_right_wav, 8000, 16000)
        sdr, sir, sar, lens = bss_eval(mix, left, right, predicts_left_wav,
                                              predicts_right_wav)
        length = length + lens
        gnsdr = gnsdr + sdr * lens
        gsir = gsir + sir * lens
        gsar = gsar + sar * lens
        cnt += 1
        if cnt % TEST_STEP == 0:
            print('GNSDR: ', gnsdr / length)
            print('GSIR: ', gsir / length)
            print('GSAR: ', gsar / length)
            #SAVING THE AUDIO FILES - FOR DEMO PURPOSE       
            soundfile.write('/content/drive/MyDrive/test_wavfiles/{}_accompaniments_predict.wav'.format(cnt), predicts_left_wav, 16000, format='wav',subtype='PCM_16')
            soundfile.write('/content/drive/MyDrive/test_wavfiles/{}_voice_predict.wav'.format(cnt), predicts_right_wav, 16000, format='wav',subtype='PCM_16')
            soundfile.write('/content/drive/MyDrive/test_wavfiles/{}mixed.wav'.format(cnt), mix, 16000, format='wav',subtype='PCM_16')
        if cnt == TOTAL_TEST:
            break


In [None]:
train()

  3%|▎         | 50/1500 [00:19<12:43,  1.90it/s]

loss of 50 is tensor([0.0393], device='cuda:0', grad_fn=<DivBackward0>)


  7%|▋         | 100/1500 [00:38<15:08,  1.54it/s]

loss of 100 is 0.02935784123837948


 10%|█         | 150/1500 [00:57<12:16,  1.83it/s]

loss of 150 is 0.027406053617596626


 13%|█▎        | 200/1500 [01:17<11:51,  1.83it/s]

loss of 200 is 0.024519337341189384


 17%|█▋        | 250/1500 [01:36<11:32,  1.80it/s]

loss of 250 is 0.02354545332491398


 20%|██        | 300/1500 [01:56<10:51,  1.84it/s]

loss of 300 is 0.022520629689097404


 23%|██▎       | 350/1500 [02:16<10:51,  1.76it/s]

loss of 350 is 0.021910475566983223


 27%|██▋       | 400/1500 [02:36<10:13,  1.79it/s]

loss of 400 is 0.020458854734897614


 30%|███       | 450/1500 [02:56<09:53,  1.77it/s]

loss of 450 is 0.02156914211809635


 33%|███▎      | 500/1500 [03:16<09:15,  1.80it/s]

loss of 500 is 0.01852353662252426


 37%|███▋      | 550/1500 [03:36<09:00,  1.76it/s]

loss of 550 is 0.019005117937922478


 40%|████      | 600/1500 [03:57<10:06,  1.48it/s]

loss of 600 is 0.018992483615875244


 43%|████▎     | 650/1500 [04:18<09:08,  1.55it/s]

loss of 650 is 0.018618136644363403


 47%|████▋     | 700/1500 [04:38<08:15,  1.61it/s]

loss of 700 is 0.018063334748148918


 50%|█████     | 750/1500 [04:59<07:02,  1.77it/s]

loss of 750 is 0.01740381121635437


 53%|█████▎    | 800/1500 [05:19<06:35,  1.77it/s]

loss of 800 is 0.016836626455187798


 57%|█████▋    | 850/1500 [05:39<06:14,  1.73it/s]

loss of 850 is 0.017995329573750496


 60%|██████    | 900/1500 [06:00<05:38,  1.77it/s]

loss of 900 is 0.016789371147751808


 63%|██████▎   | 950/1500 [06:20<05:14,  1.75it/s]

loss of 950 is 0.016724102199077606


 67%|██████▋   | 1000/1500 [06:41<04:43,  1.76it/s]

loss of 1000 is 0.016961731016635895


 70%|███████   | 1050/1500 [07:01<04:12,  1.78it/s]

loss of 1050 is 0.01625293493270874


 73%|███████▎  | 1100/1500 [07:21<03:52,  1.72it/s]

loss of 1100 is 0.01629212498664856


 77%|███████▋  | 1150/1500 [07:42<03:19,  1.75it/s]

loss of 1150 is 0.01641121320426464


 80%|████████  | 1200/1500 [08:02<02:51,  1.75it/s]

loss of 1200 is 0.015350810252130032


 83%|████████▎ | 1250/1500 [08:23<02:20,  1.78it/s]

loss of 1250 is 0.015868697315454483


 87%|████████▋ | 1300/1500 [08:43<01:54,  1.75it/s]

loss of 1300 is 0.015344792045652866


 90%|█████████ | 1350/1500 [09:04<01:25,  1.75it/s]

loss of 1350 is 0.014491170644760132


 93%|█████████▎| 1400/1500 [09:24<00:57,  1.74it/s]

loss of 1400 is 0.014679123647511005


 97%|█████████▋| 1450/1500 [09:44<00:28,  1.72it/s]

loss of 1450 is 0.014274033717811108


100%|██████████| 1500/1500 [10:05<00:00,  2.48it/s]

loss of 1500 is 0.014869502745568752





In [None]:
test()

  9%|▊         | 20/230 [00:44<07:33,  2.16s/it]

GNSDR:  [7.77796867 8.73091081]
GSIR:  [10.84189001 16.81503898]
GSAR:  [11.24641794  9.65255729]


 17%|█▋        | 40/230 [01:36<07:47,  2.46s/it]

GNSDR:  [6.9511877  7.65953761]
GSIR:  [10.13924474 15.61238917]
GSAR:  [10.30863059  8.63456765]


 26%|██▌       | 60/230 [02:27<07:54,  2.79s/it]

GNSDR:  [8.57331129 8.98134124]
GSIR:  [12.36467046 16.07386327]
GSAR:  [11.45988794 10.23684241]


 35%|███▍      | 80/230 [03:17<06:16,  2.51s/it]

GNSDR:  [9.10301763 9.10276264]
GSIR:  [13.37360851 15.30408567]
GSAR:  [11.6508171  10.79820057]


 43%|████▎     | 100/230 [04:10<06:09,  2.84s/it]

GNSDR:  [8.73340349 9.34807895]
GSIR:  [12.51486088 15.48056645]
GSAR:  [11.96988627 11.02189892]


 52%|█████▏    | 120/230 [04:58<04:22,  2.39s/it]

GNSDR:  [8.22616844 9.16455648]
GSIR:  [11.7016312  15.01600306]
GSAR:  [11.84744303 11.01697091]


 61%|██████    | 140/230 [05:55<03:36,  2.40s/it]

GNSDR:  [7.74474738 8.84694686]
GSIR:  [10.9639156  14.55852014]
GSAR:  [11.74392168 10.81249395]


 70%|██████▉   | 160/230 [06:46<02:56,  2.52s/it]

GNSDR:  [7.54977433 8.7127773 ]
GSIR:  [10.67447484 14.29822948]
GSAR:  [11.6743563  10.75438326]


 78%|███████▊  | 180/230 [07:37<01:53,  2.27s/it]

GNSDR:  [7.71315877 8.77043738]
GSIR:  [10.98418569 14.33073764]
GSAR:  [11.62516719 10.81902736]


 87%|████████▋ | 200/230 [08:29<01:08,  2.29s/it]

GNSDR:  [7.86381275 8.87897298]
GSIR:  [11.21312622 14.49078417]
GSAR:  [11.62958388 10.86449444]


 96%|█████████▌| 220/230 [09:12<00:20,  2.09s/it]

GNSDR:  [8.03539315 8.95607098]
GSIR:  [11.52434767 14.45645287]
GSAR:  [11.66143523 10.98371314]


100%|██████████| 230/230 [09:34<00:00,  2.50s/it]
