In [37]:
import numpy as np
import musdb
import IPython.display as ipd
import openunmix as opmux
import torch

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

import musdb
import museval

from tqdm.autonotebook import tqdm

# Setup

In [2]:
musdb_path = "/home/paco/TFM/data/MUSDB18/"
data_path = "/home/paco/TFM/data/"

In [3]:
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7fe73804a110>

In [4]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hacemos inferencia en cpu, no consigo gestionar bien la memoria en gpu y estoy ahorrando un total de 5 minutos de procesamiento que va a ser una sola vez
device = "cpu"

In [5]:
print("Num GPUs Available: ", torch.cuda.get_device_name())
print("GPU available:", torch.cuda.is_available())

Num GPUs Available:  GeForce RTX 3060 Laptop GPU
GPU available: True


In [6]:
batch_size = 256
subtrack_length = 3
prediction_overlap = 0.25
source = 'source'
targets = ['vocals', 'drums', 'bass', 'other']
target = targets[0]
# 8192 de training
train_samples = 2**13
# Todas las muestras de validación (1227)
val_samples = 2**11

In [7]:
test_dset =  musdb.DB(root=musdb_path, subsets=['test'])

# Definimos la arquitectura de la red...

### Primero, la transformada de Fourier para codificar/decodificar

In [8]:
sftf_window_size = 4096
stft_window_hop = 1024
stft_center = True

In [9]:
class RecSepSTFT_3(torch.nn.Module): # Buenos resultados con lr inicial = 1e-2
    def __init__(self):
        super().__init__()
        # Usamos la norma del complejo como hacen en OpenUnmix
        self.complex_norm = opmux.transforms.ComplexNorm()
        
        self.encoder_stft = opmux.transforms.TorchSTFT(n_fft=sftf_window_size, n_hop=stft_window_hop, center=stft_center)
        self.decoder_stft = opmux.transforms.TorchISTFT(n_fft=sftf_window_size, n_hop=stft_window_hop, center=stft_center)
        
        self.fc1 = torch.nn.Linear(((int(sftf_window_size/2)+1)*2), 256)
        self.bn1 = torch.nn.BatchNorm1d(256)
        
        
        self.lstm = torch.nn.LSTM(input_size=256,
                                  hidden_size=int(256/2),
                                  num_layers=3, 
                                  batch_first=True, 
                                  dropout=0, 
                                  bidirectional=True,
                                  proj_size=0)
        #self.fc_lstm_res = torch.nn.Linear(256*2,1024)
        self.dropout = torch.nn.Dropout(0)
        
        self.fc2 = torch.nn.Linear(256, 256)
        self.bn2 = torch.nn.BatchNorm1d(256)
        
        self.fc3 = torch.nn.Linear(256,256)
        self.bn3 = torch.nn.BatchNorm1d(256)
        
        self.fc4 = torch.nn.Linear(256,(int(sftf_window_size/2)+1)*2)
        self.bn4 = torch.nn.BatchNorm1d((int(sftf_window_size/2)+1)*2)

    def forward(self, x):
        pre_mix = x
        # Calculamos la norma compleja (pasamos a dominio real)
        x = self.complex_norm(x)
        
        # Tenemos una entrada en formato: (batch, canal, feature/banda, secuencia)
        b_size, n_channel, n_feat, seq_len = x.size()

        # Vamos a hacer un primer paso de codificación, para ello tenemos que dejar los datos en forma (batch, secuencia, features)
        # Permutamos los datos a formato (batch, secuencia, canal, banda)
        x = torch.permute(x, (0,3,1,2))
        # Pasamos las dos últimas dimensiones (ahora son canal, banda) a una única ("desenrollamos" las features de cada canal en uno solo)
        x = torch.reshape(x, (b_size, seq_len, n_channel * n_feat))
        # Ponemos una capa fully connected, batch norm, y activación
        x = self.fc1(x)
        # Para el batch norm hay que tener el tensor en formato (batch, features, sequence)
        x = torch.permute(x, (0,2,1))
        # Batch norm
        x = self.bn1(x)
        # Ahora tenemos los datos en formato  (batch, features, sequence)
        # La lstm los necesita en formato (batch, sequence, features)
        x = torch.permute(x, (0,2,1))
        # Activación
        x = torch.nn.functional.relu(x)
        x_skip_1 = x
        
        x = self.dropout(x)
        
        # Nos quedamos con los estados de cada step de la secuencia
        x_skip_lstm = x
        x=self.lstm(x)[0]
        x = x_skip_lstm + x
        x = torch.nn.functional.relu(x)
        
        # Ahora tenemos los datos en formato (batch, sequence, features)
        # Aplicamos fc, bn, activation de nuevo
        x = self.fc2(x)
        x = torch.permute(x, (0,2,1))
        x = self.bn2(x)
        x = torch.permute(x, (0,2,1))
        x = torch.nn.functional.relu(x)
        x_skip_2 = x
        
        # Aplicamos una capa fc más para conseguir que features tenga un tamaño compatible con (canales, bandas_fourier_iniciales)
        x = self.fc3(x)
        x = torch.permute(x, (0,2,1))
        x = self.bn3(x)
        x = torch.permute(x, (0,2,1))
        x = torch.nn.functional.relu(x)
        
        x = x + x_skip_1
        x = x + x_skip_2
        
        x = self.fc4(x)
        x = torch.permute(x, (0,2,1))
        x = self.bn4(x)
        x = torch.permute(x, (0,2,1))
        
        # Redistribuimos y giramos para dejar los datos en formato: (batch, canal, feature/banda, secuencia, complejo)
        x = torch.reshape(x, (b_size, seq_len, n_channel, n_feat, 1))
        x = torch.permute(x, (0,2,3,1,4))
        
        # Aplicamos una sigmoidal para obtener una soft mask, y la aplicamos a la entrada
        # x = torch.sigmoid(x)
        
        # Aplicamos x como una máscara a la stft de la entrada sin procesar
        x = pre_mix * x
        
        return x

