# Data preprocessing

## Dependencies

In [None]:
%pip install librosa

## Imports

In [None]:
import librosa as lr
import numpy as np
import os
import torch
import torchvision.models as models
import random
from typing import Callable
from itertools import combinations

## Constants

In [None]:
DATASET_PATH = input('Dataset path: ') or 'Dataset/'
TRAINING_DATASET_PATH = DATASET_PATH + 'IRMAS_Training_Data/'
VALIDATION_DATASET_PATH = DATASET_PATH + 'IRMAS_Validation_Data/'

CLASSES = (
    'cel',
    'cla',
    'flu',
    'gac',
    'gel',
    'org',
    'pia',
    'sax',
    'tru',
    'vio',
    'voi'
)

## Transform training data

### Transform existing data (without mixing)

In [None]:
def transform_training_data(output_dir: str, transform: Callable,
                            source_dir: str = TRAINING_DATASET_PATH) -> None:
    '''`transform` should be a callable that accepts signal and sample rate and
    returns torch.Tensor'''
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    for c in CLASSES:
        output_path = os.path.join(output_dir, c)
        if not os.path.isdir(output_path):
            os.mkdir(output_path)

        for f in os.scandir(os.path.join(source_dir, c)):
            name, ext = os.path.splitext(f.path)
            if ext == '.wav':
                signal, sample_rate = lr.load(f.path)
                result = transform(signal, sample_rate)
                output_filename = os.path.splitext(f.name)[0] + '.pt'
                output_filepath = os.path.join(output_path, output_filename)
                torch.save(result, output_filepath)

### Mix training examples and transform

In [None]:
def mix_and_transform(output_dir: str, transform: Callable,
                      mix_together: int, num_examples: int,
                      source_dir: str = TRAINING_DATASET_PATH) -> None:
    '''Mix together training examples and perform transform.
    mix_together param tells how many input examples will be included in the
    mix.'''
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    class_dir_paths = [os.path.join(source_dir, c) for c in CLASSES]
    class_dirs = [[os.path.join(c_path, f) for f in os.listdir(c_path)]
                  for c_path in class_dir_paths]
    digits = len(str(num_examples-1))

    label_combinations = list(combinations(range(len(CLASSES)), mix_together))
    past_combinations = [set() for c in label_combinations]

    num_combinations = len(label_combinations)
    for i in range(num_examples):
        combination_idx = i % num_combinations
        combination = label_combinations[combination_idx]
        sources = []
        for label in combination:
            sources.append(random.choice(class_dirs[label]))
        # ensure we don't generate duplicates
        while tuple(sources) in past_combinations[combination_idx]:
            sources = []
            for label in combination:
                sources.append(random.choice(class_dirs[label]))
        past_combinations[combination_idx].add(tuple(sources))

        # load the source examples
        signal, sample_rate = lr.load(sources[0])
        for filename in sources[1:]:
            source_signal, _ = lr.load(filename)
            signal += source_signal

        data = transform(signal, sample_rate)
        labels = torch.Tensor([1 if i in combination else 0 for i in
                               range(len(CLASSES))])
        torch.save((data, labels), os.path.join(output_dir,
                                                f'{i:0{digits}}.pt'))

## Transform validation data

In [None]:
def read_validation_labels(filename: str) -> torch.Tensor:
    with open(filename, 'r') as f:
        lines = [line.rstrip() for line in f.readlines()]
    labels = torch.Tensor([1 if c in lines else 0 for c in CLASSES])
    return labels

In [None]:
def transform_validation_data(output_dir: str, transform: Callable,
                              source_dir: str = VALIDATION_DATASET_PATH):
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    files_split = [os.path.splitext(f.name) for f in os.scandir(source_dir)]
    file_names = [f[0] for f in files_split if f[1] == '.wav']
    for filename in file_names:
        wav_file = filename + '.wav'
        label_file = filename + '.txt'

        signal, sample_rate = lr.load(os.path.join(source_dir, wav_file))
        data = transform(signal, sample_rate)
        labels = read_validation_labels(os.path.join(source_dir, label_file))

        torch.save((data, labels), os.path.join(output_dir, filename+'pt'))

## ResNet50 Features

In [None]:
def melspec_resnet50_transform_factory(hop_length: int = 256,
                                       n_mels: int = 256) -> Callable:
    resnet49 = models.resnet50(weights='DEFAULT')

    resnet49.fc = torch.nn.Identity()

    def melspec_resnet50_transform(signal: np.ndarray, sample_rate: int)\
            -> torch.Tensor:
        melspec = torch.Tensor(lr.feature.melspectrogram(y=signal,
                                                         sr=sample_rate,
                                                         hop_length=hop_length,
                                                         n_mels=n_mels))
        normalized = (melspec - torch.min(melspec)) /\
                     (torch.max(melspec) - torch.min(melspec))
        tripled = torch.stack((normalized, normalized, normalized), axis=0)
        resnet_features = resnet49(torch.unsqueeze(tripled, dim=0))
        return torch.squeeze(resnet_features, dim=0)
    return melspec_resnet50_transform

### ResNet50 training transform

In [None]:
melspec_resnet50_training_output = os.path.join(DATASET_PATH, 'melspec_resnet50_training/')
transform_training_data(melspec_resnet50_training_output, melspec_resnet50_transform_factory())

In [None]:
melspec_resnet50_training_output_mix2 = os.path.join(DATASET_PATH, 'melspec_resnet50_training_mix_2/')
mix_and_transform(melspec_resnet50_training_output_mix2,
                  melspec_resnet50_transform_factory(), mix_together=2,
                  num_examples=10000)

In [None]:
melspec_resnet50_training_output_mix3 = os.path.join(DATASET_PATH, 'melspec_resnet50_training_mix_3/')
mix_and_transform(melspec_resnet50_training_output_mix3,
                  melspec_resnet50_transform_factory(), mix_together=3,
                  num_examples=10000)

### ResNet50 validation transform

In [None]:
melspec_resnet50_validation_output = os.path.join(DATASET_PATH, 'melspec_resnet50_validation/')
transform_validation_data(melspec_resnet50_validation_output,
                          melspec_resnet50_transform_factory())

## Mel Spectogram

In [None]:
def melspec_transform_factory(hop_length: int = 256, n_mels: int = 256)\
        -> Callable:
    def melspec_transform(signal: np.ndarray, sample_rate: int)\
            -> torch.Tensor:
        return torch.Tensor(lr.feature.melspectrogram(y=signal, sr=sample_rate,
                                                      hop_length=hop_length,
                                                      n_mels=n_mels))
    return melspec_transform

In [None]:
melspec_output = os.path.join(DATASET_PATH, 'melspec_training/')
transform_training_data(melspec_output, melspec_transform_factory())