In [2]:
import numpy as np
import musdb
import IPython.display as ipd
import openunmix as opmux
import torch
import h5py

from tqdm.autonotebook import tqdm

# Setup

In [3]:
musdb_path = "/home/paco/TFM/data/MUSDB18/"
data_path = "/home/paco/TFM/data/"

In [4]:
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f5fe18fa350>

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}

In [6]:
print("Num GPUs Available: ", torch.cuda.get_device_name())
print("GPU available:", torch.cuda.is_available())

Num GPUs Available:  GeForce RTX 3060 Laptop GPU
GPU available: True


# Definimos los generadores de datos...

In [7]:
def h5_generator(data, source, target, n_samples=2**13, chunk_size=2**9, batch_size=8, randomize=True):
    # Cogemos n_samples elementos contiguos del dataset, empezando en algún punto al azar
    if(randomize):
        starting_point = len(data[source])-n_samples
        starting_point= 1 if(starting_point < 1) else starting_point
        starting_point = np.random.choice(starting_point)
    else:
        starting_point = 0
    data_source = data[source]
    data_target = data[target]
    
    # Si estamos pidiendo más muestras de las que hay, las devolvemos todas
    if((starting_point+n_samples) > data_source.shape[0]):
        n_samples = data_source.shape[0] - starting_point
    
    # Leemos los datos de un chunk
    for chunk_start in np.arange(starting_point, starting_point+n_samples, chunk_size):
        chunk_source = data_source[chunk_start:(chunk_start+chunk_size)]
        chunk_target = data_target[chunk_start:(chunk_start+chunk_size)]
        # Shuffle de source y target
        random_order = np.random.permutation(len(chunk_source))
        chunk_source = chunk_source[random_order]
        chunk_target = chunk_target[random_order]
        
        for jj in np.arange(0,len(chunk_source), 8):
            # Generamos predictor y target (audio original y la sección del target) y los cargamos como tensores de pytorch en "device" (gpu normalmente)
            x = chunk_source[jj:jj+8]
            y = chunk_target[jj:jj+8]
            # Devolvemos x e y
            yield x, y

In [8]:
class musdb_dataset(torch.utils.data.Dataset):
    def __init__(self,data,target='vocals',sequence_length=6):
        self.sequence_length = sequence_length
        self.target = target
        self.data = data

    def __getitem__(self, index):
        # Cogemos una canción al azar (en principio, si no contemplamos reemplazamiento, el batch size está limitado por el número de canciones en el dataset)
        track = self.data[index]
        # Generamos un trozo al azar y recortamos
        track.chunk_start = np.random.uniform(0, (track.duration - self.sequence_length))
        track.chunk_duration = self.sequence_length
        # Generamos predictor y target (audio original y la sección del target) y los cargamos como tensores de pytorch en "device" (gpu normalmente)
        x = torch.tensor(track.audio.T, dtype=torch.float32, device=device)
        y = torch.tensor(track.targets[self.target].audio.T, dtype=torch.float32, device=device)
        
        # Devolvemos x e y
        return x, y 


    def __len__(self):
        return len(self.data)

## Instanciamos los generadores de datos

In [9]:
batch_size = 256
subtrack_length = 3
source = 'source'
target = 'vocals'
# 8192 de training
train_samples = 2**13
# Todas las muestras de validación (1227)
val_samples = 2**11

In [11]:
train_file = h5py.File(data_path+'train_shuffled.h5', 'r')
val_file = h5py.File(data_path+'val_shuffled.h5', 'r')
test_file = h5py.File(data_path+'test_shuffled.h5', 'r')

In [12]:
# train_dset =  musdb_dataset(data=musdb.DB(root=musdb_path, subsets=['train'], split='train'), target='vocals', sequence_length=subtrack_length)
# val_dset =  musdb_dataset(data=musdb.DB(root=musdb_path, subsets=['train'], split='valid'), target='vocals', sequence_length=subtrack_length)
# test_dset =  musdb_dataset(data=musdb.DB(root=musdb_path, subsets=['test']), target='vocals', sequence_length=subtrack_length)