In [10]:
# Cargamos el modelo y lo ponemos en modo inferencia
model = torch.load(data_path+('model_%s.pt'%(target))).to(device)
model.eval()

RecSepSTFT_3(
  (complex_norm): ComplexNorm()
  (encoder_stft): TorchSTFT()
  (decoder_stft): TorchISTFT()
  (fc1): Linear(in_features=4098, out_features=256, bias=True)
  (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lstm): LSTM(256, 128, num_layers=3, batch_first=True, bidirectional=True)
  (dropout): Dropout(p=0, inplace=False)
  (fc2): Linear(in_features=256, out_features=256, bias=True)
  (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=256, out_features=256, bias=True)
  (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc4): Linear(in_features=256, out_features=4098, bias=True)
  (bn4): BatchNorm1d(4098, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [11]:
# Testing
def test_model(in_model, in_data, in_target, in_seq_length = 3 * 44100, overlap_step=1, pred_batch_size = 128):
    # Aseguramos que estamos en modo de evaluación
    in_model = in_model.eval()

    pred_tracks = []
    # Para cada track en data...
    for track in tqdm(in_data, desc='Estimando %s' % in_target):
        # Extraemos el audio de la mezcla
        full_X = track.audio.T
        # Extraemos el audio del target
        true_y = track.targets[in_target].audio.T
        
        x = []
        x_indices = []

        # Extraemos las muestras a predecir (la primera dimensión es el canal*)
        for idx in np.arange(0,full_X.shape[1], int(overlap_step * in_seq_length)):
            subtrack_padded = np.zeros((full_X.shape[0], in_seq_length))
            subtrack = full_X[:, idx:(idx+in_seq_length)]
            subtrack_padded[:, :subtrack.shape[1]] = subtrack
            x.append(subtrack_padded)
            x_indices.append(idx)
        
        # Pasamos las muestras por el modelo para obtener la predicción
        # Creamos el tensor de salida
        pred_y = torch.zeros(true_y.shape, dtype=torch.float32)
        pred_y_samples = torch.zeros_like(pred_y)
        
        # Pasamos las muestras a tensor de Torch
        x = torch.tensor(x, dtype=torch.float32)

        # Iteramos sobre los batch, realizando la predicción en cada uno
        for idx in np.arange(0,x.shape[0], batch_size):
            # Sacamos el batch
            batch_pred = x[idx:idx+batch_size]
            batch_indices = x_indices[idx:idx+batch_size]
            
            # Lo copiamos a device, codificamos con STFT
            batch_pred = batch_pred.to(device)
            batch_pred = model.encoder_stft(batch_pred)
            # Predecimos
            batch_pred = model(batch_pred)
            # Decodificamos con stft
            batch_pred = model.decoder_stft(batch_pred, length=in_seq_length)
            # Lo copiamos a cpu de nuevo
            batch_pred = batch_pred.detach().cpu()
            
            # Añadimos sobre el array completo de predicción
            for ii in np.arange(batch_pred.shape[0]):
                idx_start = batch_indices[ii]
                idx_end = idx_start+batch_pred.shape[-1]
                # Lo limitamos al final del array
                idx_end = idx_end if (idx_end < pred_y.shape[1]) else pred_y.shape[1]
                    
                pred_partial = batch_pred[ii]
                pred_y[:, idx_start:(idx_start+pred_partial.shape[1])] += pred_partial[:,0:(idx_end-idx_start)]
                pred_y_samples[:, idx_start:(idx_start+pred_partial.shape[-1])] += 1
        torch.cuda.empty_cache()
        pred_y_samples[pred_y_samples == 0] = 1
        pred_y = pred_y / pred_y_samples
        
        pred_tracks.append(pred_y.detach())
        
    return pred_tracks
        

In [12]:
# %%time
# pred_samples = test_model(model, test_dset, target, subtrack_length * test_dset[0].rate, overlap_step=prediction_overlap, pred_batch_size=batch_size)

In [95]:
sample_idx = 7

In [98]:
# sample = test_dset[sample_idx].audio.T
# sample = sample[sample_idx].cpu().detach()
# ipd.Audio(sample, rate=44100)

In [99]:
# sample = estimates['vocals'][sample_idx]
# sample = sample[sample_idx].cpu().detach()
# ipd.Audio(sample.numpy(), rate=44100)

In [16]:
# Ahora lo hacemos para todos los targets
estimates = {}

for target in tqdm(targets):
    # Cargamos el modelo y lo ponemos en modo inferencia
    model = torch.load(data_path+('model_%s.pt'%(target))).to(device)
    model.eval()
    
    # Sacamos las estimaciones de target
    target_preds = test_model(model, test_dset, target, subtrack_length * test_dset[0].rate, overlap_step=prediction_overlap, pred_batch_size=batch_size)
    
    # Las guardamos en estimates
    estimates[target] = target_preds

  0%|          | 0/4 [00:00<?, ?it/s]

Estimando vocals:   0%|          | 0/50 [00:00<?, ?it/s]

Estimando drums:   0%|          | 0/50 [00:00<?, ?it/s]

Estimando bass:   0%|          | 0/50 [00:00<?, ?it/s]

Estimando other:   0%|          | 0/50 [00:00<?, ?it/s]

In [117]:
for idx in tqdm(np.arange(len(test_dset))):
    test_dset.save_estimates(user_estimates={target:(estimates.get(target)[idx].numpy().T) for target in targets}, track=test_dset[idx], estimates_dir = data_path+'MUSDB18_estimates')

  0%|          | 0/50 [00:00<?, ?it/s]

In [30]:
%%time
results = museval.EvalStore(frames_agg='median', tracks_agg='median')

for idx, track in tqdm(enumerate(test_dset), total=len(test_dset)):
    track_estimates = {target:(estimates.get(target)[idx].numpy().T) for target in targets}
    results.add_track(museval.eval_mus_track(track, track_estimates))

  0%|          | 0/50 [00:00<?, ?it/s]

CPU times: user 1h 12min 55s, sys: 13min 31s, total: 1h 26min 27s
Wall time: 1h 24min 12s


In [33]:
results

Aggrated Scores (median over frames, median over tracks)
vocals          ==> SDR:   4.172  SIR:   6.709  ISR:   9.054  SAR:   6.139  
drums           ==> SDR:   3.746  SIR:   6.284  ISR:   6.524  SAR:   5.314  
bass            ==> SDR:   3.400  SIR:   4.411  ISR:   7.096  SAR:   5.205  
other           ==> SDR:   2.560  SIR:   2.768  ISR:   4.989  SAR:   4.495  

In [40]:
comparisons = museval.MethodStore()

In [34]:
comparisons.add_evalstore(results, name="RRSEP")

In [None]:
# comparisons.save(data_path+'compared_results_dataframe')

In [43]:
comparisons = museval.MethodStore()
comparisons.add_sisec18()
comparisons.add_evalstore(results, name="RRSEP")

Downloading SISEC18 Evaluation data...
Done!


In [53]:
agg_df[agg_df.method == 'RRSEP'].groupby('target')['score'].median()

target
bass      4.875210
drums     5.555230
other     3.896905
vocals    6.099680
Name: score, dtype: float64

In [54]:
agg_df.metric.unique()

array(['ISR', 'SAR', 'SDR', 'SIR'], dtype=object)

In [63]:
agg_df = comparisons.agg_frames_scores().reset_index()


In [62]:
agg_df

Unnamed: 0,method,track,target,metric,score
0,2DFT,AM Contra - Heart Peripheral,accompaniment,ISR,12.456640
1,2DFT,AM Contra - Heart Peripheral,accompaniment,SAR,8.595240
2,2DFT,AM Contra - Heart Peripheral,accompaniment,SDR,8.582600
3,2DFT,AM Contra - Heart Peripheral,accompaniment,SIR,22.414620
4,2DFT,AM Contra - Heart Peripheral,vocals,ISR,17.195920
...,...,...,...,...,...
24995,WK,Zeno - Signs,other,SIR,6.536005
24996,WK,Zeno - Signs,vocals,ISR,8.265570
24997,WK,Zeno - Signs,vocals,SAR,4.273505
24998,WK,Zeno - Signs,vocals,SDR,4.499505


In [111]:
all_metrics = ['SDR', 'SIR', 'SAR', 'ISR']
all_targets = ['vocals', 'drums', 'bass', 'other']

for metric in all_metrics:
    for target in all_targets:
        metrics = [metric,]
        selected_targets = [target,]

        agg_df = comparisons.agg_frames_scores().reset_index()

        sns.set()
        sns.set_context("notebook")


        oracles = [
            'IBM1', 'IBM2', 'IRM1', 'IRM2', 'MWF', 'IMSK'
        ]

        # Convert to Pandas Dataframes
        agg_df['oracle'] = agg_df.method.isin(oracles)
        agg_df = agg_df[agg_df.target.isin(selected_targets)].dropna()

        # Get sorting keys (sorted by median of SDR:vocals)
        df_sort_by = agg_df[
            (agg_df.metric == metrics[0]) &
            (agg_df.target == selected_targets[0])
        ]

        methods_by_sdr = df_sort_by.score.groupby(
            df_sort_by.method
        ).median().sort_values().index.tolist()

        # df = df[df.target == "vocals"]
        g = sns.FacetGrid(
            agg_df,
            row="target",
            col="metric",
            row_order=selected_targets,
            col_order=metrics,
            height=6,
            sharex=False,
            aspect=3
        )
        g = (g.map(
            sns.boxplot,
            "score",
            "method",
            "oracle",
            orient='h',
            order=methods_by_sdr[::-1],
            hue_order=[True, False],
            showfliers=False,
            notch=True
        ))

        g.ax.artists[np.where(np.array(methods_by_sdr[::-1]) == 'RRSEP')[0][0]].set_facecolor('orange')

        g.fig.tight_layout()
        plt.subplots_adjust(hspace=0.2, wspace=0.1)
        g.fig.savefig(
            data_path+"plots/boxplot_%s_%s.png" % (metrics[0],selected_targets[0]),
            bbox_inches='tight',
        )

  fig, axes = plt.subplots(nrow, ncol, **kwargs)
