In [1]:
import os
import cv2
import ast
import torch
import shutil
import numpy as np
import pandas as pd
from tqdm import tqdm

In [2]:
def parse_bounding_box(bbox_str):
    bbox = ast.literal_eval(bbox_str)
    return bbox[0], bbox[1], bbox[2], bbox[3]

def parse_control_point(point_str):
    point = ast.literal_eval(point_str)
    return point[0], point[1]

def find_notehead(noteheads, x_ref, y_ref, x_range_min, x_range_max):
    candidates = []
    
    for _, note in noteheads.iterrows():
        x_min, y_min, x_max, y_max = parse_bounding_box(note["BoundingBox"])
        if 1400 <= note["Area"] <= 2400:  # Valid notehead size
            if y_min > y_ref and x_range_min <= x_min <= x_range_max:
                candidates.append((y_min, x_min, x_max, y_max))

    if candidates:
        candidates.sort()  # Select the topmost notehead
        y_min, x_min, x_max, y_max = candidates[0]
        return x_min, y_min, x_max, y_max
    return None

def note_mask(img, note):
    x_min, y_min, x_max, y_max = map(int, note)
    mask = np.zeros(img.shape[:2], dtype=np.uint8)
    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # Extract notehead
    note_region = img_gray[y_min:y_max, x_min:x_max]
    binary_region = (note_region < 64).astype(np.uint8)
    mask[y_min:y_max, x_min:x_max] = binary_region

    height, width = img_gray.shape
    # Left stem detection
    for x in range(max(0, x_min - 4), x_min + 8):
        y = y_max
        while y + 32 < height and np.all(img_gray[y:y + 32, x] < 64):
            mask[y:y + 32, x] = 1
            y += 4

    # Right stem detection
    for x in range(x_max - 8, min(width, x_max + 4)):
        y = y_min
        while y - 32 > 0 and np.all(img_gray[y - 32:y, x] < 64):
            mask[y - 32:y, x] = 1
            y -= 4

    return mask

In [3]:
tensor_folder = "tensors"
os.makedirs(tensor_folder, exist_ok=True)
failed_folder = "failed"
os.makedirs(failed_folder, exist_ok=True)
error_log_path = "tensor_error.log"

control_points_df = pd.read_csv("control_points.csv")
notehead_detection_df = pd.read_csv("notehead_detection.csv")
valid_rows = control_points_df[control_points_df["Invalid Bezier Curve"] == 0]

for _, row in tqdm(valid_rows.iterrows(), total=len(valid_rows), desc="Creating Tensors"):
    try:
        crop_file_name = row["Crop File Name"]
        x_P0, y_P0 = parse_control_point(row["P0"])
        x_P4, y_P4 = parse_control_point(row["P4"])

        noteheads = notehead_detection_df[notehead_detection_df["File Name"] == crop_file_name]
        start_note = find_notehead(noteheads, x_P0, y_P0, x_P0 - 65, x_P0)  # Based on notehead size
        end_note = find_notehead(noteheads, x_P4, y_P4, x_P4 - 55, x_P4)  # Based on notehead size

        img_path = os.path.join("Use", crop_file_name)
        slur_erased_path = os.path.join("Erase Slur", f"e_{crop_file_name}")
        img = cv2.imread(img_path)
        slur_erased_img = cv2.imread(slur_erased_path, cv2.IMREAD_GRAYSCALE)

        start_note_mask = note_mask(img, start_note)    # Channel 1
        end_note_mask = note_mask(img, end_note)        # Channel 4

        # Extract staff lines using histogram method
        binary = (slur_erased_img < 64).astype(np.uint8)
        row_pixel_count = np.sum(binary > 0, axis=1)  # Count black pixels per row
        max_pixel_count = np.max(row_pixel_count)  # Usually equal to image width
        rows_to_remove = set(np.where(row_pixel_count > 0.9 * max_pixel_count)[0])  # PARAMETER
        
        # Remove neighboring rows if above threshold
        for row in list(rows_to_remove):
            for neighbor in [row - 1, row + 1]:  # PARAMETER
                if 0 <= neighbor < binary.shape[0] and row_pixel_count[neighbor] > 0.2 * max_pixel_count:
                    rows_to_remove.add(neighbor)  # Remove neighboring row if above the 20% threshold
        rows_to_remove = sorted(rows_to_remove)
        
        staff_line_mask = np.zeros_like(binary, dtype=np.uint8)
        staff_line_mask[rows_to_remove, :] = 1    # Channel 3

        music_symbols_mask = np.logical_and(binary, np.logical_not(staff_line_mask)).astype(np.uint8)

        tensor = np.stack([start_note_mask, music_symbols_mask, staff_line_mask, end_note_mask], axis=0)
        tensor = torch.tensor(tensor, dtype=torch.uint8)
        tensor_path = os.path.join(tensor_folder, crop_file_name.replace(".png", ".pt"))
        torch.save(tensor, tensor_path)

    except Exception as e:
        with open(error_log_path, "a") as error_log:
            error_log.write(f"{crop_file_name}: {e}\n")
        failed_path = os.path.join(failed_folder, crop_file_name)
        shutil.copy(img_path, failed_path)

print("All tensors created.")

Creating Tensors: 100%|█████████████████████████████████████████████████████████████| 634/634 [00:05<00:00, 116.90it/s]

All tensors created.



