In [1]:
import os
from autovc.speaker_encoder.utils import wav_to_mel_spectrogram
from autovc.utils.model_loader import load_model
from autovc.utils.core import retrieve_file_paths
import torch
from torch.utils.data import Dataset

class SpeakerEncoderDataLoader(Dataset):
    def __init__(self, data_dict):
        super().__init__()

        # Find wav files in dictionary of data        
        wav_files = [retrieve_file_paths(data_dir_path) for speaker_data_dir in data_dict.values() for data_dir_path in speaker_data_dir]

        # Compute mel spectograms
        speakers = len(data_dict.keys())
        self.datasets = [[wav_to_mel_spectrogram(wav) for wav in wav_files[i]] for i in range(speakers)]
        
        print(f"The datasets are of lengths: {[len(d) for d in self.datasets]}")

    def __getitem__(self, i):
        return tuple(d[i % len(d)] for d in self.datasets)

    def __len__(self):
        return max(len(d) for d in self.datasets)

    def collate_fn(self, batch):

        return batch

    def get_dataloader(self, batch_size=2, shuffle=False,  num_workers=0, pin_memory=False, **kwargs):
        return torch.utils.data.DataLoader(
            self,  
            batch_size      = batch_size, 
            num_workers     = num_workers, 
            shuffle         = shuffle,
            collate_fn      = self.collate_fn,
            drop_last       = True,
        )

datadir = {'hilde': ['data/conversions'], 'hague': ['data/conversions2'], 'peter':['data/new']}
Data = SpeakerEncoderDataLoader(datadir)




The datasets are of lengths: [15, 8, 11]


In [2]:
from autovc.speaker_encoder.model import SpeakerEncoder
dataloader = Data.get_dataloader(batch_size = 2)
SE = SpeakerEncoder()


def batch_forward(batch):
    embeddings = []
    for b in batch:
        embed_speaker = torch.stack([SE.forward(torch.from_numpy(speaker).unsqueeze(0).to('cpu')) for speaker in b])
        embeddings.append(embed_speaker)
    return torch.cat(embeddings, dim = 1)
for i in range(5):
    for batch in dataloader:
        embeds = batch_forward(batch)
        print(embeds.shape)
        print(SE.similarity_matrix(embeds))

torch.Size([3, 2, 256])
tensor([[[1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000]]], grad_fn=<CopySlices>)
tensor([[[5.0000, 5.0000, 5.0000],
         [5.0000, 5.0000, 5.0000]],

        [[5.0000, 5.0000, 5.0000],
         [5.0000, 5.0000, 5.0000]],

        [[5.0000, 5.0000, 5.0000],
         [5.0000, 5.0000, 5.0000]]], grad_fn=<AddBackward0>)
torch.Size([3, 2, 256])
tensor([[[1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000]]], grad_fn=<CopySlices>)
tensor([[[5.0000, 5.0000, 5.0000],
         [5.0000, 5.0000, 5.0000]],

        [[5.0000, 5.0000, 5.0000],
         [5.0000, 5.0000, 5.0000]],

        [[5.0000, 5.0000, 5.0000],
         [5.0000, 5.0000, 5.0000]]], grad_fn=<AddBackward

KeyboardInterrupt: 

In [14]:
SE.similarity_weight

Parameter containing:
tensor([10.], requires_grad=True)