# Dataset

In [None]:
import os

midi_files = [f for f in os.listdir('data/midi_dataset/midis') if os.path.isfile(os.path.join('data/midi_dataset/midis', f))]
print(f"Found {len(midi_files)} MIDI files in the dataset.")

composers = set([ f.split(',')[0] for f in midi_files if f.endswith('.mid') ])
print(f"Found {len(composers)} unique composers in the dataset.")
print(f"Ratio of MIDI files to composers: {len(midi_files) / len(composers):.2f}")

Found 10854 MIDI files in the dataset.
Found 2569 unique composers in the dataset.
Ratio of MIDI files to composers: 4.22


In [5]:
from collections import Counter

# Count files per composer
composer_counts = Counter([f.split(',')[0] for f in midi_files if f.endswith('.mid')])

# Top 10 composers
top_10 = composer_counts.most_common(10)
print("Top 10 composers by file count:")
for composer, count in top_10:
    print(f"{composer}: {count}")

# Last 10 composers (least files)
last_10 = composer_counts.most_common()[:-11:-1]
print("\nLast 10 composers by file count:")
for composer, count in last_10:
    print(f"{composer}: {count}")

Top 10 composers by file count:
Scarlatti: 279
Bach: 246
Liszt: 197
Schubert: 131
Chopin: 102
Mozart: 90
Beethoven: 82
Czerny: 80
Handel: 78
Carbajo: 77

Last 10 composers by file count:
Żołnowski: 1
Łodwigowski: 1
Zwyssig: 1
Zweig: 1
Zwart: 1
Zurluth: 1
Zopff: 1
Zoeller: 1
Ziring: 1
Zintl: 1


In [None]:
num_composers_over = sum(1 for count in composer_counts.values() if count > 20)
print(f"Number of composers with more than 50 files: {num_composers_over}")

Number of composers with more than 50 files: 42


In [11]:
top = composer_counts.most_common(20)
# Sum the file counts for these composers
top_sum = sum(count for _, count in top)
print(f"Total number of files for top composers: {top_sum}")

Total number of files for top composers: 1949


In [13]:
top_composers_files = [f for f in midi_files if f.split(',')[0] in dict(top).keys()]
print(f"Number of files for top 20 composers: {len(top_composers_files)}")

Number of files for top 20 composers: 1949


In [15]:
from sklearn.model_selection import train_test_split
from collections import defaultdict

# Split datasets
# Group files by composer
composer_to_files = defaultdict(list)
for f in top_composers_files:
    composer_name = f.split(',')[0]
    composer_to_files[composer_name].append(f)

train_files = []
test_files = []

for files in composer_to_files.values():
    train, test = train_test_split(files, test_size=0.2, random_state=42)
    train_files.extend(train)
    test_files.extend(test)

In [None]:
print(train_files[:5])
print(test_files[:5])

['Alkan, Charles-Valentin, 3 Improvisations dans le Style brillant, Op.12, ORT-ei5_X8w.mid', 'Alkan, Charles-Valentin, Les mois, Op.74, Pzth1MU5JFY.mid', 'Alkan, Charles-Valentin, 2e verset du 41e Psaume, S7-WAxfY3VM.mid', 'Alkan, Charles-Valentin, Pour Monsieur Gurkhaus, L2ajYW7C75s.mid', 'Alkan, Charles-Valentin, 2 Petites pièces, Op.60, jcvaJLXSC0c.mid']
['Alkan, Charles-Valentin, Etude, WoO, c297e_yjlAQ.mid', 'Alkan, Charles-Valentin, Réconciliation, Op.42, MGwcbrYFsiU.mid', "Alkan, Charles-Valentin, Variations sur 'Ah ! segnata é la mia morte', Op.16 No.4, -sSHCvni-NU.mid", 'Alkan, Charles-Valentin, 3 Petites fantaisies, Op.41, CptNa_Pa7b0.mid', 'Alkan, Charles-Valentin, Salut, Cendre de Pauvre, Op.45, HxD_ReBdT-M.mid']


In [21]:
data_dir = 'data/midi_dataset/midis'
train_files = [os.path.join(data_dir, f) for f in train_files]
test_files = [os.path.join(data_dir, f) for f in test_files]

