In [1]:
import torch
import torch.nn as nn
import librosa
import numpy as np
import soundfile as sf
from pathlib import Path
import matplotlib.pyplot as plt
import torchaudio
import os

In [2]:
# Configuration
SAMPLE_RATE = 44100  # Fréquence d'échantillonnage
N_FFT = 2048  # Taille de la FFT
HOP_LENGTH = 512  # Pas de saut pour la STFT
N_MELS = 128

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "../results/models/model.pth"  # Chemin vers votre modèle sauvegardé
SOURCES = ["vocals", "drums", "bass", "other"]  # Noms des sources à séparer
OUTPUT_DIR = "../outputs/"  # Dossier pour sauvegarder les résultats

In [3]:
class SpectrogramSeparator(nn.Module):
    def __init__(self, n_mels=128, seq_len=800, n_sources=4, d_model=256, nhead=8, num_layers=4):
        super().__init__()
        self.n_sources = n_sources
        self.seq_len = seq_len
        self.n_mels = n_mels
        
        # Convolutional encoder
        self.conv_encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        
        # Projection to d_model dimension
        self.flatten_proj = nn.Linear(128 * n_mels, d_model)
        
        # Positional encoding
        self.pos_encoding = nn.Parameter(torch.randn(seq_len, d_model))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=d_model*4,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Transformer decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, 
            nhead=nhead,
            dim_feedforward=d_model*4,
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        
        # Source-specific queries (learnable)
        self.source_queries = nn.Parameter(torch.randn(n_sources, seq_len, d_model))
        
        # Output projection
        self.output_proj = nn.Sequential(
            nn.Linear(d_model, n_mels),
            nn.ReLU()
        )

    def forward(self, x):
        # x: (B, 128, 800)
        B = x.shape[0]
        
        # Add channel dimension and apply CNN
        x = x.unsqueeze(1)  # (B, 1, 128, 800)
        x = self.conv_encoder(x)  # (B, 128, 128, 800)
        
        # Prepare for transformer
        x = x.permute(0, 3, 1, 2)  # (B, T, C, M)
        B, T, C, M = x.shape
        x = x.reshape(B, T, C * M)  # (B, T, C*M)
        x = self.flatten_proj(x)  # (B, T, d_model)
        
        # Add positional encoding
        x = x + self.pos_encoding[:T]
        
        # Transformer encoder
        memory = self.transformer_encoder(x)  # (B, T, d_model)
        
        # Prepare source queries
        queries = self.source_queries.expand(B, -1, -1, -1)  # (B, S, T, d_model)
        S = queries.shape[1]
        queries = queries.reshape(B*S, T, -1)  # (B*S, T, d_model)
        
        # Expand memory for each source
        memory = memory.unsqueeze(1).expand(-1, S, -1, -1)  # (B, S, T, d_model)
        memory = memory.reshape(B*S, T, -1)  # (B*S, T, d_model)
        
        # Transformer decoder
        output = self.transformer_decoder(queries, memory)  # (B*S, T, d_model)
        
        # Project to mel spectrum
        output = self.output_proj(output)  # (B*S, T, n_mels)
        
        # Reshape to final format
        output = output.reshape(B, S, T, self.n_mels)  # (B, S, T, n_mels)
        output = output.permute(0, 1, 3, 2)  # (B, S, n_mels, T)
        
        return output

In [4]:
mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=N_FFT,
    hop_length=HOP_LENGTH,
    n_mels=N_MELS
).to('cuda')

def audio_to_mel_gpu(audio_np):
    # Convert NumPy array to torch tensor on GPU
    audio_tensor = torch.tensor(audio_np, dtype=torch.float32, device='cuda')
    
    if audio_tensor.ndim == 1:
        audio_tensor = audio_tensor.unsqueeze(0)  # Add channel dimension

    mel = mel_transform(audio_tensor)
    mel_db = torchaudio.transforms.AmplitudeToDB()(mel)
    return mel_db.squeeze(0).cpu().numpy()  # Convert back to NumPy

