In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import numpy as np
import h5py
import matplotlib.pyplot as plt

musdb18_path = '/home/enricguso/datasets/musdb18/'

tcontext=120
fft_size=1024
split=0.8
classes = ('vocals', 'drums', 'bass', 'other')
class musDB_class_dataset(Dataset):
    
    def __init__(self, root_dir, tcontext, split, mode):
        self.root_dir = root_dir
        self.tcontext = tcontext
        self.split = split
        self.mode = mode
        with h5py.File(self.root_dir, 'r') as db:
            # At which frequency index we split into train and val set:
            self.train_end_f_ind = db['f_indexes'][+
                np.where(db['f_indexes'][...] >= int(db['f_indexes'][-1]) * self.split)[0][0]]
            # Number of STFT bins for the validation set
            self.val_stftbins = db['f_indexes'][-1] - self.train_end_f_ind
            self.f_indexes=db['f_indexes'][...]
            
    def __len__(self):
        # Returns the length of the dataset
        if self.mode == 'train':
            lens = int((self.train_end_f_ind/self.tcontext))
        else:
            lens = int(self.val_stftbins/self.tcontext)
        return lens

    def __getitem__(self, idx):
        # get the slice using the index idx
        with h5py.File(self.root_dir, 'r') as db:
            if self.mode == 'train':
                reader_head = idx
            elif self.mode == 'val':
                reader_head = idx + int(self.train_end_f_ind / self.tcontext)
            # validation_audio mode: for inference. Provides a whole file with 50% overlap
            track_mag = db['track_mag'][
                        int(reader_head * self.tcontext):int(reader_head * self.tcontext + self.tcontext)]
            track_mag = np.expand_dims(track_mag, 0)
            label = db['label'][reader_head]
            track_mag = torch.from_numpy(track_mag)

            sample = {'input': track_mag, 'label': label}
            
        return sample

train_dataset=musDB_class_dataset(musdb18_path + 'musdb_classify_tcontext_' + str(tcontext) + '.hdf5', tcontext, split, mode='train')

val_dataset=musDB_class_dataset(musdb18_path + 'musdb_classify_tcontext_' + str(tcontext) + '.hdf5', tcontext, split, mode='val')

train_loader = DataLoader(train_dataset, batch_size=30, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=30, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)

class Net(nn.Module):
    def __init__(self, tcontext, fft_size):
        self.tcontext=tcontext
        self.fft_size=fft_size
        super(Net, self).__init__()

        self.conv1=nn.Conv2d(1, 16, (5,5), stride=(1,2))
        self.bn1=nn.BatchNorm2d(16)
        self.conv2=nn.Conv2d(16, 32, (5,5), stride=(1,2))
        self.bn2=nn.BatchNorm2d(32)
        self.conv3=nn.Conv2d(32, 64, (5,5), stride=2)
        self.bn3=nn.BatchNorm2d(64)
        self.conv4=nn.Conv2d(64, 128, (5,5), stride=2)
        self.bn4=nn.BatchNorm2d(128)
        self.conv5=nn.Conv2d(128, 256, (5,5), stride=2)
        self.bn5=nn.BatchNorm2d(256)
        self.conv6=nn.Conv2d(256, 512, (5,5), stride=2)
        self.bn6=nn.BatchNorm2d(512)
        
        self.lrelu = nn.LeakyReLU(negative_slope=0.2)
        self.drop1 = nn.Dropout2d(p=0.5)
        self.drop2 = nn.Dropout2d(p=0.5)
        self.drop3 = nn.Dropout2d(p=0.5)
        
        self.fc1 = nn.Linear(512*4*5,512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 4)
        
    def forward(self, x):
        x=self.conv1(x)
        x=self.bn1(x)
        x=self.lrelu(x)
        x=self.drop1(x)
        x=self.conv2(x)
        x=self.bn2(x)
        x=self.lrelu(x)
        x=self.drop2(x)
        x=self.conv3(x)
        x=self.bn3(x)
        x=self.drop3(x)
        x=self.lrelu(x)
        x=self.conv4(x)
        x=self.bn4(x)
        x=self.lrelu(x)
        x=self.conv5(x)
        x=self.bn5(x)
        x=self.lrelu(x)
        x=self.conv6(x)
        x=self.bn6(x)
        x=self.lrelu(x)
        x=x.view(-1, 512*4*5)
        x=nn.functional.relu(self.fc1(x))
        x=nn.functional.relu(self.fc2(x))
        x=(self.fc3(x))
        
        return x.squeeze()
    
    
net=Net(tcontext, fft_size)

if torch.cuda.is_available():
    net.cuda()
    print('Model sent to GPU')    

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
num_epochs=500

train_loss=torch.zeros(num_epochs)
val_loss=torch.zeros(num_epochs)
val_accuracy=torch.zeros(num_epochs)

for epoch in range(num_epochs):
    print('Epoch '+str(epoch+1)+' starts...')
    running_loss=0.0
    for x in train_loader:
        data = x['input']
        label = x['label'].squeeze()
        
        if torch.cuda.is_available():
            data = torch.autograd.Variable(data.cuda())
            label = torch.autograd.Variable(label.cuda())
        else:
            data = torch.autograd.Variable(data)
            label = torch.autograd.Variable(label)
        optimizer.zero_grad()
        outputs = net(data)
        loss = criterion(outputs, label)
        loss.backward()
        running_loss+=loss.item()
        optimizer.step()
    train_loss[epoch]=running_loss/(len(train_loader)*train_loader.batch_size)
    
    correct = 0
    total = 0
    running_loss = 0
    for x in val_loader:
        #disable dropout
        net.eval()
        with torch.no_grad():
            data = x['input']
            label = x['label'].squeeze()
            if torch.cuda.is_available():
                data = torch.autograd.Variable(data.cuda())
                label = torch.autograd.Variable(label.cuda())
            else:
                data = torch.autograd.Variable(data)
                label = torch.autograd.Variable(label)
            outputs=net(data)
            loss = criterion(outputs, label)
            running_loss+=loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()
            
    val_loss[epoch]=running_loss/(len(val_loader)*val_loader.batch_size)
    val_accuracy[epoch] = 100*correct/total
    #re-enable dropout
    net.train()
    print('Epoch ' + str(epoch+1) + '. Loss: '+'%.4f' %train_loss[epoch].numpy()+' (train) | '+'%.4f' %val_loss[epoch].numpy()+' (val). ACC: '+'%.4f' %val_accuracy[epoch].item()+'%')

Model sent to GPU
Epoch 1 starts...
Epoch 1. Loss: 0.0142 (train) | 0.0188 (val). ACC: 79.8936%
Epoch 2 starts...
