In [9]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Conv1d
import torchaudio
import utils    
from torch.utils.data import Dataset, DataLoader
import glob
from torchvision import transforms



# from torch.utils.tensorboard import SummaryWriter
# writer = SummaryWriter()


import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
CUDA = torch.cuda.is_available()

# Define device for torch
use_cuda = True
print("CUDA is available:", torch.cuda.is_available())
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

CUDA is available: True


In [10]:
class DemixingAudioDataset(Dataset):
    """Demixing Audio dataset"""

    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the audio.
            transform (callable, optional): Optional transform to be applied on a sample.

        Returns: 
            
            sample (dict): No transform 
                    key (str): 'context',
                    value (torch.Tensor): mixture_waveform
                    key (str): 'target',
                    values (tuple): (bass_waveform, drums_waveform ... )

            sample (dict): With transform
                    key (str): 'context',
                    value (list of torch.Tensor): mixture_stft_batch
                    key (str): 'target',
                    values (tuple of list of torch.Tensor): (bass_stft_batch, drums_stft_batch ... )

            
        """
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        size = len(glob.glob("{}/*".format(self.root_dir)))
        return size

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        mixture_waveform, _ = torchaudio.load("{}/{}".format(glob.glob("{}/*".format(self.root_dir))[idx],'mixture.wav'))
        bass_waveform, _ = torchaudio.load("{}/{}".format(glob.glob("{}/*".format(self.root_dir))[idx],'bass.wav'))
        drums_waveform, _ = torchaudio.load("{}/{}".format(glob.glob("{}/*".format(self.root_dir))[idx],'drums.wav'))
        others_waveform, _ = torchaudio.load("{}/{}".format(glob.glob("{}/*".format(self.root_dir))[idx],'other.wav'))
        vocals_waveform, _ = torchaudio.load("{}/{}".format(glob.glob("{}/*".format(self.root_dir))[idx],'vocals.wav'))

        sample = {'context': mixture_waveform, 
                    'target': (bass_waveform,
                                drums_waveform,
                                others_waveform,
                                vocals_waveform)}

        if self.transform:
            sample = self.transform(sample)

        return sample

class SplitAudio(object):
    """Splits audio into chunks

    Args:
        output_size (int): Desired output size of audio frames 

    returns: 

    """

    def __init__(self, output_size):
        assert isinstance(output_size, int)
        self.output_size = output_size

    def __call__(self, sample):

        context, target = sample['context'], sample['target']
        mixture_waveform = context
        bass_waveform, drums_waveform, others_waveform, vocals_waveform = target
        
        mixture_waveform_batch = torch.split(mixture_waveform,self.output_size,dim=1)
        bass_waveform_batch = torch.split(bass_waveform,self.output_size,dim=1)
        drums_waveform_batch = torch.split(drums_waveform,self.output_size,dim=1)
        others_waveform_batch = torch.split(others_waveform,self.output_size,dim=1)
        vocals_waveform_batch = torch.split(vocals_waveform,self.output_size,dim=1)


        return {'context': mixture_waveform_batch, 
                    'target': (bass_waveform_batch,
                                drums_waveform_batch,
                                others_waveform_batch, 
                                vocals_waveform_batch)}

class STFTWaveform(object):
    """Converts waveform to stft 

    """
    # FIXME: Fix reshaping if needed

    def __call__(self, sample):

        context, target = sample['context'], sample['target']
        mixture_waveform = context
        bass_waveform, drums_waveform, others_waveform, vocals_waveform = target
        
        mixture_stft_batch = []
        for waveform in mixture_waveform:
            stft = torch.stft(waveform,n_fft=4096,hop_length=1024,win_length=4096,return_complex=False)
            stft = torch.reshape(stft,(2, 2, 2049, -1)).to(device)
            mixture_stft_batch.append(stft)

        bass_stft_batch = []
        for waveform in bass_waveform:
            stft = torch.stft(waveform,n_fft=4096,hop_length=1024,win_length=4096,return_complex=False)
            stft = torch.reshape(stft,(2, 2, 2049, -1)).to(device)
            bass_stft_batch.append(stft)

        drums_stft_batch = []
        for waveform in drums_waveform:
            stft = torch.stft(waveform,n_fft=4096,hop_length=1024,win_length=4096,return_complex=False)
            stft = torch.reshape(stft,(2, 2, 2049, -1)).to(device)
            drums_stft_batch.append(stft)
        
        others_stft_batch = []
        for waveform in others_waveform:
            stft = torch.stft(waveform,n_fft=4096,hop_length=1024,win_length=4096,return_complex=False)
            stft = torch.reshape(stft,(2, 2, 2049, -1)).to(device)
            others_stft_batch.append(stft)

        vocals_stft_batch = []
        for waveform in vocals_waveform:
            stft = torch.stft(waveform,n_fft=4096,hop_length=1024,win_length=4096,return_complex=False)
            stft = torch.reshape(stft,(2, 2, 2049, -1)).to(device)
            vocals_stft_batch.append(stft)


        return {'context': mixture_stft_batch, 
                    'target': (bass_stft_batch,
                                drums_stft_batch,
                                others_stft_batch, 
                                vocals_stft_batch)}


In [11]:
composed = transforms.Compose([SplitAudio(44100),
                               STFTWaveform()])

train_data = DemixingAudioDataset("raw_data/train",transform=composed)
test_data = DemixingAudioDataset("raw_data/test",transform=composed)

In [17]:
class CNN(nn.Module):

    # FIXME: Fix the model


    def __init__(self, numChannels, classes, fixedFirstLayer = 48):
        super(CNN, self).__init__() 

        # Encoder 
        self.conv1 = nn.Conv2d(in_channels=numChannels, out_channels=fixedFirstLayer,kernel_size=(8, 8),stride=(4,4))
        self.relu1 = nn.GELU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.conv2 = nn.Conv1d(in_channels=fixedFirstLayer, out_channels=2*fixedFirstLayer,kernel_size=(8, 8),stride=(4,4))
        self.relu2 = nn.GELU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.conv3 = nn.Conv1d(in_channels=2*fixedFirstLayer, out_channels=4*fixedFirstLayer,kernel_size=(8, 8),stride=(4,4))
        self.relu3 = nn.GELU()
        self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.conv4 = nn.Conv1d(in_channels=4*fixedFirstLayer, out_channels=8*fixedFirstLayer,kernel_size=(8, 8),stride=(4,4))
        self.relu4 = nn.GELU()
        self.maxpool4 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        #Bottleneck layer
        self.conv5 = nn.Conv1d(in_channels=8*fixedFirstLayer, out_channels=16*fixedFirstLayer,kernel_size=(8, 8),stride=(8,8))
        self.relu5 = nn.GELU()

        # Decoder 
        self.de_conv1 = nn.ConvTranspose2d(in_channels=16*fixedFirstLayer, out_channels=8*fixedFirstLayer,kernel_size=(8, 8),stride=(4,4))
        self.de_relu1 = nn.GELU()
        self.de_maxpool1 = nn.UpsamplingBilinear2d(kernel_size=(2, 2))

        self.de_conv2 = nn.Conv1d(in_channels=8*fixedFirstLayer, out_channels=4*fixedFirstLayer,kernel_size=(8, 8),stride=(4,4))
        self.de_relu2 = nn.GELU()
        self.de_maxpool2 = nn.MaxUnpool2d(kernel_size=(2, 2))

        self.de_conv3 = nn.Conv1d(in_channels=4*fixedFirstLayer, out_channels=2*fixedFirstLayer,kernel_size=(8, 8),stride=(4,4))
        self.de_relu3 = nn.GELU()
        self.de_maxpool3 = nn.MaxUnpool2d(kernel_size=(2, 2))

        self.de_conv4 = nn.Conv1d(in_channels=2*fixedFirstLayer, out_channels=fixedFirstLayer,kernel_size=(8, 8),stride=(4,4))
        self.de_relu4 = nn.GELU()
        self.de_maxpool4 = nn.MaxUnpool2d(kernel_size=(2, 2))

        self.de_conv4 = nn.Conv1d(in_channels=fixedFirstLayer, out_channels=classes,kernel_size=(8, 8),stride=(4,4))
        self.de_relu4 = nn.GELU()
        self.de_maxpool4 = nn.MaxUnpool2d(kernel_size=(2, 2))

    def forward(self, inputs):
    
        embeddings = self.embeddings(inputs)

        out = self.linear(embeddings)


        return out

In [18]:
class simple_CNN(nn.Module):

    def __init__(self, numChannels, classes, fixedFirstLayer = 48,):
        super(simple_CNN, self).__init__() 

        # Encoder 
        self.conv1 = nn.ConvTranspose2d(in_channels=numChannels, out_channels=classes, kernel_size=(8, 8),stride=(4,4))
        self.gelu1 = nn.GELU()
        # self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

    def forward(self, inputs):
    
        embeddings = self.conv1(inputs)
        out = self.gelu1(embeddings)
        return out

In [14]:
# Create model and pass to CUDA if available.
model = simple_CNN(numChannels = 2, classes = 4)
model = model.to(device)
model.train()

simple_CNN(
  (conv1): ConvTranspose2d(2, 4, kernel_size=(8, 8), stride=(4, 4))
  (gelu1): GELU()
)

In [15]:
# Define training parameters
learning_rate = 0.0001 # Number recommended by wave-u-net paper
epochs = 100
torch.manual_seed(28)
loss_function = nn.MSELoss() # Loss recommended by wave-u-net paper
optimizer = optim.Adam(model.parameters(), lr = learning_rate) # optimiser recommended by wave-u-net paper
 

In [16]:
def train(data, model, epochs, loss_func, optimizer):
    """
    This is a trainer function to train our CNN model.
    """
    # TODO: Use early stopper and tensorboard 

    losses = []
    accuracies = []

    for epoch in range(epochs):
        for sample in data:
            for keys, values in sample.items():
                batch_stft = values[0]
                target_stft = values[1]
                total_loss = 0
                # Forward pass
                model.zero_grad()

                # FIXME: Fix shape of output
                output = model(batch_stft)
                
                print(output.shape)
                output_lst = torch.split(output,1,dim=1)
                # FIXME: Fix loss calculator

                # How do we know which index is which instrument? 
                for i in range(len(target_stft)):
                    loss = loss_func(output_lst[i], target_stft[i])

                # Backward pass and optim
                total_loss += loss.data.item()
                # print(loss)

                loss.backward()
                optimizer.step()
                
                # Loss update
    
    # Display
    if epoch % 10 == 0:
        # FIXME: Create accuracy checker data
        accuracy = check_accuracy(model, data, word2index, index2word)
        print("Accuracy after epoch {} is {}".format(epoch, accuracy))
        accuracies.append(accuracy)
        losses.append(total_loss)
    return losses, accuracies, model

losses, accuracies, model = train(train_data, model, epochs, loss_function, optimizer)

torch.Size([2, 2, 2049, 44])
torch.Size([2, 4, 8200, 180])
torch.Size([2, 1, 8200, 180])
tensor([[[-2.2287e+00,  0.0000e+00,  4.6125e+00,  ...,  0.0000e+00,
          -2.0750e+01,  0.0000e+00],
         [-2.8599e+01,  0.0000e+00, -2.9630e+01,  ...,  0.0000e+00,
          -9.9305e+00,  0.0000e+00],
         [-4.5172e+00, -1.1921e-06,  3.6602e+00,  ...,  4.0574e+00,
          -2.1347e+01,  3.2380e+00],
         ...,
         [-2.2142e-03,  4.2550e-07, -6.8724e-03,  ..., -3.3275e-01,
          -1.2846e-01, -3.9848e-01],
         [-1.5256e-01,  9.9499e-02,  3.9991e-01,  ...,  3.7075e-01,
          -5.8186e-02, -5.0394e-01],
         [-1.5503e-02,  0.0000e+00, -6.5674e-02,  ..., -3.3011e-01,
           6.8115e-02, -7.9959e-01]],

        [[-1.1847e-01, -7.1112e-01, -2.6312e-01,  ...,  1.1993e-02,
          -1.2402e-01, -3.6142e-01],
         [ 3.4067e-02,  3.0629e-07, -9.0276e-02,  ..., -1.0181e-01,
           3.7164e-01, -7.3254e-01],
         [ 4.8378e-01, -1.8970e-02, -1.1900e-01,  ..., 

  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: The size of tensor a (180) must match the size of tensor b (44) at non-singleton dimension 3