In [14]:
test_dset =  musdb.DB(root=musdb_path, subsets=['test'])

In [103]:
# Testing
def test_musdb(in_model, in_data, in_target, in_seq_length = 3 * 44100, overlap_step=1, pred_batch_size = 256):
    # Aseguramos que estamos en modo de evaluación
#     in_model = in_model.eval()
    
    # Para cada track en data...
    for track in in_data:
        # Extraemos el audio de la mezcla
        full_X = track.audio
        # Extraemos el audio del target
        true_y = track.targets[in_target].audio
        
        x = []
        x_indices = []

        # Extraemos las muestras a predecir (la primera dimensión es el canal*)
        for idx in np.arange(0,full_X.shape[0], int(overlap_step * in_seq_length)):
            subtrack_padded = np.zeros((in_seq_length, full_X.shape[1]))
            subtrack = full_X[idx:(idx+in_seq_length)]
            subtrack_padded[:subtrack.shape[0]] = subtrack
            x.append(subtrack_padded)
            x_indices.append((idx, (idx+in_seq_length)))
        
        # Pasamos las muestras por el modelo para obtener la predicción
        # Creamos el tensor de salida
        pred_y = torch.zeros(true_y.shape, dtype=torch.float32)
        pred_y_samples = torch.zeros_like(pred_y)
        
        # Pasamos las muestras a tensor de Torch
        x = torch.tensor(x, dtype=torch.float32)
        return x,x_indices
        # Iteramos sobre los batch, realizando la predicción en cada uno
        for idx in np.arange(0,x.shape[0], batch_size):
            # Sacamos el batch
            batch_x = x[idx:idx+batch_size]
            batch_indices = x_indices[idx:idx+batch_size]
            
            # Lo enviamos a device, codificamos con STFT
            batch_pred = model.encoder_stft(batch_x.to(device))
            # Predecimos
            batch_pred = model(batch_pred)
            # Decodificamos con stft
            batch_pred = model.decoder_stft(batch_pred, length=in_seq_length)
            # Lo enviamos a cpu de nuevo
            batch_pred = batch_pred.cpu()
            
            for ii in np.arange(batch_x.shape[0]):
                idx_start = batch_indices[ii][0]
                idx_end = batch_indices[ii][1]

                if(idx_end > pred_y.shape[0]):
                    idx_end = (pred_y.shape[0]-idx_start)
                
                pred_y[idx_start:idx_end] = batch_x[ii][0:(idx_end-idx_start)]
                pred_y_samples[idx_start:idx_end] = (pred_y_samples[idx_start:idx_end]+1)
                
                
        
        

In [39]:
test_dset[0].audio.shape[0]

9256960

In [97]:
torch.tensor([1,2,3]).shape

torch.Size([3])

In [104]:
%%time
t = test_musdb(None, test_dset, 'vocals', subtrack_length * test_dset[0].rate, overlap_step=1)

CPU times: user 4.86 s, sys: 482 ms, total: 5.34 s
Wall time: 5.9 s


In [110]:
[v[1]-v[0] for v in t[1]]

[132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300,
 132300]

In [31]:
len([s.shape[1] for s in t])

35

In [11]:
# train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=True)
# val_loader = torch.utils.data.DataLoader(val_dset, batch_size=batch_size, shuffle=True)
# test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, shuffle=True)

# Definimos la arquitectura de la red...

### Primero, la transformada de Fourier para codificar/decodificar

In [49]:
sample_idx = 2

In [50]:
sample = (model.decoder_stft(y, length=132300))
# sample = (sample*train_std[...,:sample.shape[-1]]) + train_mean[...,:sample.shape[-1]]
sample = sample[sample_idx].cpu().detach()
ipd.Audio(sample, rate=44100)

In [51]:
sample = (model.decoder_stft(y_pred, length=132300))
# sample = (sample*train_std[...,:sample.shape[-1]]) + train_mean[...,:sample.shape[-1]]
sample = sample[sample_idx].cpu().detach()
ipd.Audio(sample, rate=44100)