In [22]:
print(train_files[:5])
print(test_files[:5])

['data/midi_dataset/midis\\Alkan, Charles-Valentin, 3 Improvisations dans le Style brillant, Op.12, ORT-ei5_X8w.mid', 'data/midi_dataset/midis\\Alkan, Charles-Valentin, Les mois, Op.74, Pzth1MU5JFY.mid', 'data/midi_dataset/midis\\Alkan, Charles-Valentin, 2e verset du 41e Psaume, S7-WAxfY3VM.mid', 'data/midi_dataset/midis\\Alkan, Charles-Valentin, Pour Monsieur Gurkhaus, L2ajYW7C75s.mid', 'data/midi_dataset/midis\\Alkan, Charles-Valentin, 2 Petites pièces, Op.60, jcvaJLXSC0c.mid']
['data/midi_dataset/midis\\Alkan, Charles-Valentin, Etude, WoO, c297e_yjlAQ.mid', 'data/midi_dataset/midis\\Alkan, Charles-Valentin, Réconciliation, Op.42, MGwcbrYFsiU.mid', "data/midi_dataset/midis\\Alkan, Charles-Valentin, Variations sur 'Ah ! segnata é la mia morte', Op.16 No.4, -sSHCvni-NU.mid", 'data/midi_dataset/midis\\Alkan, Charles-Valentin, 3 Petites fantaisies, Op.41, CptNa_Pa7b0.mid', 'data/midi_dataset/midis\\Alkan, Charles-Valentin, Salut, Cendre de Pauvre, Op.45, HxD_ReBdT-M.mid']


In [49]:
from src.dataloader.dataset import PianoRollClassificationDataset
import pickle


pitch_to_strip = (24, 84)  # Take  only pitches between 24 and 84 C1 to C6 (https://arxiv.org/pdf/1809.07600)

train_dataset = PianoRollClassificationDataset(
    midi_files=train_files,
    frame_per_second=64,
    verbose=True,
    strip_bounds=True,
    pitch_to_strip=pitch_to_strip
)

out_path = "data/preprocessed_classification_piano_roll.pkl"

pickle.dump(
    train_dataset,
    open(out_path, "wb"),
)
print(f"Dataset saved to {out_path}")

Got 1551 MIDI files in, frame rate set to 64 FPS.


100%|██████████| 1551/1551 [02:43<00:00,  9.48it/s]


{0: 'Alkan', 1: 'Bach', 2: 'Beatty', 3: 'Beethoven', 4: 'Carbajo', 5: 'Chopin', 6: 'Czerny', 7: 'Gottschalk', 8: 'Handel', 9: 'Haydn', 10: 'Liszt', 11: 'Mozart', 12: 'Rebikov', 13: 'Scarlatti', 14: 'Schubert', 15: 'Schumann', 16: 'Scott', 17: 'Scriabin', 18: 'Simpson', 19: 'Zhang'}
Initialized 73448 MIDI tensors with 20 unique composers.
Dataset saved to data/preprocessed_classification_piano_roll.pkl


In [46]:
from importlib import reload
from src import dataloader
reload(dataloader.dataset)
from src.dataloader.dataset import PianoRollClassificationDataset
import pickle


In [50]:
test_dataset = PianoRollClassificationDataset(
    midi_files=test_files,
    frame_per_second=64,
    verbose=True,
    strip_bounds=True,
    pitch_to_strip=pitch_to_strip
)

out_path = "data/preprocessed_classification_piano_roll_test.pkl"

pickle.dump(
    test_dataset,
    open(out_path, "wb"),
)
print(f"Dataset saved to {out_path}")

Got 398 MIDI files in, frame rate set to 64 FPS.


100%|██████████| 398/398 [00:47<00:00,  8.41it/s]


{0: 'Alkan', 1: 'Bach', 2: 'Beatty', 3: 'Beethoven', 4: 'Carbajo', 5: 'Chopin', 6: 'Czerny', 7: 'Gottschalk', 8: 'Handel', 9: 'Haydn', 10: 'Liszt', 11: 'Mozart', 12: 'Rebikov', 13: 'Scarlatti', 14: 'Schubert', 15: 'Schumann', 16: 'Scott', 17: 'Scriabin', 18: 'Simpson', 19: 'Zhang'}
Initialized 20584 MIDI tensors with 20 unique composers.
Dataset saved to data/preprocessed_classification_piano_roll_test.pkl