In [5]:
def audio_to_spectrogram(file_path, window_size=800, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS, device=DEVICE):
    """
    Transforme un fichier .mp4 en spectrogramme mel mono et le divise en fenêtres
    
    Args:
        file_path (str): Chemin vers le fichier .mp4/.stem.mp4
        window_size (int): Taille de chaque fenêtre (default: 800)
        n_fft (int): Taille de la FFT
        hop_length (int): Pas de saut pour la STFT
        n_mels (int): Nombre de bandes mel
        device (str): Appareil de calcul ('cpu' ou 'cuda') ou None pour auto-détection
        
    Returns:
        tuple: (windows, sample_rate) - fenêtres de spectrogramme et taux d'échantillonnage
    """
    
    print(f"Traitement du fichier: {file_path}")
    print(f"Utilisation de l'appareil: {device}")
    
    # Déterminer si c'est un fichier stem ou un fichier audio normal
    is_stem = file_path.endswith('.stem.mp4')
    
    if is_stem:
        # Lire les stems avec stempeg
        print("Lecture des stems...")
        try:
            audio_data, sample_rate = stempeg.read_stems(file_path)
            # Utiliser seulement la mixture (premier stem)
            audio = audio_data[0]
        except Exception as e:
            print(f"Erreur lors de la lecture des stems: {e}")
            print("Tentative de lecture comme fichier audio normal...")
            audio, sample_rate = librosa.load(file_path, sr=SAMPLE_RATE, mono=True)
    else:
        # Lire comme fichier audio normal avec librosa
        print("Lecture du fichier audio normal...")
        audio, sample_rate = librosa.load(file_path, sr=SAMPLE_RATE, mono=True)
    
    # Convertir en mono si stéréo
    if audio.ndim > 1:
        print(f"Conversion audio stéréo en mono (forme initiale: {audio.shape})")
        audio = np.mean(audio, axis=0)
    
    print(f"Forme audio après conversion mono: {audio.shape}, SR: {sample_rate}")
    
    # Calculer le spectrogramme mel
    print("Calcul du spectrogramme mel...")
    mel_spec_db = audio_to_mel_gpu(audio)
    print(f"Forme du spectrogramme: {mel_spec_db.shape}")
    
    # Diviser en fenêtres
    print(f"Division en fenêtres de taille {window_size}...")
    num_windows = mel_spec_db.shape[1] // window_size
    windows = []
    
    if num_windows > 0:
        # Tronquer pour être divisible par window_size
        mel_spec_db_truncated = mel_spec_db[:, :num_windows*window_size]
        
        # Diviser en fenêtres
        # problème de divisions euclidienne , ajout de padding !!!
        for i in range(num_windows):
            start_idx = i * window_size
            end_idx = start_idx + window_size
            window = mel_spec_db_truncated[:, start_idx:end_idx]
            windows.append(window)
        
        windows = np.stack(windows)
        print(f"Nombre de fenêtres créées: {len(windows)}, forme: {windows.shape}")
    else:
        print(f"ATTENTION: Spectrogramme trop court pour créer des fenêtres de taille {window_size}")
        # Padding si nécessaire
        if mel_spec_db.shape[1] < window_size:
            padding = window_size - mel_spec_db.shape[1]
            mel_spec_db_padded = np.pad(mel_spec_db, ((0, 0), (0, padding)), mode='constant')
            windows = np.stack([mel_spec_db_padded])
            print(f"Spectrogramme padded à la taille {mel_spec_db_padded.shape}, 1 fenêtre créée")
    
    return windows

In [6]:
file_path = "../inputs/mysong.mp4"
output_dir = "../outputs/"

# Transformer le fichier en spectrogrammes et les diviser en fenêtres
windows = audio_to_spectrogram('../inputs/mysong.mp4', window_size=800, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS, device=DEVICE)



Traitement du fichier: ../inputs/mysong.mp4
Utilisation de l'appareil: cuda
Lecture du fichier audio normal...


  audio, sample_rate = librosa.load(file_path, sr=SAMPLE_RATE, mono=True)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


Forme audio après conversion mono: (9604096,), SR: 44100
Calcul du spectrogramme mel...
Forme du spectrogramme: (128, 18759)
Division en fenêtres de taille 800...
Nombre de fenêtres créées: 23, forme: (23, 128, 800)


In [7]:
MODEL_PATH = '../results/models/model.pth'

checkpoint = torch.load(MODEL_PATH)

model = SpectrogramSeparator()
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()

SpectrogramSeparator(
  (conv_encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
  )
  (flatten_proj): Linear(in_features=16384, out_features=256, bias=True)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=1024, bias

In [8]:
y = []
for window in windows:
    x = torch.from_numpy(window)
    x = torch.unsqueeze(x, 0).to(DEVICE)
    y_pred = model.forward(x)
    y.append(y_pred)

In [9]:
y_cpu = [tensor.detach().cpu() for tensor in y]
y_numpy = np.concatenate(y_cpu, axis=0)

In [10]:
transposed = np.transpose(y_numpy, (1, 2, 0, 3))  # Devient (4, 128, 23, 800)
merged = transposed.reshape(4, 128, 23 * 800)

In [11]:
merged.shape

(4, 128, 18400)

In [21]:
import numpy as np
import IPython.display as ipd
import io
import soundfile as sf

ipd.Audio(merged[0], rate=44100)