In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchaudio
import pretty_midi
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, recall_score

In [2]:
class FullPipelineDataset(Dataset):
    def __init__(self, wav_folder, separated_folder, midi_folder, sr=22050, n_mels=128, fs_midi=50,
                 fixed_full_length=200, fixed_midi_length=50):
        self.wav_folder = wav_folder
        self.separated_folder = separated_folder
        self.midi_folder = midi_folder
        self.sr = sr
        self.n_mels = n_mels
        self.fs_midi = fs_midi
        self.fixed_full_length = fixed_full_length
        self.fixed_midi_length = fixed_midi_length
        
        self.wav_files = sorted([f for f in os.listdir(wav_folder) if f.lower().endswith('.wav')])
        self.mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_mels=n_mels)
        
    def pad_or_crop_tensor(self, tensor, target_length):
        current_length = tensor.size(-1)
        if current_length > target_length:
            return tensor[..., :target_length]
        elif current_length < target_length:
            pad_amount = target_length - current_length
            return F.pad(tensor, (0, pad_amount))
        else:
            return tensor

    def midi_to_piano_roll(self, midi_path):
        pm = pretty_midi.PrettyMIDI(midi_path)
        piano_roll = pm.get_piano_roll(fs=self.fs_midi)  # [128, T]
        piano_roll = (piano_roll > 0).astype(np.float32)
        return piano_roll

    def __len__(self):
        return len(self.wav_files)

    def __getitem__(self, idx):
        wav_filename = self.wav_files[idx]
        base_name = os.path.splitext(wav_filename)[0]
        wav_path = os.path.join(self.wav_folder, wav_filename)
        waveform, sr = torchaudio.load(wav_path)
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        full_spec = self.mel_transform(waveform)
        full_spec = full_spec.log2().clamp(min=-10)
        full_spec = self.pad_or_crop_tensor(full_spec, self.fixed_full_length)
        
        instruments = ['vocals', 'bass', 'drums', 'other']
        separated_specs = {}
        sep_folder = os.path.join(self.separated_folder, base_name)
        for inst in instruments:
            sep_path = os.path.join(sep_folder, f"{inst}.wav")
            if not os.path.exists(sep_path):
                sep_spec = torch.zeros(1, self.n_mels, self.fixed_full_length)
            else:
                sep_wave, _ = torchaudio.load(sep_path)
                if sep_wave.shape[0] > 1:
                    sep_wave = sep_wave.mean(dim=0, keepdim=True)
                sep_spec = self.mel_transform(sep_wave)
                sep_spec = sep_spec.log2().clamp(min=-10)
                sep_spec = self.pad_or_crop_tensor(sep_spec, self.fixed_full_length)
            separated_specs[inst] = sep_spec

        midi_piano_rolls = {}
        midi_subfolder = os.path.join(self.midi_folder, base_name)
        for inst in instruments:
            midi_path = os.path.join(midi_subfolder, f"{inst}.mid")
            if not os.path.exists(midi_path):
                raise FileNotFoundError(f"MIDI no encontrado: {midi_path}")
            pr = self.midi_to_piano_roll(midi_path)
            pr_tensor = torch.tensor(pr)
            pr_tensor = self.pad_or_crop_tensor(pr_tensor, self.fixed_midi_length)
            midi_piano_rolls[inst] = pr_tensor

        return full_spec, separated_specs, midi_piano_rolls

In [3]:
# CNN
class FullPipelineCNN(nn.Module):
    def __init__(self, output_dims, input_shape):
        super(FullPipelineCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.bn1   = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2   = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(2)

        _, n_mels, tiempo = input_shape
        fm_height = n_mels // 4
        fm_width  = tiempo // 4

        self.fc_common = nn.Linear(32 * fm_height * fm_width, 512)
        self.branch_vocals = nn.Linear(512, output_dims['vocals'])
        self.branch_bass   = nn.Linear(512, output_dims['bass'])
        self.branch_drums  = nn.Linear(512, output_dims['drums'])
        self.branch_other  = nn.Linear(512, output_dims['other'])

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc_common(x))
        vocals = torch.sigmoid(self.branch_vocals(x))
        bass   = torch.sigmoid(self.branch_bass(x))
        drums  = torch.sigmoid(self.branch_drums(x))
        other  = torch.sigmoid(self.branch_other(x))
        return {'vocals': vocals, 'bass': bass, 'drums': drums, 'other': other}


