In [1]:
import pickle
import matplotlib.pyplot as plt
import numpy as np
from pydantic import BaseModel
import tempfile
import os
import pretty_midi
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import uuid
import pandas as pd

# Import your trained models and utilities
from classical_composer.data import generate_piano_roll, get_fixed_frame_indicies, generate_random_frame_indicies
from classical_composer.frames import extract_frames
from scipy.stats import mode


from classical_composer.features.feature_extractor import FeatureExtractor
from classical_composer.features import (
    frequency_based_features,
    harmonic_features,
    higher_level_features,
    pitch_based_features,
    temporal_features,
    velocity_based_features,
)

In [2]:
from unittest.mock import Mock
# Create a mock settings object
settings = Mock()
settings.static_folder = "../static"
settings.resource_folder = "../resources"

with open(f"{settings.resource_folder}/models/kmeans_model_run.pkl", "rb") as f:
    kmeans_model_run =  pickle.load(f)

In [3]:

def create_piano_roll_image(piano_roll, extracted_frames, frame_indices, fs=100):
    """
    Plots the entire piano roll with shaded regions for frames and a 3x3 grid of extracted frames.

    Args:
        piano_roll: 2D NumPy array representing the piano roll.
        extracted_frames: List of 2D NumPy arrays for the extracted frames.
        frame_indices: List of tuples (start_index, end_index) for each frame.
        fs: Frames per second for the piano roll.
    """
    fig = plt.figure(figsize=(15, 10))

    # Plot the full piano roll at the top
    ax_main = plt.subplot2grid((1,1), (0, 0), colspan=4)
    ax_main.imshow(piano_roll, aspect='auto', origin='lower', cmap='hot')
    ax_main.set_title('Piano Roll with Highlighted Frames')
    ax_main.set_ylabel('Pitch')
    ax_main.set_xlabel(f"Time (1/{fs}s)")

    # Add shaded areas and labels for each frame
    colors = plt.cm.tab10(np.linspace(0, 1, len(frame_indices)))  # Generate unique colors
    for idx, (start, end) in enumerate(frame_indices):
        # Add a shaded rectangle
        ax_main.add_patch(patches.Rectangle(
            (start, 0),  # Bottom-left corner
            end - start,  # Width
            piano_roll.shape[0]/10,  # Height
            linewidth=1,
            edgecolor=colors[idx],
            facecolor=colors[idx],
            alpha=0.3  # Semi-transparent shading
        ))

        # Add a rotated label
        label_x = (start + end) / 2
        label_y = -piano_roll.shape[0]  # Bottom of the plot
        ax_main.text(
            label_x, label_y, f'Frame {idx+1}',
            color=colors[idx], ha='center', va='bottom',
            fontsize=10, rotation=90,  # Rotate 90 degrees
            bbox=dict(facecolor='white', alpha=0.7)
        )

    # Generate a random UUID
    file_UUID = uuid.uuid4()
    main_filename = f"{settings.static_folder}/{file_UUID}.png"
    plt.savefig(main_filename)
    plt.close()


    frame_files = []
    # create each frame image
    for idx, frame in enumerate(extracted_frames):
        frame_filename = f"{settings.static_folder}/{file_UUID}_frame_{idx}.png"
        frame_files.append({
            "indicies": frame_indices[idx],
            "file_name": frame_filename
        })
        plt.figure(figsize=(5, 3))
        plt.imshow(piano_roll, aspect='auto', origin='lower', cmap='hot')
        plt.axis('off')  # Turn off the axis
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)  # Remove whitespace
        plt.savefig(frame_filename)
        plt.close()
    
    return {
        "file_name": main_filename,
        "frames": frame_files
    }

In [4]:

#read in file and generate piano roll
input_file_path = "../resources/mars.mid"
input_piano_roll =  generate_piano_roll(input_file_path) 
#extract random frames
fs=100
frame_size=3000
fixed_frames = get_fixed_frame_indicies(input_piano_roll.shape, frame_size=frame_size)

# Extract random frames from the piano roll
random_frames = generate_random_frame_indicies(
    input_piano_roll.shape,
    n_frames=20,
    frame_size=frame_size,
    buffer_size=5,
    random_seed=59,
)
frame_indicies = [(int(start), int(end)) for start, end in fixed_frames + random_frames]
extracted_frames = extract_frames(input_piano_roll, frame_indicies)
# # Plot quality check
piano_roll_image = create_piano_roll_image(input_piano_roll, extracted_frames, frame_indicies, fs=100)
print(piano_roll_image)

{'file_name': '../static/6c75fba3-6153-4fb2-a5e5-ec5eb89f1bc0.png', 'frames': [{'indicies': (0, 3000), 'file_name': '../static/6c75fba3-6153-4fb2-a5e5-ec5eb89f1bc0_frame_0.png'}, {'indicies': (39503, 42503), 'file_name': '../static/6c75fba3-6153-4fb2-a5e5-ec5eb89f1bc0_frame_1.png'}, {'indicies': (7171, 10171), 'file_name': '../static/6c75fba3-6153-4fb2-a5e5-ec5eb89f1bc0_frame_2.png'}, {'indicies': (35236, 38236), 'file_name': '../static/6c75fba3-6153-4fb2-a5e5-ec5eb89f1bc0_frame_3.png'}, {'indicies': (19470, 22470), 'file_name': '../static/6c75fba3-6153-4fb2-a5e5-ec5eb89f1bc0_frame_4.png'}, {'indicies': (12469, 15469), 'file_name': '../static/6c75fba3-6153-4fb2-a5e5-ec5eb89f1bc0_frame_5.png'}, {'indicies': (33205, 36205), 'file_name': '../static/6c75fba3-6153-4fb2-a5e5-ec5eb89f1bc0_frame_6.png'}, {'indicies': (6211, 9211), 'file_name': '../static/6c75fba3-6153-4fb2-a5e5-ec5eb89f1bc0_frame_7.png'}, {'indicies': (1139, 4139), 'file_name': '../static/6c75fba3-6153-4fb2-a5e5-ec5eb89f1bc0_f

In [5]:
kmeans = kmeans_model_run["model"]
scaler = kmeans_model_run["scaler"]
cluster_mapping = kmeans_model_run["cluster_mapping"]
threshold = kmeans_model_run["threshold"]

feature_columns = [
        "pitch_entropy",
        "dominant_pitch",
        "avg_velocity",
        "spectral_bandwidth"
]

feature_functions = [
        pitch_based_features,
        velocity_based_features,
        temporal_features,
        harmonic_features,
        frequency_based_features,
        higher_level_features,
    ]
featureExtractor = FeatureExtractor(feature_functions)


predictions = []
for idx, frame in enumerate(extracted_frames):
    features = featureExtractor.extract_all_features(frame)
    features["frame_id"] = idx
    features_df = pd.DataFrame(features, index=["frame_id"])[feature_columns]

    #normalize features
    X = scaler.transform(features_df)    
    distances = kmeans.transform(X)  # Distance to each centroid
    closest_clusters = np.argmin(distances, axis=1)  # Index of closest cluster
    
    # Map clusters to composers
    prediction = [cluster for cluster in closest_clusters][0]
    if distances[0,closest_clusters] > threshold:
        prediction = -1
    else:
        precition = closest_clusters[0]
    predictions.append(prediction)

final_result = mode(predictions)
final_class = cluster_mapping.get(final_result, "unknown")
print(final_class)

unknown


In [7]:
import tensorflow as tf
cnn_model = tf.keras.models.load_model("../resources/models/cnn_model.keras")

I0000 00:00:1738070998.556597    5037 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5578 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3070 Ti Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6
