In [8]:
import time
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader
from models.resnet import ResNet18, ResNet, BasicBlock
from utils.dataset import LIBRITTS
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

Последний этап в рамках инженерии признаков. Преобразуем весь датасет в numpy массив, состоящий из векторов длины 8192, используя обученную сеть ResNet18

In [11]:
BATCH_SIZE = 32
SAMPLE_RATE = 24000
N_MELS = 128
N_FFT = 1024

In [9]:
model = torch.load('./parameters/model.pkl') # Загрузили обученню модель

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

In [12]:
sample_transforms = torchvision.transforms.Compose([
    torchaudio.transforms.MelSpectrogram(sample_rate=SAMPLE_RATE,
                                        n_fft=N_FFT, 
                                        n_mels=N_MELS,
                                        normalized=True),
    torchvision.transforms.Resize((128, 128))
])

In [13]:
gender_labels = torch.load('./data/labels/gender_labels')

In [14]:
test_dataset = LIBRITTS(
    root='./data/',
    url="test-clean",
    download=True,
    gender_labels=gender_labels,
    transforms=sample_transforms
)

In [15]:
train_dataset = LIBRITTS(
    root='./data/',
    url="train-clean-100",
    download=True,
    gender_labels=gender_labels,
    transforms=sample_transforms
)

In [16]:
len_train_dataset = len(train_dataset)
part_train = int(len_train_dataset*0.8)

train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [part_train, len_train_dataset - part_train])

In [17]:
train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False,
                                          num_workers = 2,
                                          pin_memory = True)

val_data_loader = torch.utils.data.DataLoader(val_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False,
                                          num_workers = 2,
                                          pin_memory = True)
                                        
test_data_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False,
                                          num_workers = 2,
                                          pin_memory = True)

In [18]:
# Проверим что загрузчики работают
for data, label in train_data_loader:
    print(f' Data shape: {data.shape}, label shape: {label.shape}')
    break

 Data shape: torch.Size([32, 1, 128, 128]), label shape: torch.Size([32])


In [57]:
def get_feature_dataset(dataloader: DataLoader, name: str, model: nn.Module, device: str ='cuda'):
    """Преобразует исходный датасет в набор векторов с помощью model и сохраняет как numpy массив.

    Args:
        dataloader (DataLoader):
        name (str): 
        model (nn.Module): 
        device (str, optional): Defaults to 'cuda'.
    """
    flag = True
    np_out = []
    model.eval()
    for x, y in dataloader:
        x = x.to(device)
        out = model(x, True)
        if flag:
            np_y = y.detach().numpy()
            np_out = out.cpu().detach().numpy()
            flag = False
        else:
            np_y = np.append(np_y, y.detach().numpy())
            np_out = np.append(np_out, out.cpu().detach().numpy(), axis=0)
    np.save(f'./data/features/{name}_x.npy', np_out)
    np.save(f'./data/features/{name}_y.npy', np_y)

In [58]:
get_feature_dataset(train_data_loader,'train', model, device)
get_feature_dataset(val_data_loader,'val', model, device)
get_feature_dataset(test_data_loader,'test', model, device)

В результате получили новое представление исходного датасета в виде векторов длины 8192.