In [1]:
import torchaudio

In [5]:
!ls music_dataset/genres

blues  classical  country  disco  hiphop  jazz	metal  pop  reggae  rock


In [6]:
!ls music_dataset/genres/blues

blues.00000.wav  blues.00025.wav  blues.00050.wav  blues.00075.wav
blues.00001.wav  blues.00026.wav  blues.00051.wav  blues.00076.wav
blues.00002.wav  blues.00027.wav  blues.00052.wav  blues.00077.wav
blues.00003.wav  blues.00028.wav  blues.00053.wav  blues.00078.wav
blues.00004.wav  blues.00029.wav  blues.00054.wav  blues.00079.wav
blues.00005.wav  blues.00030.wav  blues.00055.wav  blues.00080.wav
blues.00006.wav  blues.00031.wav  blues.00056.wav  blues.00081.wav
blues.00007.wav  blues.00032.wav  blues.00057.wav  blues.00082.wav
blues.00008.wav  blues.00033.wav  blues.00058.wav  blues.00083.wav
blues.00009.wav  blues.00034.wav  blues.00059.wav  blues.00084.wav
blues.00010.wav  blues.00035.wav  blues.00060.wav  blues.00085.wav
blues.00011.wav  blues.00036.wav  blues.00061.wav  blues.00086.wav
blues.00012.wav  blues.00037.wav  blues.00062.wav  blues.00087.wav
blues.00013.wav  blues.00038.wav  blues.00063.wav  blues.00088.wav
blues.00014.wav  blues.00039.wav  blues.00064.wa

In [2]:
# https://pytorch.org/audio/main/generated/torchaudio.datasets.GTZAN.html

In [7]:
import os
import random
import torch
import numpy as np
import soundfile as sf
from torch.utils import data

In [8]:
GTZAN_GENRES = ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']

In [9]:
class GTZANDataset(data.Dataset):
    def __init__(self, data_path, split, num_samples, num_chunks):
        self.data_path =  data_path if data_path else ''
        self.split = split
        self.num_samples = num_samples
        self.num_chunks = num_chunks
        self.genres = GTZAN_GENRES
        self._get_song_list()


    def _get_song_list(self):
        list_filename = os.path.join(self.data_path, '%s_filtered.txt' % self.split)
        with open(list_filename) as f:
            lines = f.readlines()
        self.song_list = [line.strip() for line in lines]


    def _adjust_audio_length(self, wav):
        if self.split == 'train':
            random_index = random.randint(0, len(wav) - self.num_samples - 1)
            wav = wav[random_index : random_index + self.num_samples]
        else:
            hop = (len(wav) - self.num_samples) // self.num_chunks
            wav = np.array([wav[i * hop : i * hop + self.num_samples] for i in range(self.num_chunks)])
        return wav

    def __getitem__(self, index):
        line = self.song_list[index]

        # get genre
        genre_name = line.split('/')[0]
        genre_index = self.genres.index(genre_name)

        # get audio
        audio_filename = os.path.join(self.data_path, 'genres', line)
        wav, fs = sf.read(audio_filename)

        # adjust audio length
        wav = self._adjust_audio_length(wav).astype('float32')


        return wav, genre_index

    def __len__(self):
        return len(self.song_list)

