In [9]:
from google.colab import drive
import os

# set path to project folder
gdrive_path='/content/gdrive/MyDrive/7-programming/music_recognition/'

# mount Google Drive
drive.mount('/content/gdrive', force_remount=True)

# navigate to Google Drive folder
os.chdir(gdrive_path)

# check that we are in the right folder
print(sorted(os.listdir()))

Mounted at /content/gdrive
['.git', '.gitignore', 'README.md', 'data', 'data_loader.ipynb']


In [32]:
import torch
import torch.utils.data as data
from torchvision import transforms
from PIL import Image

class PrimusDataset(data.Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.data = []

        # iterate through each subdirectory (corresponding to a sample)
        for sample_dir in os.listdir(data_path):

            sample_dir_path = os.path.join(data_path, sample_dir)

            image_file = None
            semantic_file = None

            # .png-file contains image, .semantic-file contains labels
            for file in os.listdir(sample_dir_path):
                if file.endswith(".png"):
                    image_file = os.path.join(sample_dir_path, file)
                elif file.endswith(".semantic"):
                    semantic_file = os.path.join(sample_dir_path, file)

            # check if a (image, label)-pair could be found
            if image_file and semantic_file:
                self.data.append((image_file, semantic_file))
            else:
                print(f"Couldn't find {'Image in ' + str(sample_dir_path) if not image_file else 'Labels in ' + str(sample_dir_path)}!")


    def __getitem__(self, index):
        # function to transform image to tensor
        to_tensor = transforms.ToTensor()

        # obtain path for image and label at given index
        image_path, labels_path = self.data[index]

        # read image and label
        image = Image.open(image_path).convert('L')
        image = to_tensor(image)
        with open(labels_path, 'r') as file:
            labels = file.read()

        # apply transforms to image if specified
        if self.transform:
            image = self.transform(image)

        return image, labels

    def __len__(self):
        return len(self.data)

In [33]:
import torch
from torch.utils.data import DataLoader, random_split
from prettytable import PrettyTable

def split_data(dataset, ratio=(0.6, 0.2, 0.2)):
    """
    Applies train-validation-test split to a given dataset

    Parameters:
        dataset: Dataset to which split will be applied
        ratio: Ratio represented as tuple of shape (train, val, test)

    Returns:
        train_data: Dataset for Training
            (i.e. fitting models)
        val_data: Dataset for Validation
            (i.e. evaluating performance of models to tune hyperparameters)
        test_data: Dataset for Testing
            (i.e. evaluating final performance)
    """
    # calculate sizes of dataset
    train_size, val_size = int(ratio[0] * len(dataset)), int(ratio[1] * len(dataset))
    test_size = len(dataset) - train_size - val_size

    # apply split
    train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size])

    # print table of dataset sizes
    table = PrettyTable()
    table.field_names = ["Dataset", "# Samples"]
    table.align["Dataset"] = "l"
    table.add_row(["Train", len(train_data)])
    table.add_row(["Validation", len(val_data)])
    table.add_row(["Test", len(test_data)])

    print(table)

    return train_data, val_data, test_data

In [34]:
# load dataset
data_root = os.path.join(gdrive_path, 'data')
dataset = PrimusDataset(data_path=data_root)

# apply train-val-test split
train_data, val_data, test_data = split_data(dataset)

+------------+-----------+
| Dataset    | # Samples |
+------------+-----------+
| Train      |     6     |
| Validation |     2     |
| Test       |     2     |
+------------+-----------+


In [35]:
train_data[0]

(tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]]),
 'clef-C1\tkeySignature-EbM\ttimeSignature-C\tmultirest-4\tbarline\trest-half\tnote-G4_eighth\tnote-G4_sixteenth\tnote-G4_sixteenth\tnote-G4_eighth\tnote-A4_eighth\tbarline\tnote-B4_eighth\tnote-Bb4_eighth\trest-eighth\tnote-Bb4_eighth\tnote-Bb4_eighth\tnote-Bb4_eighth\tnote-A4_eighth\tnote-Bb4_eighth\tbarline\tnote-G4_quarter\t')