In [5]:
import lightning as L
import torch
import torchaudio
import librosa
from torch.utils.data import Dataset, DataLoader
from datasets import load_from_disk
from pathlib import Path
import glob


from utils import *

In [4]:
base_dir = Path("/data1/datasets/wavefake/")
bonafide_list = glob.glob(str(base_dir / "wavs16" / "*.wav"))
print(len(bonafide_list))
print(bonafide_list[:10])

fake_list = glob.glob(str(base_dir / "generated_audio" / "*" / "*.wav"))
print(len(fake_list))
print(fake_list[:10])

13100
['/data1/datasets/wavefake/wavs16/LJ015-0284.wav', '/data1/datasets/wavefake/wavs16/LJ020-0055.wav', '/data1/datasets/wavefake/wavs16/LJ004-0165.wav', '/data1/datasets/wavefake/wavs16/LJ019-0335.wav', '/data1/datasets/wavefake/wavs16/LJ050-0235.wav', '/data1/datasets/wavefake/wavs16/LJ009-0201.wav', '/data1/datasets/wavefake/wavs16/LJ013-0206.wav', '/data1/datasets/wavefake/wavs16/LJ031-0042.wav', '/data1/datasets/wavefake/wavs16/LJ048-0170.wav', '/data1/datasets/wavefake/wavs16/LJ043-0034.wav']
117983
['/data1/datasets/wavefake/generated_audio/common_voices_prompts_from_conformer_fastspeech2_pwg_ljspeech/gen_6128.wav', '/data1/datasets/wavefake/generated_audio/common_voices_prompts_from_conformer_fastspeech2_pwg_ljspeech/gen_13352.wav', '/data1/datasets/wavefake/generated_audio/common_voices_prompts_from_conformer_fastspeech2_pwg_ljspeech/gen_10907.wav', '/data1/datasets/wavefake/generated_audio/common_voices_prompts_from_conformer_fastspeech2_pwg_ljspeech/gen_1266.wav', '/data1

In [9]:
base_dir = "/data1/datasets/wavefake/"


class WaveFakeDataset(Dataset):
    def __init__(self, base_dir, pad_mode="random", max_len=64000):
        """
        In-the-Wild datamodule for evaluation
        """
        self.base_dir = base_dir
        self.wav_paths = []
        self.labels = []
        if pad_mode == "random":
            self.pad = pad_random
        else:
            self.pad = pad
        self.max_len = max_len
        self.parse_protocol()

    def parse_protocol(self):
        bonafide_list = glob.glob(str(self.base_dir / "wavs16" / "*.wav"))
        fake_list = glob.glob(str(self.base_dir / "generated_audio" / "*" / "*.wav"))
        for path in bonafide_list:
            self.wav_paths.append(path)
            self.labels.append(1)
        for path in fake_list:
            self.wav_paths.append(path)
            self.labels.append(0)

    def __len__(self):
        return len(self.wav_paths)

    def __getitem__(self, index):
        path = self.wav_paths[index]
        x, _ = librosa.load(path, sr=16000)
        x = torch.Tensor(self.pad(x, self.max_len))
        y = torch.LongTensor([1]) if self.labels[index] == 1 else torch.LongTensor([0])
        y = y.squeeze()
        return x, y


class WaveFake(L.LightningDataModule):
    def __init__(self, base_dir, max_len=64000, **dataloaderArgs):
        super().__init__()
        self.base_dir = Path(base_dir)
        self.max_len = max_len
        self.dataloaderArgs = dataloaderArgs

    def setup(self, stage: str):
        self.trainset = WaveFakeDataset(
            self.base_dir, pad_mode="random", max_len=self.max_len
        )
        self.valset = WaveFakeDataset(
            self.base_dir, pad_mode="random", max_len=self.max_len
        )
        self.testset = WaveFakeDataset(
            self.base_dir, pad_mode="normal", max_len=self.max_len
        )

    def train_dataloader(self):
        return DataLoader(self.trainset, shuffle=True, **self.dataloaderArgs)

    def val_dataloader(self):
        return DataLoader(self.valset, shuffle=False, **self.dataloaderArgs)

    def test_dataloader(self):
        return DataLoader(self.testset, shuffle=False, **self.dataloaderArgs)

In [13]:
# Test the WaveFake dataset and datamodule
data_module = WaveFake(base_dir=base_dir, max_len=64000, batch_size=16, num_workers=2)
data_module.setup(stage="fit")

# Fetch a batch from the train dataloader
train_loader = data_module.train_dataloader()
batch = next(iter(train_loader))
x, y = batch

print("Batch x shape:", x.shape)
print("Batch y shape:", y.shape)
print("Batch y:", y)

Batch x shape: torch.Size([16, 64000])
Batch y shape: torch.Size([16])
Batch y: tensor([1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0])
