In [1]:
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from config import Wav2Vec2Config
from model import Wav2Vec2ForPreTraining,Wav2Vec2ForSequenceClassification,Wav2Vec2GumbelVectorQuantizer,_compute_mask_indices,Wav2Vec2Encoder,Wav2Vec2FeatureProjection

from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
import random
from tqdm import tqdm
import os

from dataset import AudioDataset

import torch.nn as nn

device = 'cuda' if torch.cuda.is_available() else 'cpu'





parent_dir = 'data/mp3_train_files'
file_list = [os.path.join(root, file) 
             for root, _, files in os.walk(parent_dir) 
             for file in files]

random.seed(42)
random.shuffle(file_list)

train_size = int(0.8 * len(file_list))
val_size = int(0.1 * len(file_list))
test_size = len(file_list) - train_size - val_size

train_files = file_list[:train_size]
val_files = file_list[train_size:train_size + val_size]
test_files = file_list[train_size + val_size:]

train_dataset = AudioDataset(train_files)
val_dataset = AudioDataset(val_files)
test_dataset = AudioDataset(test_files)

In [2]:
state_dict = torch.load('weights/pre_train-01.pt')
config = Wav2Vec2Config()
pre_train_model = Wav2Vec2ForPreTraining(config)
pre_train_model.load_state_dict(state_dict)


model = Wav2Vec2ForSequenceClassification(pre_train_model,len(train_dataset.labels()))

In [29]:
id_2_label = train_dataset.id_2_label()

In [58]:
sm = nn.Softmax(dim=-1)

for N in range(50):

    out = model(test_dataset[N][0].unsqueeze(0)).squeeze()

    pred = id_2_label[sm(out).argmax().item()]

    label = test_dataset[N][1]

    print((pred,label))


('Crochet', 'Ishizaka')
('Crochet', 'Rubinstein')
('Crochet', 'Horowitz')
('Crochet', 'Nikolayeva')
('Crochet', 'Tureck')
('Ishizaka', 'Tharaud')
('Ishizaka', 'Gould')
('Crochet', 'Rubinstein')
('Crochet', 'Horowitz')
('Ishizaka', 'Schiff')
('Crochet', 'Gould')
('Crochet', 'Moravec')
('Crochet', 'Nikolayeva')
('Crochet', 'Tureck')
('Ishizaka', 'Moravec')
('Crochet', 'Crochet')
('Crochet', 'Tureck')
('Crochet', 'Crochet')
('Crochet', 'Rubinstein')
('Crochet', 'Rubinstein')
('Crochet', 'Tharaud')
('Crochet', 'Nikolayeva')
('Ishizaka', 'Tureck')
('Crochet', 'Moravec')
('Crochet', 'Richter')
('Crochet', 'Rubinstein')
('Crochet', 'Tureck')
('Crochet', 'Gould')
('Crochet', 'Richter')
('Nikolayeva', 'Gould')
('Crochet', 'Crochet')
('Crochet', 'Tureck')
('Crochet', 'Richter')
('Crochet', 'Nikolayeva')
('Crochet', 'Tureck')
('Ishizaka', 'Richter')
('Crochet', 'Gould')
('Ishizaka', 'Tharaud')
('Crochet', 'Rubinstein')
('Crochet', 'Crochet')
('Crochet', 'Ishizaka')
('Crochet', 'Tureck')
('Crochet

In [50]:
sm(out).argmax()

tensor(3)