In [14]:
# Environment setup
!pip install pretty_midi music21 numpy pandas matplotlib scikit-learn



In [15]:
# Imports
import os
import zipfile
import numpy as np
import pretty_midi
import pickle

from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

In [16]:
# Config
SEQUENCE_LENGTH = 50 # input for LSTM and CNN
MAX_SEQUENCES_PER_FILE = 50
MIDI_DIR = 'dataset'

In [17]:
# Unzip dataset
def extract_zip(zip_path, extract_to):
  # if not os.path.exists(extract_to):
  #   os.makedirs(extract_to)
  with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)
  print(f"Extracted zip contents to '{extract_to}")

# Check unzipped (debug)
# for root, dirs, files in os.walk('dataset'):
#   for file in files:
#     if file.endswith('.mid') or file. endswith('.midi'):
#       print(os.path.join(root, file))

### Load MIDI Files

In [18]:
#
def load_midi_files(midi_root):
  midi_data = []
  filenames = []
  for composer in os.listdir(midi_root):
    composer_dir = os.path.join(midi_root, composer)
    if os.path.isdir(composer_dir):
      for file in os.listdir(composer_dir):
        if file.endswith('.mid') or file.endswith('.midi'):
          path = os.path.join(composer_dir, file)
          try:
            midi = pretty_midi.PrettyMIDI(path)
            midi_data.append((file, midi))
            filenames.append(file)
          except Exception as e:
            print(f"Skipping {file}: {e}")
  return midi_data, filenames

### Extract Features


In [19]:
def extract_note_sequence(midi):
  notes = []
  for instrument in midi.instruments:
    if not instrument.is_drum:
      for note in instrument.notes:
        notes.append({
            'start': note.start,
            'pitch': note.pitch,
            'duration': note.end - note.start,
            'velocity': note.velocity
        })
  # Sorting by time:
  notes.sort(key=lambda n: n['start'])
  return notes

### Make Sequences (LSTM/RNN)

In [20]:
def make_feature_sequences(notes, seq_len=SEQUENCE_LENGTH):
  sequence = []
  for i in range(len(notes) - seq_len):
    seq = notes[i:i + seq_len]
    sequence.append(seq)
  return sequence

### Make Piano Rolls (CNN)

In [21]:
# Notes to piano rolls
def notes_to_piano_roll(notes, seq_len=SEQUENCE_LENGTH, pitch_range=(21, 109)):
  roll = np.zeros((seq_len, pitch_range[1] - pitch_range[0])) # 50 * 88

  for i, note in enumerate(notes[:seq_len]):
    pitch = note['pitch']
    if pitch_range[0] <= pitch < pitch_range[1]:
      roll[i, pitch - pitch_range[0]] = 1.0
  return roll

### Normalize Features

In [22]:
def normalize(sequences):
  # Pitch: 0-127 -> [0, 1]
  # Duration: scale to [0, 1]
  # Velocity: 0-127 -> [0, 1]
  all_durations = [step[1] for seq in sequences for step in seq]
  max_duration = max(all_durations) if all_durations else 1.0
  normalized = []
  for seq in sequences:
    norm_seq = np.copy(seq)
    norm_seq[:, 0] /= 127.0         # pitch
    norm_seq[:, 1] /= max_duration  # duration
    norm_seq[:, 2] /= 127.0         # velocity
    normalized.append(norm_seq)
  return normalized

### Label Encoding

In [23]:
def extract_labels(filenames):
  composers = [name.split('_')[0] for name in filenames]
  le = LabelEncoder()
  labels = le.fit_transform(composers)
  return labels, le

### Main Pipeline

