In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.prepare_data import create_mfcc_path_and_speaker_id,prepare_xvector_input_data
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torch.nn.utils.rnn import pad_sequence

In [33]:
# create_mfcc_path_and_speaker_id("data/mfcc/dev","data/mfcc_and_speaker_id/dev.txt")
# create_mfcc_path_and_speaker_id("data/mfcc/test","data/mfcc_and_speaker_id/test.txt")
# create_mfcc_path_and_speaker_id("data/mfcc/dev-other","data/mfcc_and_speaker_id/dev-other.txt")
# create_mfcc_path_and_speaker_id("data/mfcc/test-other","data/mfcc_and_speaker_id/test-other.txt")
# create_mfcc_path_and_speaker_id("data/mfcc/train-100","data/mfcc_and_speaker_id/train-100.txt")
# create_mfcc_path_and_speaker_id("data/mfcc/train-360","data/mfcc_and_speaker_id/train-360.txt")
# create_mfcc_path_and_speaker_id("data/mfcc/train-500","data/mfcc_and_speaker_id/train-500.txt")

In [34]:
files,user_ids = prepare_xvector_input_data("data/mfcc_and_speaker_id/dev.txt")

In [35]:
class XVectorDataset(Dataset):
    def __init__(self, file_list, user_ids):
        self.file_list = file_list
        self.user_ids = user_ids
        self.label2id = {label: idx for idx, label in enumerate(sorted(set(user_ids)))}
        self.labels = [self.label2id[u] for u in user_ids]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        mfcc = np.load(self.file_list[idx])  # [T, F]
        mfcc = mfcc.T  # [time, n_mfcc]
        label = self.labels[idx]
        return torch.tensor(mfcc, dtype=torch.float32), label

def collate_fn(batch):
    mfccs, labels = zip(*batch)
    mfcc_lengths = torch.tensor([x.shape[0] for x in mfccs], dtype=torch.long)
    mfccs_padded = pad_sequence(mfccs, batch_first=True)  # [B, max_T, F]
    return mfccs_padded, labels


In [36]:
dataset = XVectorDataset(files, user_ids)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [37]:
for x, y in dataloader:
    print(x.shape)  # [B, T, F]
    print(y)  # [B]
    break


torch.Size([32, 735, 40])
(28, 11, 25, 18, 29, 28, 8, 29, 1, 24, 34, 13, 1, 19, 19, 19, 9, 10, 4, 4, 24, 37, 25, 32, 17, 14, 25, 9, 12, 3, 7, 28)


In [38]:
class XVector(nn.Module):
    def __init__(self, input_dim, embedding_dim=256):
        super(XVector, self).__init__()
        
        self.tdnn1 = nn.Conv1d(input_dim, 256, kernel_size=5, dilation=1)
        self.tdnn2 = nn.Conv1d(256, 256, kernel_size=3, dilation=2)
        self.tdnn3 = nn.Conv1d(256, 256, kernel_size=3, dilation=3)
        self.tdnn4 = nn.Conv1d(256, 256, kernel_size=1)
        self.tdnn5 = nn.Conv1d(256, 1500, kernel_size=1)

        # Статистичний пулінг (по всій часовій осі)
        self.stats_pooling = self._stats_pooling

        # Повнозв’язні шари
        self.fc1 = nn.Linear(3000, 256)
        self.fc2 = nn.Linear(256, embedding_dim)

    def _stats_pooling(self, x):
        # x: [B, C, T]
        mean = torch.mean(x, dim=2)
        std = torch.std(x, dim=2)
        return torch.cat((mean, std), dim=1)

    def forward(self, x):
        # x: [B, T, F] → [B, F, T]
        x = x.transpose(1, 2)
        x = F.relu(self.tdnn1(x))
        x = F.relu(self.tdnn2(x))
        x = F.relu(self.tdnn3(x))
        x = F.relu(self.tdnn4(x))
        x = F.relu(self.tdnn5(x))
        
        x = self.stats_pooling(x)  # [B, 3000]
        x = F.relu(self.fc1(x))
        embeddings = self.fc2(x)
        return embeddings  # x-vector


In [39]:
model = XVector(input_dim=40, embedding_dim=256)
mfcc = [...]