def get_dataloader(data_path=None, 
                   split='train', 
                   num_samples=22050 * 29, 
                   num_chunks=1, 
                   batch_size=16, 
                   num_workers=0):
    batch_size = batch_size if (split == 'train') else (batch_size // num_chunks)
    data_loader = data.DataLoader(dataset=GTZANDataset(data_path, 
                                                       split, 
                                                       num_samples, 
                                                       num_chunks),
                                  batch_size=batch_size,
                                  shuffle=True,
                                  drop_last=False,
                                  num_workers=num_workers)
    return data_loader

In [30]:
train_loader = get_dataloader(split='train', data_path="music_dataset")
iter_train_loader = iter(train_loader)
train_wav, train_genre = next(iter_train_loader)

print('training data shape: %s' % str(train_wav.shape))
print(train_genre)

training data shape: torch.Size([16, 639450])
tensor([2, 7, 9, 6, 7, 5, 1, 0, 6, 9, 0, 3, 8, 7, 8, 2])


In [16]:
from torch import nn


class Convo(nn.Module):
    def __init__(self, input_channels, output_channels, shape=3, pooling=2, dropout=0.1):
        super(Convo, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, shape, padding=shape//2)
        self.bn = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(pooling)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, wav):
        out = self.conv(wav)
        out = self.bn(out)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.dropout(out)
        return out

In [36]:
import torchaudio


class MusicClassifier(nn.Module):
    def __init__(self, num_channels=16, 
                       sample_rate=22050, 
                       n_fft=1024, 
                       f_min=0.0, 
                       f_max=11025.0, 
                       num_mels=128, 
                       num_classes=10):
        super(MusicClassifier, self).__init__()

        # mel spectrogram
        self.melspec = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, 
                                                            n_fft=n_fft, 
                                                            f_min=f_min, 
                                                            f_max=f_max, 
                                                            n_mels=num_mels)
        self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
        self.input_bn = nn.BatchNorm2d(1)

        # convolutional layers
        self.layer1 = Convo(1, num_channels, pooling=(2, 3))
        self.layer2 = Convo(num_channels, num_channels, pooling=(3, 4))
        self.layer3 = Convo(num_channels, num_channels * 2, pooling=(2, 5))
        self.layer4 = Convo(num_channels * 2, num_channels * 2, pooling=(3, 3))
        self.layer5 = Convo(num_channels * 2, num_channels * 4, pooling=(3, 4))

        # dense layers
        self.dense1 = nn.Linear(num_channels * 4, num_channels * 4)
        self.dense_bn = nn.BatchNorm1d(num_channels * 4)
        self.dense2 = nn.Linear(num_channels * 4, num_classes)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, wav):
        # input Preprocessing
        out = self.melspec(wav)
        print(wav.shape, "input shape")
        print(out.shape, "mel shape")
        out = self.amplitude_to_db(out)
        print(out.shape, "after db shape")

        # input batch normalization
        out = out.unsqueeze(1)
        print(out.shape, "a bit unsqueeze")
        out = self.input_bn(out)
        print(out.shape, "shape after batch normalization")

        # convolutional layers
        out = self.layer1(out)
        print(out.shape, "shape after first convo")
        out = self.layer2(out)
        print(out.shape, "shape after second convo")
        out = self.layer3(out)
        print(out.shape, "shape after third convo")
        out = self.layer4(out)
        print(out.shape, "shape after fourth convo")
        out = self.layer5(out)
        print(out.shape, "shape after fifth convo")
        
        # reshape. (batch_size, num_channels, 1, 1) -> (batch_size, num_channels)
        out = out.reshape(len(out), -1)
        print(out.shape, "shape reshape")

        # dense layers
        out = self.dense1(out)
        print(out.shape, "shape after first dense")
        out = self.dense_bn(out)
        print(out.shape, "shape after batch norm")
        out = self.relu(out)
        print(out.shape, "shape after relu")
        out = self.dropout(out)
        print(out.shape, "shape after dropout")
        out = self.dense2(out)
        print(out.shape, "shape after final dense")
        
        print("================batch finished")

        return out


In [37]:
device = "cpu"

music_classifier = MusicClassifier()
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(music_classifier.parameters(), lr=0.001)
valid_losses = []
num_epochs = 30

for epoch in range(num_epochs):
    losses = []

    # Train
    music_classifier.train()
    for (wav, genre_index) in train_loader:
        wav = wav.to(device)
        genre_index = genre_index.to(device)

        # Forward
        out = music_classifier(wav)
        loss = loss_function(out, genre_index)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print('Epoch: [%d/%d], Train loss: %.4f' % (epoch+1, num_epochs, np.mean(losses)))

torch.Size([16, 639450]) input shape
torch.Size([16, 128, 1249]) mel shape
torch.Size([16, 128, 1249]) after db shape
torch.Size([16, 1, 128, 1249]) a bit unsqueeze
torch.Size([16, 1, 128, 1249]) shape after batch normalization
torch.Size([16, 16, 64, 416]) shape after first convo
torch.Size([16, 16, 21, 104]) shape after second convo
torch.Size([16, 32, 10, 20]) shape after third convo
torch.Size([16, 32, 3, 6]) shape after fourth convo
torch.Size([16, 64, 1, 1]) shape after fifth convo
torch.Size([16, 64]) shape reshape
torch.Size([16, 64]) shape after first dense
torch.Size([16, 64]) shape after batch norm
torch.Size([16, 64]) shape after relu
torch.Size([16, 64]) shape after dropout
torch.Size([16, 10]) shape after final dense
torch.Size([16, 639450]) input shape
torch.Size([16, 128, 1249]) mel shape
torch.Size([16, 128, 1249]) after db shape
torch.Size([16, 1, 128, 1249]) a bit unsqueeze
torch.Size([16, 1, 128, 1249]) shape after batch normalization
torch.Size([16, 16, 64, 416]) s

KeyboardInterrupt: 