In [24]:
def preprocess_dataset(midi_root=MIDI_DIR, seq_len=SEQUENCE_LENGTH):
  print("Loading MIDI files")
  midi_data, filenames = load_midi_files(midi_root)

  all_lstm_sequences = []
  all_cnn_sequences = []
  all_labels = []

  print("Processing MIDI files")
  for filename, midi in tqdm(midi_data):
    notes = extract_note_sequence(midi)
    if len(notes) < seq_len:
      continue # skips shorter sequences

    sequences = make_feature_sequences(notes, seq_len)
    piano_rolls = [notes_to_piano_roll(seq) for seq in sequences]

    # CNN input
    all_cnn_sequences.extend(piano_rolls)

    # LSTM input
    dicts_to_array = lambda seq: np.array([[n['pitch'], n['duration'], n['velocity']] for n in seq])
    array_sequences = [dicts_to_array(seq) for seq in sequences]
    normalized_sequences = normalize(array_sequences)
    all_lstm_sequences.extend(normalized_sequences[:MAX_SEQUENCES_PER_FILE])

    # Labels
    all_labels.extend([filename.split('_')[0]] * len(sequences))

  if not all_lstm_sequences:
    print("No valid sequence found. Check MIDI formatting and minimum length.")
    return None, None, None, None

  # Label encoding
  print("Encoding labels")
  le = LabelEncoder()
  y = le.fit_transform(all_labels)

  X_lstm = np.array(all_lstm_sequences, dtype=np.float32)
  X_cnn = np.array(all_cnn_sequences)[..., np.newaxis]

  print(f"Final LSTM shape: {X_lstm.shape},\nFinal CNN shape: {X_cnn.shape},\nLabels shape: {y.shape}")
  return X_lstm, X_cnn, y, le

### Save Preprocessed Data

In [25]:
def save_data(X, y, label_encoder, path):
  with open(path, 'wb') as f:
    pickle.dump((X, y, label_encoder), f)
  print(f"Saved preprocessed data to {path}")

### Run

In [26]:
if __name__ == '__main__':
  zip_path = '/content/train.zip'
  extract_path = '.'
  extracted_path = '/content/train'

  extract_zip(zip_path, extract_to=extract_path)

  # Preprocess extracted data
  X_lstm, X_cnn, y, le = preprocess_dataset(midi_root=extracted_path)

  if X_lstm is not None:
    save_data(X_lstm, y, le, path='lstm_data.pkl')
    save_data(X_cnn, y, le, path='cnn_data.pkl')
  else:
    print("Preprocessing failed. Check dataset format/contents.")

Extracted zip contents to '.
Loading MIDI files




Processing MIDI files


100%|██████████| 369/369 [01:00<00:00,  6.10it/s]


Encoding labels
Final LSTM shape: (18450, 50, 3),
Final CNN shape: (1456959, 50, 88, 1),
Labels shape: (1456959,)
Saved preprocessed data to lstm_data.pkl
Saved preprocessed data to cnn_data.pkl


In [29]:
# Zip large files
def zip_file(input_path, output_path):
  with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    zipf.write(input_path, arcname=os.path.basename(input_path))

zip_file('cnn_data.pkl', 'cnn_data.zip')

In [27]:
with open('lstm_data.pkl', 'rb') as f:
  X, y, le = pickle.load(f)

In [28]:
print('X shape:', X.shape)
print('y shape:', y.shape)

print("X dtype:", X.dtype)
print("First input sample:\n", X[0])   # 50x3 matrix
print("First label (encoded):", y[0])
print("First label (composer):", le.inverse_transform([y[0]])[0])

X shape: (18450, 50, 3)
y shape: (1456959,)
X dtype: float32
First input sample:
 [[0.52755904 0.2946015  0.79527557]
 [0.46456692 0.2946015  0.79527557]
 [0.48818898 0.2946015  0.79527557]
 [0.33858266 0.2946015  0.79527557]
 [0.52755904 0.05086644 0.6929134 ]
 [0.46456692 0.05086644 0.6929134 ]
 [0.48818898 0.05086644 0.6929134 ]
 [0.33858266 0.05086644 0.6929134 ]
 [0.52755904 0.05086644 0.6929134 ]
 [0.46456692 0.05086644 0.6929134 ]
 [0.48818898 0.05086644 0.6929134 ]
 [0.33858266 0.05086644 0.6929134 ]
 [0.52755904 0.05086644 0.6929134 ]
 [0.46456692 0.05086644 0.6929134 ]
 [0.48818898 0.05086644 0.6929134 ]
 [0.33858266 0.05086644 0.6929134 ]
 [0.52755904 0.05086644 0.6929134 ]
 [0.46456692 0.05086644 0.6929134 ]
 [0.48818898 0.05086644 0.6929134 ]
 [0.33858266 0.05086644 0.6929134 ]
 [0.52755904 0.05086644 0.6929134 ]
 [0.46456692 0.05086644 0.6929134 ]
 [0.48818898 0.05086644 0.6929134 ]
 [0.33858266 0.05086644 0.6929134 ]
 [0.54330707 0.10173289 0.6929134 ]
 [0.48818898 0.101