In [81]:
import numpy as np
import pandas as pd
import soundfile as sf

import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns

from collections import defaultdict
from typing import Union

import torch
import torch.nn as nn
from torch.utils.data import Dataset

import os
from glob import glob
from tqdm.auto import tqdm

from IPython.display import Audio, display

SR = 16_000


def adisplay(audio, rate=SR):
    display(Audio(audio, rate=rate))

In [2]:
meta = pd.read_csv('../datasets/voxceleb/vox2_meta.csv')
meta.columns = [col.strip() for col in meta.columns]
for col in meta:
    meta[col] = meta[col].str.strip()

id2gender = dict(zip(meta['VoxCeleb2 ID'], meta['Gender']))

In [65]:
meta = pd.read_csv('../datasets/voxceleb/vox1_meta.csv', sep='\t')
meta.columns = [col.strip() for col in meta.columns]
for col in meta:
    meta[col] = meta[col].str.strip()

id2gender = dict(zip(meta['VoxCeleb1 ID'], meta['Gender']))

In [69]:
def get_speaker_files(root, speaker):
    speaker_root = os.path.join(root, speaker)
    files = glob(f'{speaker_root}/**/*.wav', recursive=True)
    files = [file.removeprefix(f'{speaker_root}/') for file in files]
    return files

def get_file_duration_info(root, speaker, speaker_file):
    filepath = os.path.join(root, speaker, speaker_file)
    a, sr = sf.read(filepath)
    n_frames = len(a)
    duration = n_frames / sr
    return filepath, n_frames, duration

In [70]:
root = '/Users/alexnikko/prog/bss_coursework/datasets/voxceleb/voxceleb1/test/wav/'
speakers = os.listdir(root)
speakers.remove('.DS_Store')  # because it is macOS

speakers_info = {}
for speaker in tqdm(speakers):
    speaker_info = {}
    speaker_files = get_speaker_files(root, speaker)
    speaker_duration = 0
    for file in speaker_files:
        filepath, n_frames, duration = get_file_duration_info(root, speaker, file)
        speaker_info[file] = {
            'n_frames': n_frames,
            'duration': duration
        }
        speaker_duration += duration
    speaker_info['total_duration'] = speaker_duration
    speakers_info[speaker] = speaker_info

  0%|          | 0/40 [00:00<?, ?it/s]

In [71]:
total_duration = 0
for sp in speakers_info:
    total_duration += speakers_info[sp]['total_duration']
print(f'{total_duration / 60} hours')

672.2544104166666 hours


This is enough for testing

for dataset I need following scheme:

- male speakers: list of strings
- female speakers: list of strings
- speakers_info: mapping from speaker to his files
- I need to filter this files by minimum duration (3 seconds e.g.)
- I need to know how many frames in each file

In [44]:
len(durations)

4874

In [79]:
len(list(filter(lambda x: x >= 5, durations)))

3558

In [72]:
# get root as input
# return find speakers, split to male and female speakers, for each speaker find files and filter them

def prepare_meta(root: str, id2gender: dict[str, str], minimum_duration: float):
    speakers = os.listdir(root)
    # because it is macOS
    if '.DS_Store' in speakers:
        speakers.remove('.DS_Store')
    
    
    sp2files = defaultdict(list)
    for speaker in speakers:
        speaker_files = get_speaker_files(root, speaker)
        for file in speaker_files:
            _, n_frames, duration = get_file_duration_info(root, speaker, file)
            if duration < minimum_duration:
                continue
            sp2files[speaker].append({
                'rel_path': file,
                'n_frames': n_frames
            })
    male_speakers = [speaker for speaker in speakers if id2gender[speaker] == 'm']
    female_speakers = [speaker for speaker in speakers if id2gender[speaker] == 'f']
    return male_speakers, female_speakers, sp2files

In [73]:
male_speakers, female_speakers, sp2files = prepare_meta(root, id2gender, 10)

In [74]:
len(male_speakers), len(female_speakers)

(25, 15)

In [112]:
class VoxcelebDataset(Dataset):
    def __init__(self,
                 root: str,
                 male_speakers: list[str],
                 female_speakers: list[str],
                 sp2files: dict[str, list[dict[str, Union[str, int]]]],
                 frames: int,
                 steps: int,
                 prob_same: float = 0.5):
        super().__init__()
        
        self.root = root
        self.male_speakers = male_speakers
        self.female_speakers = female_speakers
        self.sp2files = sp2files
        self.steps = steps
        self.prob_same = prob_same
        assert 0 <= prob_same <= 1, f'prob_same must be in [0, 1], got {prob_same}'
        self.list_of_choice_for_same = ['female', 'male']
        self.prob_of_choice_for_same = [
            len(female_speakers) / (len(female_speakers) + len(male_speakers)),
            len(male_speakers) / (len(female_speakers) + len(male_speakers)),
        ]
    
    def __len__(self):
        return self.steps
    
    def __getitem__(self, idx: int):
        if np.random.rand() < self.prob_same:
            speakers = (self.female_speakers
                        if np.random.choice(self.list_of_choice_for_same, p=self.prob_of_choice_for_same) == 'female'
                        else self.male_speakers)
            sp1, sp2 = np.random.choice(speakers, size=2, replace=False)
        else:
            sp1 = np.random.choise(self.male_speakers)
            sp2 = np.random.choice(self.female_speakers)
        
        file1 = np.random.choice(self.sp2files[sp1])
        file2 = np.random.choice(self.sp2files[sp2])
        
        rel_path1, rel_path2 = file1['rel_path'], file2['rel_path']
        n_frames1, n_frames2 = file1['n_frames'], file2['n_frames']
        
        start1 = np.random.randint(n_frames1 - self.frames + 1)
        start2 = np.random.randint(n_frames2 - self.frames + 1)
        
        path1, path2 = os.path.join(self.root, sp1, rel_path1), os.path.join(self.root, sp2, rel_path2)
        
        a1 = sf.read(path1, dtype='float32')[0]
        a2 = sf.read(path2, dtype='float32')[0]
        
        mix = a1 + a2
        
        return mix, np.hstack([a1, a2])

In [113]:
dataset = VoxcelebDataset(root, male_speakers, female_speakers, sp2files, frames=3 * SR, steps=100)

In [114]:
dataset[0]

AttributeError: 