In [4]:
#evaluación
def evaluate_model(model, dataloader, criterion, device='cpu', return_loss=False):
    model.eval()
    total_loss = 0.0
    total_batches = 0
    instruments = ['vocals', 'bass', 'drums', 'other']
    preds_all = {inst: [] for inst in instruments}
    targets_all = {inst: [] for inst in instruments}
    with torch.no_grad():
        for full_spec, separated_specs, midi_piano_rolls in dataloader:
            full_spec = full_spec.to(device)
            outputs = model(full_spec)
            batch_loss = 0.0
            for inst in instruments:
                target = midi_piano_rolls[inst].view(outputs[inst].size()).to(device)
                batch_loss += criterion(outputs[inst], target)
                preds_all[inst].append(outputs[inst].cpu())
                targets_all[inst].append(target.cpu())
            total_loss += batch_loss.item()
            total_batches += 1
    avg_loss = total_loss / total_batches
    print(f"Average Loss: {avg_loss:.4f}")
    for inst in instruments:
        preds = torch.cat(preds_all[inst], dim=0).view(-1).numpy()
        targs = torch.cat(targets_all[inst], dim=0).view(-1).numpy()
        preds_bin = (preds >= 0.5).astype(int)
        acc = accuracy_score(targs, preds_bin)
        rec = recall_score(targs, preds_bin, zero_division=0)
        f1 = f1_score(targs, preds_bin, zero_division=0)
        print(f"{inst.capitalize()} -> Accuracy: {acc:.4f}, Recall: {rec:.4f}, F1 Score: {f1:.4f}")
    if return_loss:
        return avg_loss

In [5]:
# entrenamiento, Early Stopping
def train_model(model, dataloader, optimizer, criterion, num_epochs=10, device='cpu', val_loader=None, early_stop_patience=5, min_delta=0.001):
    model.to(device)
    best_val_loss = float('inf')
    epochs_no_improve = 0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for full_spec, separated_specs, midi_piano_rolls in dataloader:
            full_spec = full_spec.to(device)
            optimizer.zero_grad()
            outputs = model(full_spec)
            loss = 0.0
            for inst in ['vocals', 'bass', 'drums', 'other']:
                target = midi_piano_rolls[inst].view(outputs[inst].size()).to(device)
                loss += criterion(outputs[inst], target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_loss:.4f}")
        if val_loader is not None:
            print("Validation metrics:")
            val_loss = evaluate_model(model, val_loader, criterion, device=device, return_loss=True)
            if best_val_loss - val_loss > min_delta:
                best_val_loss = val_loss
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
            print(f"Validation Loss: {val_loss:.4f} | Best: {best_val_loss:.4f} | Patience: {epochs_no_improve}/{early_stop_patience}")
            if epochs_no_improve >= early_stop_patience:
                print("Early stopping triggered!")
                break
    return model

In [6]:
# Parámetros, ruta
wav_folder = "F:\\TG MINTA\\Audios\\dataset\\wav"
separated_folder = "F:\\TG MINTA\\Audios\\dataset\\separated"
midi_folder = "F:\\TG MINTA\\Audios\\dataset\\midi"

sr = 22050
n_mels = 128
fs_midi = 50
fixed_full_length = 200   # Frames espectrograma
fixed_midi_length = 50    # Frames piano roll
input_shape = (1, n_mels, fixed_full_length)
output_dims = {
    'vocals': 128 * fixed_midi_length,
    'bass': 128 * fixed_midi_length,
    'drums': 128 * fixed_midi_length,
    'other': 128 * fixed_midi_length
}

dataset = FullPipelineDataset(wav_folder, separated_folder, midi_folder,
                              sr, n_mels, fs_midi, fixed_full_length, fixed_midi_length)
total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

model = FullPipelineCNN(output_dims, input_shape)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 20

model = train_model(model, train_loader, optimizer, criterion, num_epochs=num_epochs, device=device, val_loader=val_loader, early_stop_patience=5, min_delta=0.001)
print("Test metrics:")
evaluate_model(model, test_loader, criterion, device=device)



Epoch 1/20 - Train Loss: 1.9237
Validation metrics:
Average Loss: 0.2274
Vocals -> Accuracy: 0.9848, Recall: 0.2609, F1 Score: 0.1100
Bass -> Accuracy: 0.9861, Recall: 0.1868, F1 Score: 0.0337
Drums -> Accuracy: 0.9868, Recall: 0.3051, F1 Score: 0.0719
Other -> Accuracy: 0.9816, Recall: 0.1066, F1 Score: 0.0500
Validation Loss: 0.2274 | Best: 0.2274 | Patience: 0/5
Epoch 2/20 - Train Loss: 0.0972
Validation metrics:
Average Loss: 0.0977
Vocals -> Accuracy: 0.9968, Recall: 0.1976, F1 Score: 0.3086
Bass -> Accuracy: 0.9983, Recall: 0.1868, F1 Score: 0.2222
Drums -> Accuracy: 0.9985, Recall: 0.3475, F1 Score: 0.4409
Other -> Accuracy: 0.9952, Recall: 0.0533, F1 Score: 0.0921
Validation Loss: 0.0977 | Best: 0.0977 | Patience: 0/5
Epoch 3/20 - Train Loss: 0.0641
Validation metrics:
Average Loss: 0.0749
Vocals -> Accuracy: 0.9969, Recall: 0.2451, F1 Score: 0.3615
Bass -> Accuracy: 0.9985, Recall: 0.1648, F1 Score: 0.2256
Drums -> Accuracy: 0.9984, Recall: 0.2373, F1 Score: 0.3333
Other -> Ac