In [None]:
import timm
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchaudio
from torchaudio.transforms import Resample, MelSpectrogram

import torchvision
import torchvision.transforms as T
from torchvision.transforms import Resize

import numpy as np

import os
import glob

import matplotlib.pyplot as plt
%matplotlib  inline

In [None]:
TRAIN_AUDIO_FOLDER = "../VC-PRG-1_5/"
TEST_AUDIO_FOLDER = "../VC-PRG-6/"

SAMPLE_RATE = 22050
NUM_SAMPLES = 22050

BATCH_SIZE = 10
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4

In [None]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [None]:
path = TRAIN_AUDIO_FOLDER

class VehicleDataset(Dataset):
    def __init__(self, folder_path, device, target_sample_rate, num_samples):
        self.audio_files = sorted(glob.glob(path + "*.wav"))
        self.label_files = sorted(glob.glob(path + "*.txt"))
        self.device = device
        self.target_sample_rate = target_sample_rate
        self.num_samples = num_samples
    
    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, item):
        label = self.__get_label(item)
        waveform, sample_rate = self.__load_audio(item)
        waveform = waveform.to(self.device)
        waveform = self.__resample_if_necessary(waveform, sample_rate)
        waveform = self.__mix_down_if_necessary(waveform)
        waveform = self.__cut_if_necessary(waveform)
        waveform = self.__right_pad_if_necessary(waveform)

        mel_spectrogram = self.__get_mel_spectrogram(waveform)
        mel_spectrogram = self.__resize_mel_spectrogram(mel_spectrogram)
        mel_spectrogram = self.__convert_to_rgb(mel_spectrogram)
        
        return mel_spectrogram, label

    def __get_label(self, item):
        label = 0
        with open(self.label_files[item], 'r') as f:
            label = len(f.readlines())
        return label
    
    def __load_audio(self, item):
        waveform, sample_rate = torchaudio.load(self.audio_files[item])
        return waveform, sample_rate
    
    def __resample_if_necessary(self, waveform, sample_rate):
        if sample_rate != self.target_sample_rate:
            resampler = Resample(sample_rate, self.target_sample_rate).to(self.device)
            waveform = resampler(waveform)
        return waveform

    def __mix_down_if_necessary(self, waveform):
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        return waveform
    
    def __cut_if_necessary(self, waveform):
        if waveform.shape[1] > self.num_samples:
            waveform = waveform[:, :self.num_samples]
        return waveform
    
    def __right_pad_if_necessary(self, waveform):
        if waveform.shape[1] < self.num_samples:
            num_missing_samples = self.num_samples - waveform.shape[1]
            waveform = F.pad(waveform, (0, num_missing_samples))
        return waveform
    
    def __get_mel_spectrogram(self, waveform):
        mel_spec_transformer = MelSpectrogram(
            sample_rate=self.target_sample_rate,
            n_fft=1024,
            win_length=None,
            hop_length=512,
            n_mels=64
        ).to(self.device)
        mel_spec = mel_spec_transformer(waveform)
        return mel_spec

    def __resize_mel_spectrogram(self, mel_spec):
        mel_spec = Resize((224, 224), interpolation=T.InterpolationMode.BILINEAR)(mel_spec)
        return mel_spec

    def __convert_to_rgb(self, mel_spec):
        mel_spec = T.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x)(mel_spec)
        return mel_spec
    def __normalize(self, mel_spec):
        mel_spec = T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)(mel_spec)
        return mel_spec

vcd  = VehicleDataset(path, device, SAMPLE_RATE, NUM_SAMPLES)

print("Feature shape: ", vcd[0][0].shape)
print("Label: ", vcd[0][1])

print("Feature shape: ", vcd[1][0].shape)
print("Label: ", vcd[1][1])

In [None]:
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)

In [None]:
train_data = vcd
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE)

model = model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_loss = list()
train_acc = list()

for epoch in range(NUM_EPOCHS):
    running_loss = 0.0
    correct = 0.0
    total = 0.0
    for i, data in enumerate(train_loader, 0):
        features, labels = data
        features, labels = features.to(device), labels.to(device)

        labels = nn.functional.one_hot(labels, num_classes=10).float()

        predictions = model(features)
        loss = loss_fn(predictions, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        total += labels.size(0)
        correct += (torch.argmax(predictions, 1) == torch.argmax(labels, 1)).sum().item()
        if i % 10 == 0:
            print(f"Epoch [{epoch + 1} / {NUM_EPOCHS}] loss: {running_loss / 10:.3f}")
    
    train_loss.append(running_loss / len(train_loader))
    train_acc.append(correct / total * 100.0)

torch.save(model.state_dict(), "vcd_DeiT_model.pth")
print("Finished Training")

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
ax1.plot(train_acc, '-o')
ax1.set_title = 'Train Accuracy'
ax1.set_xlabel = 'Epoch'
ax1.set_ylabel = 'Accuracy'

ax2.plot(train_loss, '-o')
ax2.set_title = 'Train Loss'
ax2.set_xlabel = 'Epoch'
ax2.set_ylabel = 'Loss'
plt.show()

In [None]:
path = TEST_AUDIO_FOLDER
test_data  = VehicleDataset(path, device, SAMPLE_RATE, NUM_SAMPLES) 
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

correct = 0
total = 0

model.eval()
with torch.no_grad():
    for data in test_loader:
        features, labels = data
        features, labels = features.to(device), labels.to(device)
        
        predictions = model(features)
        _, predictions = torch.max(predictions, 1)
        
        total += labels.size(0)
        correct += (predictions == labels).sum().item()
print(f'Accuracy: {100 * correct // total} %')