In [52]:
train_dataset: PianoRollClassificationDataset = pickle.load(open("data/preprocessed_classification_piano_roll.pkl", "rb"))

In [54]:
test_dataset: PianoRollClassificationDataset = pickle.load(open("data/preprocessed_classification_piano_roll_test.pkl", "rb"))

In [53]:
train_dataset.labels_mapping

{0: 'Alkan',
 1: 'Bach',
 2: 'Beatty',
 3: 'Beethoven',
 4: 'Carbajo',
 5: 'Chopin',
 6: 'Czerny',
 7: 'Gottschalk',
 8: 'Handel',
 9: 'Haydn',
 10: 'Liszt',
 11: 'Mozart',
 12: 'Rebikov',
 13: 'Scarlatti',
 14: 'Schubert',
 15: 'Schumann',
 16: 'Scott',
 17: 'Scriabin',
 18: 'Simpson',
 19: 'Zhang'}

In [55]:
test_dataset.labels_mapping

{0: 'Alkan',
 1: 'Bach',
 2: 'Beatty',
 3: 'Beethoven',
 4: 'Carbajo',
 5: 'Chopin',
 6: 'Czerny',
 7: 'Gottschalk',
 8: 'Handel',
 9: 'Haydn',
 10: 'Liszt',
 11: 'Mozart',
 12: 'Rebikov',
 13: 'Scarlatti',
 14: 'Schubert',
 15: 'Schumann',
 16: 'Scott',
 17: 'Scriabin',
 18: 'Simpson',
 19: 'Zhang'}

# VAE

In [56]:
from torch.utils.data import DataLoader

midi_dataloader_train = DataLoader(
    train_dataset,
    batch_size=64,
    pin_memory=True,
    num_workers=0,
    shuffle=True,
)

In [57]:
print(next(iter(midi_dataloader_train)))

[tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0

In [58]:
import torch
import torch.nn as nn

class PrintLayer(nn.Module):
    def __init__(self, name):
        super(PrintLayer, self).__init__()
        self.name = name

    def forward(self, x):
        print(f"{self.name}: {x.shape}")
        return x


class CNNVAE(nn.Module):
    def __init__(self, input_shape=(60, 640), latent_dim=256, print_shapes=False):
        super(CNNVAE, self).__init__()
        
        self.input_shape = input_shape  # (height=pitch, width=time)
        self.latent_dim = latent_dim
        self.print_shapes = print_shapes
        
        ### Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=9, stride=1, padding=4),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),  # -> 30x320

            PrintLayer("Conv1 Output") if print_shapes else nn.Identity(),
            
            nn.Conv2d(32, 64, kernel_size=9, stride=1, padding=4),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),  # -> 15x160

            PrintLayer("Conv2 Output") if print_shapes else nn.Identity(),
            
            nn.Conv2d(64, 128, kernel_size=9, stride=1, padding=4),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),  # -> 7x80

            PrintLayer("Conv3 Output") if print_shapes else nn.Identity(),
        )
        
        # Calculate dimensions before linear layer
        self._calculate_fc_dim()
        
        # Latent space
        self.fc_mu = nn.Linear(self.fc_dim, latent_dim)
        self.fc_logvar = nn.Linear(self.fc_dim, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, self.fc_dim)
        
        ### Decoder
        self.decoder = nn.Sequential(
            nn.Upsample(size=(160, 15), mode='nearest'),
            nn.Conv2d(128, 64, kernel_size=9, stride=1, padding=4),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            PrintLayer("Decoder Upsample 1 Output") if print_shapes else nn.Identity(),
            
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 32, kernel_size=9, stride=1, padding=4),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            PrintLayer("Decoder Upsample 2 Output") if print_shapes else nn.Identity(),
            
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(32, 1, kernel_size=9, stride=1, padding=4),
            nn.Sigmoid(),

            PrintLayer("Decoder Output") if print_shapes else nn.Identity(),
        )

    def _calculate_fc_dim(self):
        with torch.no_grad():
            dummy_input = torch.zeros(1, 1, *self.input_shape)
            dummy_output = self.encoder(dummy_input)
            self.fc_dim = dummy_output.numel() // dummy_output.shape[0]

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        z = self.fc_decode(z)
        z = z.view(z.size(0), 128, 80, 7)
        z = self.decoder(z)
        return z

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

