In [None]:
!unzip Composer_Dataset.zip

Archive:  Composer_Dataset.zip
   creating: Composer_Dataset/
  inflating: __MACOSX/._Composer_Dataset  
  inflating: Composer_Dataset/.DS_Store  
  inflating: __MACOSX/Composer_Dataset/._.DS_Store  
   creating: Composer_Dataset/NN_midi_files_extended/
  inflating: __MACOSX/Composer_Dataset/._NN_midi_files_extended  
  inflating: Composer_Dataset/NN_midi_files_extended/.DS_Store  
  inflating: __MACOSX/Composer_Dataset/NN_midi_files_extended/._.DS_Store  
   creating: Composer_Dataset/NN_midi_files_extended/test/
  inflating: __MACOSX/Composer_Dataset/NN_midi_files_extended/._test  
   creating: Composer_Dataset/NN_midi_files_extended/train/
  inflating: __MACOSX/Composer_Dataset/NN_midi_files_extended/._train  
   creating: Composer_Dataset/NN_midi_files_extended/dev/
  inflating: __MACOSX/Composer_Dataset/NN_midi_files_extended/._dev  
   creating: Composer_Dataset/NN_midi_files_extended/test/mozart/
  inflating: __MACOSX/Composer_Dataset/NN_midi_files_extended/test/._mozart  
   cr

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import pretty_midi


In [None]:
def extract_features(midi_data):
    # Tempo
    tempo_changes = midi_data.get_tempo_changes()
    tempo = np.mean(tempo_changes[1]) if len(tempo_changes[1]) > 0 else 120.0

    # Key signature
    key_number = midi_data.key_signature_changes[0].key_number if midi_data.key_signature_changes else -1

    # Notes
    notes = []
    for instrument in midi_data.instruments:
        if not instrument.is_drum:
            for note in instrument.notes:
                notes.append(note.pitch)

    avg_pitch = np.mean(notes) if notes else 0
    min_pitch = np.min(notes) if notes else 0
    max_pitch = np.max(notes) if notes else 0

    # Note density
    duration = midi_data.get_end_time()
    note_density = len(notes) / duration if duration > 0 else 0

    # Normalize features
    tempo /= 300.0
    avg_pitch /= 127.0
    min_pitch /= 127.0
    max_pitch /= 127.0
    note_density /= 10.0
    key_number = (key_number + 1) / 12.0

    return np.array([tempo, avg_pitch, min_pitch, max_pitch, note_density, key_number], dtype=np.float32)


In [None]:
!pip install pretty_midi

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m53.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty_midi: filename=pretty_midi-0.2.10-py3-none-any.whl size=5592286 sha256=7fff170a0c5743f791cd51bb3a2cd6af3ae03d90583741fc25a0f6aace8a0068
  Stored in directory: /root/.cache/pip/wheels/e6/95/ac/15ceaeb2823b04d8e638fd1495357adb8d26c00ccac9d7782e
Successfully built pretty_midi
Installing collected packages: mido, pretty_midi
Successf

In [None]:
class MidiDataset(Dataset):
    def __init__(self, file_paths, labels, max_seq_len=500, fs=10, augment=True):
        self.file_paths = file_paths
        self.labels = labels
        self.max_seq_len = max_seq_len
        self.fs = fs
        self.augment = augment

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        midi_path = self.file_paths[idx]
        label = self.labels[idx]

        midi_data = pretty_midi.PrettyMIDI(midi_path)

        # Augmentation: Random transpose
        if self.augment and np.random.rand() < 0.5:
            semitones = np.random.randint(-4, 5)
            for instrument in midi_data.instruments:
                for note in instrument.notes:
                    note.pitch = min(max(note.pitch + semitones, 0), 127)

        # Piano roll for CNN
        piano_roll = midi_data.get_piano_roll(fs=self.fs)
        if piano_roll.shape[1] > self.max_seq_len:
            piano_roll = piano_roll[:, :self.max_seq_len]
        else:
            pad_width = self.max_seq_len - piano_roll.shape[1]
            piano_roll = np.pad(piano_roll, ((0, 0), (0, pad_width)), mode='constant')
        piano_roll = (piano_roll > 0).astype(np.float32)

        # Note sequence for LSTM
        notes = []
        for instrument in midi_data.instruments:
            if not instrument.is_drum:
                for note in instrument.notes:
                    notes.append(note.pitch)
        if len(notes) > self.max_seq_len:
            notes = notes[:self.max_seq_len]
        else:
            notes += [0] * (self.max_seq_len - len(notes))

        # Extract numeric features
        numeric_features = extract_features(midi_data)

        return (
            torch.tensor(piano_roll, dtype=torch.float32),
            torch.tensor(notes, dtype=torch.long),
            torch.tensor(numeric_features, dtype=torch.float32),
            torch.tensor(label, dtype=torch.long)
        )


In [None]:
def load_file_paths(root_dir):
    splits = ['train', 'dev', 'test']
    data = {}
    for split in splits:
        split_dir = os.path.join(root_dir, split)
        file_paths, labels = [], []
        composers = [c for c in sorted(os.listdir(split_dir)) if os.path.isdir(os.path.join(split_dir, c))]
        composer_to_idx = {composer: idx for idx, composer in enumerate(composers)}

        for composer in composers:
            composer_dir = os.path.join(split_dir, composer)
            for file in os.listdir(composer_dir):
                if file.endswith(('.mid', '.midi')):
                    file_paths.append(os.path.join(composer_dir, file))
                    labels.append(composer_to_idx[composer])

        data[split] = {'files': file_paths, 'labels': labels, 'composer_to_idx': composer_to_idx}
    return data

