In [1]:
import pandas
import numpy 
import torch

In [18]:
pip install tqdm

Collecting tqdm
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Downloading tqdm-4.67.1-py3-none-any.whl (78 kB)
Installing collected packages: tqdm
Successfully installed tqdm-4.67.1
Note: you may need to restart the kernel to use updated packages.


In [19]:
import tqdm

In [2]:
import torchaudio

In [6]:
pip install torchcodec

Collecting torchcodec
  Downloading torchcodec-0.8.1-cp313-cp313-manylinux_2_28_x86_64.whl.metadata (9.7 kB)
Downloading torchcodec-0.8.1-cp313-cp313-manylinux_2_28_x86_64.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m5.9 MB/s[0m  [33m0:00:00[0mm eta [36m0:00:01[0m
[?25hInstalling collected packages: torchcodec
Successfully installed torchcodec-0.8.1
Note: you may need to restart the kernel to use updated packages.


In [9]:
pip install soundfile

Collecting soundfile
  Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl.metadata (16 kB)
Collecting cffi>=1.0 (from soundfile)
  Downloading cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.6 kB)
Collecting pycparser (from cffi>=1.0->soundfile)
  Downloading pycparser-2.23-py3-none-any.whl.metadata (993 bytes)
Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m11.1 MB/s[0m  [33m0:00:00[0m
[?25hDownloading cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (219 kB)
Downloading pycparser-2.23-py3-none-any.whl (118 kB)
Installing collected packages: pycparser, cffi, soundfile
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3/3[0m [soundfile]/3[0m [cffi]
[1A[2KSuccessfully installed cffi-2.0.0 pycparser-2.23 soundfile-0.13.1
Note: you may need to restart the kernel to use updated packages.


# I- Conversion des données au format MEL

In [12]:
import os
import torch
import soundfile as sf
import torchaudio
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
import torch.nn.functional as F
import numpy as np

# -----------------------
# DATASET PYTORCH WAV SANS TORCHAUDIO.LOAD
# -----------------------
class MelDataset(torch.utils.data.Dataset):
    def __init__(self, root, segment_frames=200, sample_rate=22050, n_mels=80):
        self.root = root
        self.segment_frames = segment_frames
        self.sample_rate = sample_rate
        
        self.mel = MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=1024,
            hop_length=256,
            win_length=1024,
            n_mels=n_mels
        )
        self.db = AmplitudeToDB()

        # lister tous les fichiers wav
        self.wav_files = []
        for subfolder in os.listdir(root):
            sub_path = os.path.join(root, subfolder)
            if os.path.isdir(sub_path):
                for f in os.listdir(sub_path):
                    if f.endswith(".wav"):
                        self.wav_files.append(os.path.join(sub_path, f))

        if len(self.wav_files) == 0:
            raise RuntimeError("Aucun fichier .wav trouvé dans root.")

        self._compute_statistics()

    # -----------------------
    # CHARGEMENT AUDIO AVEC SOUNDFILE
    # -----------------------
    def _load_wav(self, path):
        wav, sr = sf.read(path)            # -> numpy array (N,) ou (N,2)
        if wav.ndim == 2:                  # stereo → mono
            wav = wav.mean(axis=1)
        wav = wav.astype(np.float32)
        
        # conversion tensor
        wav = torch.tensor(wav).unsqueeze(0)  # (1,N)

        # resample si nécessaire
        if sr != self.sample_rate:
            wav = torchaudio.functional.resample(wav, sr, self.sample_rate)

        # normalisation amplitude
        wav = wav / wav.abs().max()

        return wav

    # conversion en melspec
    def _wav_to_mel(self, wav):
        mel = self.mel(wav)
        mel_db = self.db(mel)
        return mel_db.squeeze(0).transpose(0, 1)  # (frames, mels)

    # calcul mean/std global
    def _compute_statistics(self):
        print("Calcul des statistiques globales...")
        mel_list = []
        for path in self.wav_files:
            wav = self._load_wav(path)
            mel = self._wav_to_mel(wav)
            mel_list.append(mel)

        all_mels = torch.cat(mel_list, dim=0)
        self.mean = all_mels.mean(dim=0)
        self.std = all_mels.std(dim=0)
        print("done.")

    def __len__(self):
        return len(self.wav_files)

    def __getitem__(self, idx):
        wav = self._load_wav(self.wav_files[idx])
        mel = self._wav_to_mel(wav)

        mel = (mel - self.mean) / (self.std + 1e-6)

        if mel.shape[0] < self.segment_frames:
            pad = self.segment_frames - mel.shape[0]
            mel = F.pad(mel, (0, 0, 0, pad))

        max_start = mel.shape[0] - self.segment_frames
        start = torch.randint(0, max_start + 1, (1,)).item()
        return mel[start:start + self.segment_frames]


In [13]:
dataset = MelDataset(
    root="/home/onyxia/Dynamical-Variational-Autoencoders/data/data_wav",
    segment_frames=200,
    sample_rate=22050,
    n_mels=80
)


Calcul des statistiques globales...
done.


In [15]:
len(dataset)

2922

In [21]:
save_dir = "/home/onyxia/Dynamical-Variational-Autoencoders/data/mels_saved"
os.makedirs(save_dir, exist_ok=True)

for i in tqdm.tqdm(range(len(dataset))):
    mel = dataset[i]   # (200, 80)
    torch.save(mel, os.path.join(save_dir, f"mel_{i}.pt"))




100%|██████████| 2922/2922 [03:31<00:00, 13.81it/s]


In [26]:

# dossier global qui contiendra mel_0000, mel_1000, mel_2000, ...
root_save = "/home/onyxia/Dynamical-Variational-Autoencoders/data/mels_saved"
os.makedirs(root_save, exist_ok=True)

MAX_FILES_PER_FOLDER = 1000

folder_idx = 0
file_idx_in_folder = 0

current_folder = os.path.join(root_save, f"part_{folder_idx}")
os.makedirs(current_folder, exist_ok=True)

for i in tqdm.tqdm(range(len(dataset))):

    # changement de folder
    if file_idx_in_folder >= MAX_FILES_PER_FOLDER:
        folder_idx += 1
        file_idx_in_folder = 0
        current_folder = os.path.join(root_save, f"part_{folder_idx}")
        os.makedirs(current_folder, exist_ok=True)

    mel = dataset[i]

    save_path = os.path.join(current_folder, f"mel_{i}.pt")
    torch.save(mel, save_path)

    file_idx_in_folder += 1

print("Sauvegarde terminée.")


100%|██████████| 2922/2922 [03:30<00:00, 13.90it/s]

Sauvegarde terminée.





# II- Ouverture du dataset

In [27]:
class SavedMelDataset(torch.utils.data.Dataset):
    def __init__(self, root):
        self.root = root

        self.files = []

        # parcourir part_0, part_1, etc.
        for folder in sorted(os.listdir(root)):
            p = os.path.join(root, folder)
            if os.path.isdir(p):
                for f in os.listdir(p):
                    if f.endswith(".pt"):
                        # on stocke le chemin complet
                        self.files.append(os.path.join(p, f))

        self.files.sort()  # garantir un ordre stable

        if len(self.files) == 0:
            raise RuntimeError("Aucun fichier .pt trouvé dans les folders.")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        return torch.load(self.files[idx])


In [28]:
dataset2 = SavedMelDataset(
    "/home/onyxia/Dynamical-Variational-Autoencoders/data/mels_saved"
)

print(len(dataset2))
print(dataset2[0].shape)


2922
torch.Size([200, 80])
