In [1]:
import pandas
import numpy 
import torch

In [2]:
import torchaudio

In [None]:
import os
import torch
import torchaudio
import numpy as np
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB

class VCTKMelDataset(torch.utils.data.Dataset):
    def __init__(self, root, segment_frames=200, sample_rate=22050, n_mels=80):
        self.root = root
        self.segment_frames = segment_frames
        self.sample_rate = sample_rate
        
        self.mel = MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=1024,
            hop_length=256,
            win_length=1024,
            n_mels=n_mels
        )
        self.db = AmplitudeToDB()
        
        # liste des chemins wav
        self.wav_files = [
            os.path.join(root, spk, f)
            for spk in os.listdir(root)
            if os.path.isdir(os.path.join(root, spk))
            for f in os.listdir(os.path.join(root, spk))
            if f.endswith(".wav")
        ]

        # extractions des mels pour calcul mean/std
        self.all_mels = []
        self._compute_statistics()

    def _load_wav(self, path):
        wav, sr = torchaudio.load(path)
        wav = wav.mean(dim=0, keepdim=True)  # mono
        if sr != self.sample_rate:
            wav = torchaudio.functional.resample(wav, sr, self.sample_rate)
        wav = wav / wav.abs().max()  # normalisation amplitude
        return wav

    def _wav_to_mel(self, wav):
        mel = self.mel(wav)
        mel_db = self.db(mel)
        return mel_db.squeeze(0).transpose(0, 1)  # (frames, mels)

    def _compute_statistics(self):
        print("Calcul des statistiques globales…")
        mel_list = []
        for path in self.wav_files:
            wav = self._load_wav(path)
            mel = self._wav_to_mel(wav)
            mel_list.append(mel)

        all_mels = torch.cat(mel_list, dim=0)
        self.mean = all_mels.mean(dim=0)
        self.std = all_mels.std(dim=0)
        print("done.")

    def __len__(self):
        return len(self.wav_files)

    def __getitem__(self, idx):
        wav = self._load_wav(self.wav_files[idx])
        mel = self._wav_to_mel(wav)

        # normalisation
        mel = (mel - self.mean) / (self.std + 1e-6)

        # si trop court → skip
        if mel.shape[0] < self.segment_frames:
            # padding optionnel
            pad = self.segment_frames - mel.shape[0]
            mel = torch.nn.functional.pad(mel, (0,0,0,pad))

        # échantillon aléatoire d’un segment
        max_start = mel.shape[0] - self.segment_frames
        start = torch.randint(0, max_start + 1, (1,)).item()
        segment = mel[start:start + self.segment_frames]

        return segment  # shape = (segment_frames, n_mels)