In [52]:
sample = (model.decoder_stft(x, length=132300))
# sample = (sample*train_std[...,:sample.shape[-1]]) + train_mean[...,:sample.shape[-1]]
sample = sample[sample_idx].cpu().detach()
ipd.Audio(sample, rate=44100)

In [None]:
# Bloque conv. de https://arxiv.org/pdf/1809.07454.pdf
class DConv(torch.nn.Module):
    def __init__(self, k):
        super().__init__()
        self.conv_init = torch.nn.Conv1d(in_channels = 8, 
                                    out_channels = 8,
                                    kernel_size=1,
                                    groups=1,
                                    padding_mode='circular',)
        
        self.dconv = torch.nn.Conv1d(in_channels = 8,
                                     out_channels = k*8, 
                                     kernel_size=1,
                                     groups=8,
                                     padding_mode='circular',)
        self.conv_final = torch.nn.Conv1d(in_channels = 8,
                                          out_channels = 8,
                                          kernel_size=1,
                                          groups=1,
                                          padding_mode='circular',)
        self.conv_skip =  torch.nn.Conv1d(in_channels = 8,
                                          out_channels = 8,
                                          kernel_size=1,
                                          groups=1,
                                          padding_mode='circular',)
        
        self.activation = torch.nn.ReLU()
        self.norm_conv = torch.nn.BatchNorm1d(8)
        self.norm_dconv = torch.nn.BatchNorm1d(k*8)

    def forward(self, x):
        x = self.conv_init(x)
        x = self.activation(x)
        x = self.norm_conv(x)
        
        x = self.dconv(x)
        x = self.activation(x)
        x = self.norm_dconv(x)
        
        x_skip = self.conv_skip(x)
        x = self.conv_final(x)
        
        return (x, x_skip)

In [None]:
class SepConv_1(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bn_init = torch.nn.BatchNorm1d(2)
        self.conv_init = torch.nn.Conv1d(in_channels = 2, 
                                         out_channels = 8, 
                                         kernel_size=1,
                                         groups=1,padding_mode='circular',)
        
        self.conv_block_1 = DConv(k=1)
        self.conv_block_1_conv = torch.nn.Conv1d(in_channels = 8, 
                                                 out_channels = 8, 
                                                 kernel_size=1,
                                                 groups=1,padding_mode='circular',)
        
        self.conv_final = torch.nn.Conv1d(in_channels = 8, 
                                         out_channels = 2, 
                                         kernel_size=1,
                                         groups=1,padding_mode='circular',)

    def forward(self, x):
        # Guardamos la entrada sin procesar
        pre_mix = x
        # Las convoluciones esperan los datos en formato (batch, channel, sequence)
        x = self.conv_init(x)
        block_skip = x
        x,skip_1 = self.conv_block_1(x)
        
        x =x+skip_1
        x = self.conv_final(x)
        # Aplicamos x a la entrada como una máscara (producto element-wise)
        x = pre_mix * x
        
        return x

In [None]:
# https://github.com/naplab/Conv-TasNet/blob/master/utility/sdr.py
class SDRLoss(torch.nn.Module):
    def __init__(self, eps=1e-4):
        super(SDRLoss, self).__init__()
        self.eps = eps
            
    def forward(self, estimation, origin):
        """
        batch-wise SDR caculation for one audio file on pytorch Variables.
        estimation: (batch, nsample)
        origin: (batch, nsample)
        mask: optional, (batch, nsample), binary
        """

        origin_power = torch.pow(origin, 2).sum((1,2), keepdim=True)  # (batch, 1)

        scale = torch.sum(origin*estimation, (1,2), keepdim=True) / (origin_power + self.eps)  # (batch, 1)
        
        est_true = scale * origin  # (batch, nsample)
        est_res = estimation - est_true  # (batch, nsample)

        true_power = torch.pow(est_true, 2).sum((1,2))
        res_power = torch.pow(est_res, 2).sum((1,2))
        # La invertimos para que la red pueda minimizar el valor
        sdr = -(10*torch.log10(true_power) - 10*torch.log10(res_power))        
        return torch.nansum(sdr)  # (batch, 1)

In [None]:
# %%time
# Número de épocas a entrenar
n_epochs = 50
init_lr = 1e-4
min_lr = 1e-7
lr_factor = 1e-1

# val_loss_deque = deque([np.inf,]*5)
val_loss_counter = 0

train_samples = train_samples if (train_samples < len(train_file[source])) else len(train_file[source])
val_samples = val_samples if (val_samples < len(val_file[source])) else len(val_file[source])

progress_bar = tqdm(np.arange(n_epochs * train_samples))

# Instanciamos el modelo
# model = RecSepSTFT_1(data_mean = train_mean, data_std = train_std)
model = SepConv_1()

# Cargamos el modelo a GPU/dispositivo, junto con las transformadas
model.to(device)

# Definimos el optimizador y la función de loss
optimizer = torch.optim.RMSprop(model.parameters(),
#                                 momentum=0.2,
#                                 weight_decay=1e-5, 
                                lr=init_lr,)
loss_f = torch.nn.MSELoss()
# loss_f = SDRLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, 
                                                       mode='min', 
                                                       patience=3, 
                                                       factor=lr_factor, 
                                                       threshold=min_lr/lr_factor, # Dividimos entre el factor para usarlo en el early stopping
                                                       verbose=True)

