In [1]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

In [2]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict, Any
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
data_root = os.path.join(project_root, "data", "processed")
train_dir = os.path.join(data_root, "train")
val_dir = os.path.join(data_root, "val")

train_files = [os.path.join(train_dir, f) for f in os.listdir(train_dir)]
test_files = [os.path.join(val_dir, f) for f in os.listdir(val_dir)]

train_files.sort()
test_files.sort()

train_set, val_set = train_test_split(
    train_files[: int(0.5 * len(train_files))],
    test_size=0.4,
    shuffle=True,
    random_state=42,
)

In [4]:
cont = 0
for sample in train_set:
    chunks = torch.load(sample)
    cont += len(chunks)

print(f"Train set has {cont} chunks")

Train set has 2940 chunks


In [5]:
# Initialize total memory usage
total_memory = 0

for sample in train_set:
    chunks = torch.load(sample)  # Load the sample
    for chunk in chunks:
        for tensor in chunk.values():
            total_memory += tensor.element_size() * tensor.nelement()

# Convert to MB (1 MB = 1024 * 1024 bytes)
total_memory_mb = total_memory / (1024 * 1024)

print(f"Total memory usage: {total_memory_mb:.2f} MB")

Total memory usage: 3675.00 MB


In [6]:
class ChunkDataset(Dataset):
    def __init__(self, file_list: List[str]):
        self.data = []
        for file in file_list:
            chunks = torch.load(file)  # Cargamos los tensores
            self.data.extend(
                chunks
            )  # Extendemos la lista con los diccionarios cargados

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        return self.data[idx]

In [7]:
train_dataset = ChunkDataset(train_set)
val_dataset = ChunkDataset(val_set)

# Crear DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

        self.decoder = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, out_channels, kernel_size=3, padding=1),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [9]:
# Función de entrenamiento
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=10):
    model.train()

    for epoch in range(epochs):
        train_loss = 0.0
        for batch in train_loader:
            mixture = batch["mixture"].to(device)  # Entrada
            target = batch["bass"].to(device)  # Salida deseada (ejemplo: bajo)

            optimizer.zero_grad()
            output = model(mixture)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Validación
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                mixture = batch["mixture"].to(device)
                target = batch["bass"].to(device)
                output = model(mixture)
                loss = criterion(output, target)
                val_loss += loss.item()

        print(
            f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}"
        )

In [10]:
def si_sdr_loss(
    estimate: torch.Tensor, target: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
    """
    Calcula la pérdida SI-SDR (Scale-Invariant Signal-to-Distortion Ratio) para un batch de señales.

    Parámetros:
        estimate (torch.Tensor): Señal estimada. Puede tener forma (batch, 1, time) o (batch, time).
        target (torch.Tensor): Señal objetivo. Debe tener la misma forma que estimate.
        eps (float): Pequeño valor para evitar divisiones por cero.

    Retorna:
        torch.Tensor: Pérdida SI-SDR negativa (se minimiza al maximizar SI-SDR).
    """
    # Si las tensores incluyen la dimensión de canal, se elimina (suponiendo canal único)
    if estimate.dim() == 3:
        estimate = estimate.squeeze(1)
    if target.dim() == 3:
        target = target.squeeze(1)

    # Eliminar la media para obtener la versión "zero-mean" de las señales (invarianza a la ganancia)
    estimate = estimate - torch.mean(estimate, dim=1, keepdim=True)
    target = target - torch.mean(target, dim=1, keepdim=True)

    # Calcular el factor de escala óptimo para proyectar la señal estimada sobre la señal objetivo
    scale = torch.sum(estimate * target, dim=1, keepdim=True) / (
        torch.sum(target**2, dim=1, keepdim=True) + eps
    )
    projection = scale * target

    # Calcular el error (ruido) como la diferencia entre la señal estimada y su proyección
    noise = estimate - projection

    # Calcular la relación de energía (potencia) entre la proyección y el error
    ratio = torch.sum(projection**2, dim=1) / (torch.sum(noise**2, dim=1) + eps)

    # Calcular SI-SDR en dB
    si_sdr = 10 * torch.log10(ratio + eps)

    # Queremos maximizar SI-SDR, por lo que usamos la pérdida negativa
    loss = -torch.mean(si_sdr)
    return loss

In [11]:
model = UNet().to(device)
criterion = si_sdr_loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train_model(model, train_loader, val_loader, criterion, optimizer, epochs=10)

NotImplementedError: Module [UNet] is missing the required "forward" function