## 1. 데이터 불러오기

In [1]:
# 코드 참조 https://github.com/SunnerLi/SVS-UNet-PyTorch/blob/master/train.py
import torch
import random
import time
import os
import numpy as np
from tqdm import tqdm_notebook
import torch.utils.data as Data

class SpectrogramDataset(Data.Dataset):
    def __init__(self, path, is_sampling):
        self.path = path
        self.files = sorted(os.listdir(os.path.join(path, 'mixture')))
        self.files = [name for name in self.files if 'spec' in name]
        self.is_sampling = is_sampling

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        # Load the spectrogram
        mix = np.load(os.path.join(self.path, 'mixture', self.files[idx]))
        voc = np.load(os.path.join(self.path, 'vocals', self.files[idx]))

        # Random sample
        
        start = random.randint(0, mix.shape[-1] - 512 - 1) if self.is_sampling else 0
        end   = 512 if self.is_sampling else mix.shape[-1]
        mix = mix[1:,start:start + end] #, np.newaxis]
        voc = voc[1:,start:start + end] #, np.newaxis]


        mix = np.asarray(mix, dtype=np.float32)
        voc = np.asarray(voc, dtype=np.float32)
           
        # To tensor
        mix = torch.from_numpy(mix) #.permute(2, 0, 1)
        voc = torch.from_numpy(voc) #.permute(2, 0, 1)
        return mix, voc

## 2. Baseline

In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable


class BiLSTMEncoder(nn.Module):

    def __init__(self, N, E):
        """
        Constructing blocks of the model based
        on the sparse skip-filtering connections.
        Args :
            N      : (int) Original dimensionallity of the input.
            E      : (int) Encoding size.
        """
        super(BiLSTMEncoder, self).__init__()
        self._N = N
        self._E = E
        
        self.LSTM = nn.LSTM(input_size = self._N, num_layers=2, hidden_size=self._E, batch_first = True, dropout = 0, bidirectional = True)
    
    def forward (self, X):
        # input: [B, T(unknown), N]
        # output: [B, T, 2E(bidirectional)]
        return self.LSTM(X)[0]

import torch
import torch.nn as nn
from torch.autograd import Variable

class Masker (nn.Module):

    def __init__(self, E, N):
        super(Masker, self).__init__()
        self._N = N
        self._E = E
        
        self.masker = nn.Sequential(
            nn.Linear(2*self._E, self._E),
            nn.ReLU(),
            # nn.LayerNorm(),
            nn.Linear(self._E, self._N),
            nn.Sigmoid(),
            # nn.LayerNorm()
        )    
    def forward (self, X):
        # input: [B, T, 2E(bidirectional)]
        # output: [B, T, N]
        return self.masker(X)
    
class EncoderMasker (nn.Module):
    def __init__(self):
        super(EncoderMasker, self).__init__()
    
    def forward(self, X, encoder, masker):
        mask = masker(encoder(X.permute(0,2,1))).permute(0,2,1)
        return X*mask

## 3. Training

In [None]:
import matplotlib.pyplot as plt
from IPython import display 
import torch.optim as optim
from tqdm import tqdm_notebook

train_folder = 'data/musdb18/preprocessed/plain/valid'

loader = Data.DataLoader(
    dataset = SpectrogramDataset(train_folder, is_sampling=True),
    batch_size=1, num_workers=0, shuffle=True
)

N = 512
E = 512

encoder = BiLSTMEncoder(N, E) #.cuda()
masker = Masker(E,N) #.cuda()
encoder_masker = EncoderMasker()

criterion = nn.MSELoss()
encoder_optim = optim.Adagrad(encoder.parameters())
masker_optim = optim.Adagrad(masker.parameters())

    
num_iter = 100
loss_trace = []
smooth_loss=0
print_step = int(num_iter/100.)


encoder.train()
masker.train()

    
for iter in range(num_iter):
    
    loss_sum = 0

    for mix, voc in tqdm_notebook(loader):
        
        encoder_optim.zero_grad()
        masker_optim.zero_grad()       
        
        # mix, voc = mix.cuda(), voc.cuda()
        y_hat = encoder_masker(mix, encoder, masker)
        loss = criterion(y_hat, voc)
        loss.backward()
        encoder_optim.step()
        masker_optim.step()

        loss_sum += loss.item()

        
    if( (iter+1) % print_step == 0):
        display.clear_output(wait=True)
        
        if(smooth_loss == 0):
            smooth_loss = loss_sum*1.2
        else:
            smooth_loss = 0.999*smooth_loss + 0.001*loss_sum

        loss_trace.append(smooth_loss)
        print('print step', print_step, ' iter' , iter)
        plt.plot(loss_trace)
        plt.show()


HBox(children=(IntProgress(value=0, max=14), HTML(value='')))