# Iteramos cada época
for epoch in np.arange(n_epochs):
    # Iteramos sobre los datos
    train_loss = 0
    # Ponemos el modelo en modo train
    model.train()
#     for x,y in train_loader:
    for x,y in  h5_generator(train_file, source, target, train_samples, batch_size, True):
        # Anulamos los gradientes
        optimizer.zero_grad()
        # Cargamos los datos del batch en GPU/dispositivo
        x = torch.tensor(x, dtype=torch.float32, device=device)
        y = torch.tensor(y, dtype=torch.float32, device=device)
        
        # Hacemos el forward pass
        y_pred = model(x)
        
        # Calculamos el error entre el y_pred y el y_decoded
        loss = loss_f(y, y_pred)
        # Acumulamos para después
        train_loss+=loss.item()
        # Hacemos el back propagation (ahora los gradientes de cada parámetro se actualizan)
        loss.backward()
        
        # Actualizamos los parámetros haciendo un paso del optimizador
        optimizer.step()
        
        # Actualizamos el progreso de entrenamiento
        progress_bar.update(x.size()[0])
    val_loss = 0
    # Ponemos el modelo en modo test
    model.eval()
#     for x,y in val_loader:
    for x,y in h5_generator(val_file, source, target, val_samples, batch_size, False):
        x = torch.tensor(x, dtype=torch.float32, device=device)
        y = torch.tensor(y, dtype=torch.float32, device=device)
        
        y_pred = model(x)
        
        loss = loss_f(y,y_pred)
        val_loss += loss.item()
    
    # Ahora las estrategias de cambio de lr
    # Llamamos al scheduler de LR
    scheduler.step(val_loss)
    scheduler.verbose
    
    # Si el lr llega al mínimo, paramos el entrenamiento
    if(optimizer.state_dict()['param_groups'][0]['lr'] <= min_lr):
        # Paramos el entrenamiento
        print("Early stopping")
        break    
        
        
    print("Epoch %d" % (epoch+1))
    print("Train loss: %f / %f" % (train_loss, train_loss/train_samples))
    print("Validation loss: %f / %f" % (val_loss, val_loss/val_samples))
    print()

In [None]:
sample_idx = 0

In [None]:
sample = y
# sample = (sample*train_std[...,:sample.shape[-1]]) + train_mean[...,:sample.shape[-1]]
sample = sample[sample_idx].cpu().detach()
ipd.Audio(sample, rate=44100)

In [None]:
sample = y_pred
# sample = (sample*train_std[...,:sample.shape[-1]]) + train_mean[...,:sample.shape[-1]]
sample = sample[sample_idx].cpu().detach()
ipd.Audio(sample, rate=44100)

In [None]:
sample = x
# sample = (sample*train_std[...,:sample.shape[-1]]) + train_mean[...,:sample.shape[-1]]
sample = sample[sample_idx].cpu().detach()
ipd.Audio(sample, rate=44100)