In [6]:
import pyspark
from pyspark.sql import SparkSession
import numpy as np
import pypianoroll as pp
from tempfile import NamedTemporaryFile

In [7]:
spark = SparkSession.builder.master("local").appName("data").getOrCreate()

In [8]:
# if data is in inside multiple folders, run this code
spark.sql("SET mapreduce.input.fileinputformat.input.dir.recursive=true")

DataFrame[key: string, value: string]

In [9]:
# change the path
data = spark.sparkContext.binaryFiles("lpd/data/*")

In [10]:
# quick check that the data has been downloaded
data.count()

19

In [126]:
def load_data(pair):
  '''
  Function that loads the binary data, splits it into chuncks
  and return (sample,label) pairs
  Hardcoded sequence length = 100
  '''
  X = list()
  labels = list()
  sequence_length = 100
  # write to temp file
  # there should be a better way of doing this
  tracks = []
  with NamedTemporaryFile(suffix='.npz') as tmp:
    tmp.write(pair[1])
    tmp.seek(0)
    tracks = pp.load(tmp.name).tracks
  for j in tracks:
    # be smarter here when choosing which track to keep
    # now it only ignores drum tracks and two blacklisted options
    if not j.is_drum and j.name.lower() not in ["bass","bckvocals"]:
      roll = j.pianoroll
      start = 0
      end = sequence_length
      r_len = roll.shape[0]
      while(end+1<r_len and roll[start:end].sum() != 0):
        current = roll[start:end]
        current = (128-np.argmax(current[:,::-1],axis=1))-1
        b = np.zeros((sequence_length, 128))
        b[np.arange(sequence_length), np.transpose(current)] = 1
        #b[:,0] = 0
        b[:,127] = 0
        X.append(b)
        # we are trying to predict the next note
        label_current = roll[end+1]
        label_idx = (128-np.argmax(label_current[::-1]))-1
        labels_out = np.zeros(128)
        labels_out[label_idx] = 1
        #labels_out[0] = 0
        labels_out[127] = 0
        labels.append(labels_out)
        start = end
        end += sequence_length
      # if you want to take every track that satisfy the conditions
      # comment break
      break
  return zip(X, labels)

In [127]:
# transform the data into (sample, label) pairs
loaded = data.flatMap(load_data)

In [128]:
# check that the shapes are correct
test = loaded.take(1)
print(test[0][0].shape)
print(test[0][1].shape)

(100, 128)
(128,)