In [None]:
import torch.nn.functional as F

class CNNVAELoss(nn.Module):
    def __init__(self, beta=1.0):
        super(CNNVAELoss, self).__init__()
        self.beta = beta

    def mse_loss(self, x, x_recon):
        return F.mse_loss(x_recon, x)
    
    def mae_loss(self, x, x_recon):
        return F.l1_loss(x_recon, x)

    def min_threshold_loss(self, y_true, y_pred, min_val=1/127, penalty_strength=10.0):
        significant_mask = torch.sigmoid((y_true - min_val) * 100)
        penalty = torch.exp(-penalty_strength * (y_pred / min_val))
        loss = significant_mask * penalty
        return loss.mean()

    def kl_divergence(self, mu, logvar):
        # KL divergence between N(mu, sigma^2) and N(0, 1)
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
    
    # def binary_penalty_loss(self, x, x_recon):
    #     """Penalty for model giving 0s where it should be non zero pitch value, or vice versa."""
    #     binary_x = (x > 0.1).float()
    #     binary_x_recon = (x_recon > 0.1).float()
    #     penalty = F.binary_cross_entropy(binary_x_recon, binary_x, reduction='none')
    #     return penalty.mean()

    def forward(self, x, x_recon, mu, logvar):
        mse = self.mse_loss(x, x_recon)
        penalty = 100 * self.min_threshold_loss(x, x_recon, min_val=0.016, penalty_strength=10.0)
        kl = self.kl_divergence(mu, logvar)
        total_loss = mse + penalty + self.beta * kl
        # return total_loss, mse, penalty, kl
        # mae = self.mae_loss(x, x_recon)
        # penalty = self.binary_penalty_loss(x, x_recon)
        # total_loss = mse + penalty + self.beta * kl
        return total_loss, mse, penalty, kl


In [None]:
from tqdm import tqdm

def train_vae(model, dataloader, epochs=10, lr=1e-3, patience=3, beta=1.0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = CNNVAELoss(beta=beta)

    best_loss = float('inf')
    epochs_no_improve = 0
    best_model_state = None

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        total_mse_loss = 0.0
        total_penalty_loss = 0.0
        total_kl_loss = 0.0

        iter_batch = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", unit="batch", leave=False)
        for batch in iter_batch:
            batch = batch.to(device)
            batch = batch.unsqueeze(1)  # Add channel dimension
            optimizer.zero_grad()
            x_recon, mu, logvar = model(batch)
            loss, mse_loss, penalty_loss, kl_loss = criterion(batch, x_recon, mu, logvar)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            total_mse_loss += mse_loss.item()
            total_penalty_loss += penalty_loss.item()
            total_kl_loss += kl_loss.item()
            iter_batch.set_postfix(
                loss=loss.item(),
                mse_loss=mse_loss.item(),
                penalty_loss=penalty_loss.item(),
                kl_loss=kl_loss.item()
            )

        avg_loss = total_loss / len(dataloader)
        avg_mse_loss = total_mse_loss / len(dataloader)
        avg_penalty_loss = total_penalty_loss / len(dataloader)
        avg_kl_loss = total_kl_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, MSE: {avg_mse_loss:.4f}, Penalty: {avg_penalty_loss:.4f}, KL: {avg_kl_loss:.4f}")

        # Early stopping check
        if avg_loss < best_loss:
            best_loss = avg_loss
            epochs_no_improve = 0
            best_model_state = model.state_dict()
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs.")
                break

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print("Best model restored.")

In [None]:
vae_model = CNNVAE(input_shape=(60, 640), latent_dim=256 * 2).cuda()
train_vae(vae_model, midi_dataloader, epochs=10, lr=0.0001, beta=0.1)

In [None]:
torch.save(vae_model, 'data/vae_cnn_v3.pth')