root_dir = '/content/Composer_Dataset/NN_midi_files_extended/'
dataset_info = load_file_paths(root_dir)

train_dataset = MidiDataset(dataset_info['train']['files'], dataset_info['train']['labels'], augment=True)
dev_dataset = MidiDataset(dataset_info['dev']['files'], dataset_info['dev']['labels'], augment=False)
test_dataset = MidiDataset(dataset_info['test']['files'], dataset_info['test']['labels'], augment=False)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=16)
test_loader = DataLoader(test_dataset, batch_size=16)


In [None]:
class ComposerClassifier(nn.Module):
    def __init__(self, num_classes):
        super(ComposerClassifier, self).__init__()

        # CNN for piano roll
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(16, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((2, 2))
        )
        self.cnn_fc = nn.Sequential(
            nn.Linear(32 * (128 // 4) * (500 // 4), 128),
            nn.LayerNorm(128),
            nn.ReLU()
        )

        # LSTM for note sequences
        self.embedding = nn.Embedding(128, 64)
        self.lstm = nn.LSTM(64, 128, batch_first=True)

        # Fully connected
        self.fc = nn.Sequential(
            nn.Linear(128 + 128 + 6, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, num_classes)
        )

    def forward(self, piano_roll, note_seq, numeric_features):
        # CNN branch
        x_cnn = piano_roll.unsqueeze(1)
        x_cnn = self.cnn(x_cnn)
        x_cnn = x_cnn.view(x_cnn.size(0), -1)
        x_cnn = self.cnn_fc(x_cnn)

        # LSTM branch
        x_embed = self.embedding(note_seq)
        _, (h_n, _) = self.lstm(x_embed)
        x_lstm = h_n[-1]

        # Combine
        x = torch.cat((x_cnn, x_lstm, numeric_features), dim=1)
        return self.fc(x)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(dataset_info['train']['composer_to_idx'])

model = ComposerClassifier(num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0003, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

epochs = 80
best_val_acc = 0
patience = 15
wait = 0

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for piano_roll, note_seq, numeric_features, labels in train_loader:
        piano_roll, note_seq, numeric_features, labels = (
            piano_roll.to(device), note_seq.to(device),
            numeric_features.to(device), labels.to(device)
        )
        optimizer.zero_grad()
        outputs = model(piano_roll, note_seq, numeric_features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)

    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for piano_roll, note_seq, numeric_features, labels in dev_loader:
            piano_roll, note_seq, numeric_features, labels = (
                piano_roll.to(device), note_seq.to(device),
                numeric_features.to(device), labels.to(device)
            )
            outputs = model(piano_roll, note_seq, numeric_features)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    val_acc = 100 * correct / total
    scheduler.step(val_acc)

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Val Acc: {val_acc:.2f}%")

    # Early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        wait = 0
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping triggered!")
            break

# Test evaluation
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for piano_roll, note_seq, numeric_features, labels in test_loader:
        piano_roll, note_seq, numeric_features, labels = (
            piano_roll.to(device), note_seq.to(device),
            numeric_features.to(device), labels.to(device)
        )
        outputs = model(piano_roll, note_seq, numeric_features)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")





Epoch [1/80], Loss: 2.2075, Val Acc: 11.43%
Epoch [2/80], Loss: 2.1947, Val Acc: 25.71%
Epoch [3/80], Loss: 2.1782, Val Acc: 31.43%
Epoch [4/80], Loss: 2.1545, Val Acc: 31.43%
Epoch [5/80], Loss: 2.0982, Val Acc: 48.57%
Epoch [6/80], Loss: 1.9946, Val Acc: 51.43%
Epoch [7/80], Loss: 1.7831, Val Acc: 48.57%
Epoch [8/80], Loss: 1.5706, Val Acc: 57.14%
Epoch [9/80], Loss: 1.4144, Val Acc: 57.14%
Epoch [10/80], Loss: 1.2161, Val Acc: 54.29%
Epoch [11/80], Loss: 0.9262, Val Acc: 45.71%
Epoch [12/80], Loss: 0.8118, Val Acc: 60.00%
Epoch [13/80], Loss: 0.6187, Val Acc: 60.00%
Epoch [14/80], Loss: 0.4890, Val Acc: 48.57%
Epoch [15/80], Loss: 0.3033, Val Acc: 51.43%
Epoch [16/80], Loss: 0.2109, Val Acc: 45.71%
Epoch [17/80], Loss: 0.2446, Val Acc: 51.43%
Epoch [18/80], Loss: 0.2697, Val Acc: 51.43%
Epoch [19/80], Loss: 0.1451, Val Acc: 40.00%
Epoch [20/80], Loss: 0.1110, Val Acc: 54.29%
Epoch [21/80], Loss: 0.0829, Val Acc: 51.43%
Epoch [22/80], Loss: 0.1002, Val Acc: 51.43%
Epoch [23/80], Loss