In [1]:
cd ..

/home/araxal/coursework


In [2]:
from models.genre_classification.sota_models.CNNSA import CNNSA
import torchaudio
from torchaudio.functional import resample
import torch
from typing import List
import numpy as np
from utils.genre_classification import evaluate, executor
from utils.decade_classification import feature_preparator
from utils.genre_classification.plot_metrics import plot_metrics
from utils.genre_classification import plot_confusion_matrix
from torch import nn

In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [4]:
BATCH_SIZE = 8
NUM_CLASSES = 4
DEVICE = torch.device('cuda')

In [5]:
%%time

def raw_audio_slicer(features):
    time_len = features.shape[1]
    slice_len = 22050 * 5  # 5 минут

    min_idx = np.random.randint(time_len - slice_len)

    return torch.index_select(features, 1, torch.tensor(range(min_idx, min_idx + slice_len)))

def transform(path: List[str]):
    waveform, sample_rate = torchaudio.load(path[0], normalize=True)
    waveform = waveform.mean(dim=0, keepdim=True)
    waveform = resample(waveform, orig_freq=sample_rate, new_freq=22050)
    waveform = raw_audio_slicer(waveform)

    return waveform[0]

train_data_loader, val_data_loader, test_data_loader, idx_to_label = feature_preparator(
    'features/decade_classification/external-nn.p',
    BATCH_SIZE,
    normalize=False,
    external=True,
    transform=transform
)

def transform_idx_to_label(x):
    return idx_to_label[x]

CPU times: user 6.7 ms, sys: 49 µs, total: 6.75 ms
Wall time: 6.69 ms


In [6]:
next(iter(train_data_loader))[0].shape

torch.Size([8, 110250])

In [7]:
cnnsa_model = CNNSA(sample_rate=22050, n_class=NUM_CLASSES, f_max=11025).to(DEVICE)
train_progress, val_progress = executor(DEVICE, cnnsa_model, train_dataloader = train_data_loader, val_dataloader=val_data_loader, epochs=100, learning_rate=0.0001, weight_decay=0.1, evaluate_per_iteration=15, early_stop_after=(15,  0.001), print_metrics=True)

plot_metrics(train_progress, val_progress, metrics = ['loss', 'accuracy'])

  0%|          | 0/100 [01:12<?, ?it/s]


KeyboardInterrupt: 

In [None]:
test_loss, test_accuracy, (test_pred, test_true) = evaluate(DEVICE, cnnsa_model, test_data_loader, criterion=nn.CrossEntropyLoss(), return_pred=True)

plot_confusion_matrix(test_true, test_pred, idx_to_label, transform_idx_to_label)

In [None]:
from models.genre_classification.sota_models.HarmonicCNN import HarmonicCNN

harmonic_cnn_model = HarmonicCNN(sample_rate=22050, n_class=NUM_CLASSES, f_max=11025, n_channels=16).to(DEVICE)
train_progress, val_progress = executor(DEVICE, harmonic_cnn_model, train_dataloader = train_data_loader, val_dataloader=val_data_loader, epochs=100, learning_rate=0.0001, weight_decay=0.1, evaluate_per_iteration=15, early_stop_after=(15,  0.001), print_metrics=True)

plot_metrics(train_progress, val_progress, metrics = ['loss', 'accuracy'])

In [None]:
test_loss, test_accuracy, (test_pred, test_true) = evaluate(DEVICE, harmonic_cnn_model, test_data_loader, criterion=nn.CrossEntropyLoss(), return_pred=True)

plot_confusion_matrix(test_true, test_pred, idx_to_label, transform_idx_to_label)