In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm 

class SpectrogramSeparator(nn.Module):
    def __init__(self, n_mels=128, seq_len=20000, n_sources=4, d_model=256, nhead=8, num_layers=4):
        super().__init__()
        self.n_sources = n_sources
        self.seq_len = seq_len
        self.n_mels = n_mels
        
        # Convolutional encoder
        self.conv_encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        
        # Projection to d_model dimension
        self.flatten_proj = nn.Linear(128 * n_mels, d_model)
        
        # Positional encoding
        self.pos_encoding = nn.Parameter(torch.randn(seq_len, d_model))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=d_model*4,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Transformer decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, 
            nhead=nhead,
            dim_feedforward=d_model*4,
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        
        # Source-specific queries (learnable)
        self.source_queries = nn.Parameter(torch.randn(n_sources, seq_len, d_model))
        
        # Output projection
        self.output_proj = nn.Sequential(
            nn.Linear(d_model, n_mels),
            nn.ReLU()
        )

    def forward(self, x):
        # x: (B, 128, 800)
        B = x.shape[0]
        
        # Add channel dimension and apply CNN
        x = x.unsqueeze(1)  # (B, 1, 128, 800)
        x = self.conv_encoder(x)  # (B, 128, 128, 800)
        
        # Prepare for transformer
        x = x.permute(0, 3, 1, 2)  # (B, T, C, M)
        B, T, C, M = x.shape
        x = x.reshape(B, T, C * M)  # (B, T, C*M)
        x = self.flatten_proj(x)  # (B, T, d_model)
        
        # Add positional encoding
        x = x + self.pos_encoding[:T]
        
        # Transformer encoder
        memory = self.transformer_encoder(x)  # (B, T, d_model)
        
        # Prepare source queries
        queries = self.source_queries.expand(B, -1, -1, -1)  # (B, S, T, d_model)
        S = queries.shape[1]
        queries = queries.reshape(B*S, T, -1)  # (B*S, T, d_model)
        
        # Expand memory for each source
        memory = memory.unsqueeze(1).expand(-1, S, -1, -1)  # (B, S, T, d_model)
        memory = memory.reshape(B*S, T, -1)  # (B*S, T, d_model)
        
        # Transformer decoder
        output = self.transformer_decoder(queries, memory)  # (B*S, T, d_model)
        
        # Project to mel spectrum
        output = self.output_proj(output)  # (B*S, T, n_mels)
        
        # Reshape to final format
        output = output.reshape(B, S, T, self.n_mels)  # (B, S, T, n_mels)
        output = output.permute(0, 1, 3, 2)  # (B, S, n_mels, T)
        
        return output

In [2]:
import torch
from torch.utils.data import Dataset

class SpectrogramDataset(Dataset):
    def __init__(self, X, y):
        """
        Args:
            X (np.ndarray): mixture spectrograms, shape (B, 128, 800)
            y (np.ndarray): source spectrograms, shape (B, 4, 128, 800)
        """
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


In [None]:
from torch.utils.data import DataLoader
import numpy as np

X_train = np.load('../data/processed/X_train.npy')
y_train = np.load('../data/processed/y_train.npy')
X_test = np.load('../data/processed/X_test.npy')
y_test = np.load('../data/processed/y_test.npy')

# Charger ton dataset
train_dataset = SpectrogramDataset(X_train, y_train)
test_dataset = SpectrogramDataset(X_test, y_test)

# Créer les DataLoaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)


In [4]:
class SI_SDR_Loss(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.eps = eps

    def forward(self, predictions, targets):
        """
        Args:
            predictions: tensor of shape [B, S, F, T] (batch, sources, freq_bins, time)
            targets: tensor of shape [B, S, F, T]
        Returns:
            SI-SDR loss (negative SI-SDR for minimization)
        """
        # Reshape to [B*S, F*T]
        B, S, F, T = predictions.shape
        predictions = predictions.reshape(B*S, -1)
        targets = targets.reshape(B*S, -1)

        # Zero-mean normalization
        predictions = predictions - torch.mean(predictions, dim=-1, keepdim=True)
        targets = targets - torch.mean(targets, dim=-1, keepdim=True)

        # Calculate SI-SDR
        alpha = (torch.sum(predictions * targets, dim=-1, keepdim=True) + self.eps) / (
            torch.sum(targets ** 2, dim=-1, keepdim=True) + self.eps)
        scaled_target = alpha * targets

        si_sdr = torch.sum(scaled_target ** 2, dim=-1) / (
            torch.sum((predictions - scaled_target) ** 2, dim=-1) + self.eps)
        si_sdr = 10 * torch.log10(si_sdr + self.eps)

        # Return negative mean for loss minimization
        return -si_sdr.mean()

In [None]:
model = SpectrogramSeparator().to('cuda')
criterion = SI_SDR_Loss().to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

num_epochs = 16
best_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
    
    for X_batch, y_batch in pbar:
        X_batch = X_batch.to(torch.float32).to('cuda')
        y_batch = y_batch.to(torch.float32).to('cuda')
        
        optimizer.zero_grad()
        output = model.forward(X_batch)
        loss = criterion(output, y_batch)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        pbar.set_postfix({"SI-SDR Loss": f"{loss.item():.4f}"})
    
    avg_loss = epoch_loss / len(train_loader)

    # Save best model
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, '../results/models/model.pth')

Epoch 1/16:   0%|                                                                            | 0/100 [00:00<?, ?batch/s]