In [None]:
import os
import time
import random
import torch
import torch.nn as nn
from torch.utils.data import IterableDataset, DataLoader
import pandas as pd
import pyarrow.parquet as pq
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import numpy as np
from audio_midi_pipeline import process_files


In [None]:
df = process_files('songs/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_06_Track06_wav.')
labels = df.iloc[:, 513:]
inputs = df.iloc[:, :513]

Reading in Masons architecture. Had to change the Dataset class to the one I used initially, because I'm unfamiliar with the updated dataloading function Mason used. Just copy pasted my SpectrogramDataset on top of the WeightedSpectrogramIterableDataset from model.py. Code is identical to model.py from 
#Define the CNN model:  onward.
Only loads in biases and weights for Masons 

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

class SpectrogramDataset(Dataset):
    def __init__(self, inputs_df, labels_df):
        self.inputs = torch.FloatTensor(inputs_df.values)
        self.labels = torch.FloatTensor(labels_df.values)
        
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        return self.inputs[idx].view(1, 1, 513), self.labels[idx]  # [channels, time, freq]
# Function to create DataLoader
def get_data_loader(dataset, batch_size):
    data_loader = DataLoader(dataset, batch_size=batch_size)
    return data_loader

# Define the CNN model
class PitchDetectionModel(nn.Module):
    def __init__(self, num_pitches=88):
        super(PitchDetectionModel, self).__init__()

        # Reduced number of pooling layers and smaller kernels
        self.conv_layers = nn.Sequential(
            # First conv block
            nn.Conv2d(1, 32, kernel_size=(1, 3), padding=(0, 1)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),  # Only pool frequency dimension

            # Second conv block
            nn.Conv2d(32, 64, kernel_size=(1, 3), padding=(0, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),

            # Third conv block without pooling
            nn.Conv2d(64, 128, kernel_size=(1, 3), padding=(0, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )

        # Calculate flattened feature size
        # Input dimensions: [batch_size, channels=1, height=1, width=513]
        # After conv and pooling:
        # Height remains 1
        # Width after pooling: 513 / (2 * 2) = 128.25 -> floor to 128
        # Channels after last conv layer: 128
        self.flattened_size = 128 * 1 * 128  # channels * height * width

        # Fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Flatten(),  # Flatten all dimensions except batch
            nn.Linear(self.flattened_size, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_pitches),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x shape: [batch_size, channels, height, width]
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

# Recreate the model architecture
model = torch.load("best_model.pth", map_location=torch.device('cpu'))

# Extract the state_dict (weights) from the loaded model
state_dict = model.state_dict()

# Now create a new model instance (with the same architecture)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
new_model = PitchDetectionModel(num_pitches=88)  # Ensure this matches your model architecture

# Load the extracted state_dict into the new model
new_model.load_state_dict(state_dict)

# Move the model to the appropriate device
new_model = new_model.to(device)

# Ensure the model is in evaluation mode
new_model.eval()

The code below gets the data ready to be stacked so processing can be done on the whole midi file

In [None]:

# Wrap inputs and labels in a Dataset
val_dataset = SpectrogramDataset(inputs, labels)

# Create DataLoader
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Print the number of batches
print(f"Number of batches in val_loader: {len(val_loader)}")

# Check the first batch
for data, target in val_loader:
    print(f"Data shape: {data.shape}")  # Should be [batch_size, 1, 1, 513]
    print(f"Target shape: {target.shape}")  # Should be [batch_size, 88]
    break  # Check only the first batch

The code below stacks the outputs into a list

In [None]:
def save_model_outputs(model, data_loader, device):
    # Ensure the model is in evaluation mode
    model.eval()
    
    # List to store all outputs
    all_outputs = []
    
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(data_loader):  # Ignore labels during inference
            data = data.to(device)
            outputs = model(data)  # Get model predictions
            
            # Move outputs to CPU and convert to numpy
            all_outputs.append(outputs.cpu().numpy())
    
    # Combine all batch outputs into a single array
    all_outputs = np.vstack(all_outputs)
    
    return all_outputs  # Shape will be (X, 88)

# Example usage
if __name__ == "__main__":
    # Assuming `val_loader` is the DataLoader for your data
    outputs = save_model_outputs(model, val_loader, device)
    
    # Convert outputs to a list if needed
    outputs_list = outputs.tolist()

    # Print some example outputs
    print("First 5 outputs:", outputs_list[:5])

The code below forces values below 50% certainty to off, and above 50% to be on. This can probably be adjusted, 50 is just a temporary default. This gives us our desired output of a binary array.

In [None]:
model.eval()  # Set model to evaluation mode

binary_results = []  # Store binary outputs for later processing

with torch.no_grad():
    for data, _ in val_loader:  # Assuming val_loader is already defined
        data = data.to(device)
        outputs = model(data)  # Get probabilities from the model
        binary_outputs = (outputs > 0.5).int()  # Apply threshold
        binary_results.append(binary_outputs.cpu().numpy())  # Save as NumPy array for further use

# Combine all results into a single array
binary_results = np.vstack(binary_results)
print(f"Binary results shape: {binary_results.shape}")  # Should be (X, 88)

The following code checks for repeating values in three different window sizes. For each window size, it examines a set number of frames and calculates whether at least 40% of the notes are active. If this condition is met, it combines all the active notes in that window. This approach addresses the issue where values like 1,0,1,1,1,0,1,0,1,1,0,1,1 would generate separate notes for each 1 after a 0 in the MIDI file. Now, these notes are grouped together and treated as a single event. The 40% threshold and window sizes are adjustable and can be tweaked for different results.

In [None]:
import numpy as np

import numpy as np

def fill_zeros_with_ones(binary_results, large_window_size=40, medium_window_size=20, small_window_size=7, 
                          threshold_large=0.4, threshold_medium=0.4, threshold_small=0.4):
    """
    Fill zeros with ones in binary data if at least a certain percentage of the values in the window are 1s.
    
    Parameters:
    - binary_results: 2D numpy array with shape (frames, 88)
    - large_window_size: Size of the large window for long stretches
    - medium_window_size: Size of the medium window for medium bursts
    - small_window_size: Size of the small window for short bursts
    - threshold_large: Percentage threshold for large window (default is 40%)
    - threshold_medium: Percentage threshold for medium window (default is 40%)
    - threshold_small: Percentage threshold for small window (default is 40%)
    
    Returns:
    - modified_results: The updated binary results array
    """
    # Ensure binary_results is a 2D array with shape (frames, 88)
    if binary_results.ndim != 2:
        raise ValueError("Input array must have 2 dimensions (frames, features)")
    
    if binary_results.shape[1] != 88:
        raise ValueError(f"Expected 88 columns, but found {binary_results.shape[1]}.")
    
    print(f"Input shape: {binary_results.shape}")
    
    # Calculate the threshold counts (e.g., 40% for large, medium, and small windows)
    threshold_count_large = int(large_window_size * threshold_large)
    threshold_count_medium = int(medium_window_size * threshold_medium)
    threshold_count_small = int(small_window_size * threshold_small)
    
    # Create a copy of the original binary results to avoid overwriting
    modified_results = binary_results.copy()
    
    # Iterate over each column (88 columns)
    for col in range(binary_results.shape[1]):
        print(f"Processing column {col + 1} of {binary_results.shape[1]}...")
        
        # Iterate through the frames and apply the large window
        for i in range(binary_results.shape[0] - large_window_size + 1):
            window = binary_results[i:i + large_window_size, col]
            if np.sum(window) >= threshold_count_large:  # Check if the large window meets the threshold
                # Fill all zeros in the window with 1 in the modified results
                modified_results[i:i + large_window_size, col] = np.where(window == 0, 1, window)
        
        # Apply the medium window
        for i in range(binary_results.shape[0] - medium_window_size + 1):
            window = binary_results[i:i + medium_window_size, col]
            if np.sum(window) >= threshold_count_medium:  # Check if the medium window meets the threshold
                # Fill all zeros in the window with 1 in the modified results
                modified_results[i:i + medium_window_size, col] = np.where(window == 0, 1, window)
        
        # Apply the small window
        for i in range(binary_results.shape[0] - small_window_size + 1):
            window = binary_results[i:i + small_window_size, col]
            if np.sum(window) >= threshold_count_small:  # Check if the small window meets the threshold
                # Fill all zeros in the window with 1 in the modified results
                modified_results[i:i + small_window_size, col] = np.where(window == 0, 1, window)
                
    return modified_results

# Example usage
binssss = fill_zeros_with_ones(binary_results)


Code below actually generates the midi file from our new binary variable we created called "binssss"

In [None]:
from midiutil import MIDIFile

def create_midi_from_binary_test(binssss, output_file):
    """
    Convert binary results to a MIDI file while preserving original shape.
    
    Parameters:
    - binssss: 2D numpy array of binary values (0 or 1) representing key states
    - output_file: Path to the output MIDI file
    """
    # Validate input shape
    if binssss.ndim != 2 or binssss.shape[1] != 88:
        raise ValueError("Input array must have shape (x, 88)")

    print(f"Original binary results shape: {binssss.shape}")
    
    # MIDI note numbers for piano (A0 to C8)
    base_midi_note = 21
    
    # Create a new MIDI file
    midi_file = MIDIFile(1)
    midi_file.addTempo(0, 0, 120)  # Add tempo track
    
    # Set the time increment based on your data
    time_increment_ms = 23.2198  # 0.0116099 seconds per row
    
    # Track the state of each note (on/off) and the start time of each note
    note_state = [False] * 88  # Assume initially all notes are off
    note_start_times = [None] * 88
    
    # Iterate through binary results and add note on/off events
    for time_index, frame in enumerate(binssss):
        current_time = time_index * time_increment_ms / 1000  # Convert to seconds
        
        for key_index, key_state in enumerate(frame):
            midi_note = base_midi_note + key_index
            
            # If the note is on and was previously off, start the note
            if key_state == 1 and not note_state[key_index]:
                note_state[key_index] = True
                note_start_times[key_index] = current_time
            
            # If the note is off and was previously on, end the note
            elif key_state == 0 and note_state[key_index]:
                # Only add note if we have a valid start time
                if note_start_times[key_index] is not None:
                    duration = current_time - note_start_times[key_index]
                    
                    # Ensure a minimum duration to prevent zero-length notes
                    duration = max(duration, 0.1)
                    
                    midi_file.addNote(
                        0,      # track
                        0,      # channel
                        midi_note, 
                        note_start_times[key_index], 
                        duration, 
                        100     # velocity
                    )
                
                # Reset note state
                note_state[key_index] = False
                note_start_times[key_index] = None
    
    # Write the MIDI file
    with open(output_file, "wb") as f:
        midi_file.writeFile(f)

    print("MIDI file created successfully!")

# Example usage
create_midi_from_binary_test(binssss, "masons_model_3mapped.mid")


Code that removes standalone/outlier values. Functions the same as above, except removes any standalone values in a 5 frame window.

In [None]:
def remove_standalone_ones(binary_results, window_size=5):
    """
    Remove standalone '1's in a 5-frame window by checking if a '1' is isolated (surrounded by '0's).
    
    Parameters:
    - binary_results: 2D numpy array with shape (frames, 88)
    - window_size: Size of the window to check for standalone '1's (default is 5 frames)
    
    Returns:
    - modified_results: The updated binary results array with standalone '1's removed
    """
    # Ensure binary_results is a 2D array with shape (frames, 88)
    if binary_results.ndim != 2:
        raise ValueError("Input array must have 2 dimensions (frames, features)")
    
    if binary_results.shape[1] != 88:
        raise ValueError(f"Expected 88 columns, but found {binary_results.shape[1]}.")
    
    print(f"Input shape: {binary_results.shape}")
    
    # Create a copy of the original binary results to avoid overwriting
    modified_results = binary_results.copy()
    
    # Iterate over each column (88 columns)
    for col in range(binary_results.shape[1]):
        print(f"Processing column {col + 1} of {binary_results.shape[1]}...")
        
        # Iterate through the frames and apply the 5-frame window
        for i in range(binary_results.shape[0] - window_size + 1):
            window = binary_results[i:i + window_size, col]
            
            # Check if there is a standalone 1 (1 surrounded by 0's)
            for j in range(1, window_size - 1):  # Skip the first and last frame in the window
                if window[j] == 1 and window[j-1] == 0 and window[j+1] == 0:
                    # Remove the standalone 1 by setting it to 0
                    modified_results[i + j, col] = 0
    
    return modified_results

# Example usage:
binssss_no_standalone = remove_standalone_ones(binssss, window_size=5)

Creating another midi file using updated standalone variable

In [None]:
from midiutil import MIDIFile

def create_midi_from_binary_test(binssss, output_file):
    """
    Convert binary results to a MIDI file while preserving original shape.
    
    Parameters:
    - binssss: 2D numpy array of binary values (0 or 1) representing key states
    - output_file: Path to the output MIDI file
    """
    # Validate input shape
    if binssss.ndim != 2 or binssss.shape[1] != 88:
        raise ValueError("Input array must have shape (x, 88)")

    print(f"Original binary results shape: {binssss.shape}")
    
    # MIDI note numbers for piano (A0 to C8)
    base_midi_note = 21
    
    # Create a new MIDI file
    midi_file = MIDIFile(1)
    midi_file.addTempo(0, 0, 120)  # Add tempo track
    
    # Set the time increment based on your data
    time_increment_ms = 23.2198  # 0.0116099 seconds per row
    
    # Track the state of each note (on/off) and the start time of each note
    note_state = [False] * 88  # Assume initially all notes are off
    note_start_times = [None] * 88
    
    # Iterate through binary results and add note on/off events
    for time_index, frame in enumerate(binssss):
        current_time = time_index * time_increment_ms / 1000  # Convert to seconds
        
        for key_index, key_state in enumerate(frame):
            midi_note = base_midi_note + key_index
            
            # If the note is on and was previously off, start the note
            if key_state == 1 and not note_state[key_index]:
                note_state[key_index] = True
                note_start_times[key_index] = current_time
            
            # If the note is off and was previously on, end the note
            elif key_state == 0 and note_state[key_index]:
                # Only add note if we have a valid start time
                if note_start_times[key_index] is not None:
                    duration = current_time - note_start_times[key_index]
                    
                    # Ensure a minimum duration to prevent zero-length notes
                    duration = max(duration, 0.1)
                    
                    midi_file.addNote(
                        0,      # track
                        0,      # channel
                        midi_note, 
                        note_start_times[key_index], 
                        duration, 
                        100     # velocity
                    )
                
                # Reset note state
                note_state[key_index] = False
                note_start_times[key_index] = None
    
    # Write the MIDI file
    with open(output_file, "wb") as f:
        midi_file.writeFile(f)

    print("MIDI file created successfully!")

# Example usage
create_midi_from_binary_test(binssss_no_standalone, "masons_model_3mapped_standalone.mid")