In [7]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import glob
import os
import logging
from composer_one import Composer  # Ensure this file is in the same directory or adjust import
from midi2seq import piano2seq, seq2piano  # Ensure these are available

logging.basicConfig(level=logging.INFO)

def process_midi_files(file_paths, seq_length):
    sequences = []
    for file_path in file_paths:
        try:
            seq = piano2seq(file_path)
            for i in range(0, len(seq) - seq_length, seq_length // 2):
                sequences.append(seq[i:i+seq_length])
        except Exception as e:
            logging.error(f"Error processing file {file_path}: {str(e)}")
    
    if not sequences:
        raise ValueError("No valid sequences were extracted from the MIDI files.")
    
    # Convert list of sequences to a single numpy array
    sequences_array = np.array(sequences)
    
    # Convert numpy array to PyTorch tensor
    return torch.from_numpy(sequences_array).long()

In [8]:
# Define the directory where your MIDI files are stored
midi_dir = '/Users/matiwosbirbo/PianoGen/maestro-v1.0.0 2/'
midi_files = glob.glob(os.path.join(midi_dir, '*.midi'))

if not midi_files:
    midi_files = glob.glob(os.path.join(midi_dir, '*.mid'))  # Try .mid extension

if not midi_files:
    raise FileNotFoundError(f"No MIDI files found in directory: {midi_dir}")

logging.info(f"Found {len(midi_files)} MIDI files.")

# Define sequence length for training
seq_length = 512

try:
    training_data = process_midi_files(midi_files, seq_length)
    logging.info(f"Processed {len(training_data)} sequences from MIDI files.")
except ValueError as e:
    logging.error(str(e))
    raise

INFO:Found 1184 MIDI files.
INFO:Processed 119520 sequences from MIDI files.


In [9]:
# Create DataLoader
batch_size = 32
dataset = TensorDataset(training_data)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [10]:
# Initialize the Composer model
composer = Composer(load_trained=False)

# Check if MPS (Apple Silicon GPU) is available
if torch.backends.mps.is_available():
    device = torch.device("mps")
    logging.info("Using Apple Silicon GPU (MPS)")
else:
    device = torch.device("cpu")
    logging.info("Using CPU")

# Move the model to the appropriate device
composer = composer.to(device)

INFO:Using MPS (GPU) for computation
INFO:Using Apple Silicon GPU (MPS)


In [16]:
# Set training parameters
num_epochs = 10
learning_rate = 0.001

print(f"Model is using device: {device}")

# Train the model
composer.train_model(data_loader, num_epochs, learning_rate)

Model is using device: mps


KeyboardInterrupt: 