In [None]:
import os
import random
from sklearn.model_selection import train_test_split

# Paths
sheet_music_folder = "path/to/whole_sheet_music"
bbox_folder = "path/to/bounding_boxes"

# List all sheet music files
sheet_music_files = [f for f in os.listdir(sheet_music_folder) if f.endswith('.png')]

# Split into train, val, test (80/10/10)
train_files, test_files = train_test_split(sheet_music_files, test_size=0.2, random_state=42)
val_files, test_files = train_test_split(test_files, test_size=0.5, random_state=42)

print(f"Training: {len(train_files)}, Validation: {len(val_files)}, Test: {len(test_files)}")


In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import os

class MUSCIMADataset(Dataset):
    def __init__(self, sheet_music_files, bbox_folder, sheet_music_folder, transform=None):
        self.sheet_music_files = sheet_music_files
        self.bbox_folder = bbox_folder
        self.sheet_music_folder = sheet_music_folder
        self.transform = transform

    def __len__(self):
        return len(self.sheet_music_files)

    def __getitem__(self, idx):
        # Load sheet music file
        sheet_music_file = self.sheet_music_files[idx]
        sheet_music_path = os.path.join(self.sheet_music_folder, sheet_music_file)
        sheet_image = Image.open(sheet_music_path).convert("RGB")
        
        # Get bounding box images and labels
        bbox_images = []
        bbox_labels = []
        
        for bbox_file in os.listdir(self.bbox_folder):
            if bbox_file.startswith(sheet_music_file.split(".")[0]):  # Match by sheet name
                bbox_path = os.path.join(self.bbox_folder, bbox_file)
                bbox_image = Image.open(bbox_path).convert("RGB")
                label = int(bbox_file.split("_")[-1].split(".")[0])  # Extract label from filename
                
                bbox_images.append(bbox_image)
                bbox_labels.append(label)
        
        # Apply transforms
        if self.transform:
            sheet_image = self.transform(sheet_image)
            bbox_images = [self.transform(img) for img in bbox_images]
        
        return sheet_image, bbox_images, bbox_labels


In [None]:
import torchvision
from torchvision.models.detection import FasterRCNN

# Load pre-trained Faster R-CNN model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# Replace the classification head
num_classes = 10  # Example: number of musical symbol classes + 1 (background)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Default device:", device)
model = model.to(device)


In [None]:
import torch.optim as optim

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    for sheet_image, bbox_images, bbox_labels in dataloader:
        sheet_image = sheet_image.to(device)
        
        # Prepare target dictionary for Faster R-CNN
        targets = []
        for bboxes, labels in zip(bbox_images, bbox_labels):
            targets.append({
                "boxes": torch.tensor(bboxes, dtype=torch.float32).to(device),
                "labels": torch.tensor(labels, dtype=torch.int64).to(device)
            })
        
        # Forward pass
        loss_dict = model([sheet_image], targets)
        losses = sum(loss for loss in loss_dict.values())
        
        # Backpropagation
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        
        total_loss += losses.item()
    
    print(f"Epoch {epoch}, Loss: {total_loss}")
