In [4]:
import os
import sys
def is_kaggle():
    return 'KAGGLE_KERNEL_RUN_TYPE' in os.environ

# Usage
if is_kaggle():
    print("Running on Kaggle")
    sys.path.append('/kaggle/input/urex-helperscripts')
    sys.path.append('/kaggle/input/pop909-midis')
    !pip install pretty_midi
    !pip install miditok
else:
    print("Not running on Kaggle")

Not running on Kaggle


In [5]:
from pathlib import Path
from tqdm import tqdm

import torch  # Ensure PyTorch is imported
import torch.optim as optim
from torch.utils.data import DataLoader

import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence


from miditok import REMI
from miditok.pytorch_data import DatasetMIDI, DataCollator

from vae import BetaVAE


In [6]:
use_wandb = True

if use_wandb:
    import wandb
    if is_kaggle():
        # !pip install wandb
        from kaggle_secrets import UserSecretsClient
        user_secrets = UserSecretsClient()
        wandb_api_key = user_secrets.get_secret("wandb_api_key")

    else:
        wandb_api_key = os.getenv("wandb_api_key")
    
    wandb.login(key=wandb_api_key)
        

    wandb.init(
        project="midi-vae", 
        entity="midi-vae", 
        name="testing"  # Set your run name here
    )

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmarcusongqy[0m ([33mmidi-vae[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
# aux fns

def midi_data_loader(folder, shuffle=True):
    tokenizer = REMI()  # using defaults parameters
    midi_paths = [path.resolve() for path in Path(folder).rglob("*.mid")][:100] # to limit files actually used, for testing purposes

    dataset = DatasetMIDI(
        files_paths=midi_paths,
        tokenizer=tokenizer,
        max_seq_len=1024,
        bos_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer["BOS_None"],
    )
    collator = DataCollator(tokenizer.pad_token_id)
    data_loader = DataLoader(dataset=dataset, collate_fn=collator, batch_size=32, shuffle=shuffle)

    return data_loader

def collate_fn(batch):
    # Pad sequences to the same length
    batch = pad_sequence(batch, batch_first=True, padding_value=0)
    # Add a feature dimension
    batch = batch.unsqueeze(-1)
    return batch

In [8]:
# Define model parameters
input_dim = 128           # Each time step has 128 features (piano roll)
hidden_dim = 256          # GRU hidden dimension
latent_dim = 64           # Size of the latent space
beta = 4.0                # Adjust beta for stronger disentanglement

batch_size = 32

model = BetaVAE(input_dim, hidden_dim, latent_dim, beta=beta)

learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

num_epochs = 10

if use_wandb:
    config_dict = {
        "input_dim": input_dim,
        "hidden_dim": hidden_dim,
        "latent_dim": latent_dim,
        "beta": beta,
        "batch_size": batch_size,
        "learning_rate": learning_rate,
        "num_epochs": num_epochs,
        "optimizer": optimizer.__class__.__name__,  # Store the optimizer type
    }
    # Update wandb config
    wandb.config.update(config_dict)

In [9]:
train_data_dir = "dataset_train"
test_data_dir = "dataset_test"

if is_kaggle():
    train_data_dir = "/kaggle/input/pop909-midis/" + train_data_dir
    test_data_dir = "/kaggle/input/pop909-midis/" + test_data_dir

train_data_loader = midi_data_loader(train_data_dir, shuffle=True)
test_data_loader = midi_data_loader(test_data_dir, shuffle=False)

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print("Using device:", device)

Using device: cpu


In [11]:
# Create an embedding layer (vocab_size depends on your tokenizer)
vocab_size = len(train_data_loader.dataset.tokenizer.vocab)
embedding_layer = nn.Embedding(vocab_size, input_dim).to(device)

model.train()

BetaVAE(
  (encoder_gru): GRU(128, 256, num_layers=2, batch_first=True)
  (fc_mu): Linear(in_features=256, out_features=64, bias=True)
  (fc_logvar): Linear(in_features=256, out_features=64, bias=True)
  (fc_latent_to_hidden): Linear(in_features=64, out_features=256, bias=True)
  (decoder_gru): GRU(128, 256, num_layers=2, batch_first=True)
  (output_layer): Linear(in_features=256, out_features=128, bias=True)
)

In [12]:
# main training loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    total_loss = 0.0
    
    for batch in tqdm(train_data_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch"):
        optimizer.zero_grad()
        
        tokens = batch["input_ids"].to(device)
        embedded = embedding_layer(tokens.long())  # shape: (batch, seq_len, input_dim)
        
        # Forward pass through VAE
        recon_x, mu, logvar = model(embedded.float())
        
        # Compute loss
        loss, _, _ = model.loss_function(recon_x, embedded.float(), mu, logvar)
        
        # Backprop
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_data_loader)
    print(f"Average training loss: {avg_loss:.4f}")
    
    if use_wandb:
        # Log the loss to wandb
        wandb.log({"epoch": epoch + 1, "avg_loss": avg_loss})

Epoch 1/10


Epoch 1/10: 100%|██████████| 4/4 [00:29<00:00,  7.43s/batch]


Average training loss: 23.8699
Epoch 2/10


Epoch 2/10: 100%|██████████| 4/4 [00:33<00:00,  8.31s/batch]


Average training loss: 7.5967
Epoch 3/10


Epoch 3/10: 100%|██████████| 4/4 [00:29<00:00,  7.31s/batch]


Average training loss: 5.0029
Epoch 4/10


Epoch 4/10: 100%|██████████| 4/4 [00:27<00:00,  6.97s/batch]


Average training loss: 3.9621
Epoch 5/10


Epoch 5/10: 100%|██████████| 4/4 [00:26<00:00,  6.71s/batch]


Average training loss: 2.7086
Epoch 6/10


Epoch 6/10: 100%|██████████| 4/4 [00:27<00:00,  6.91s/batch]


Average training loss: 2.2318
Epoch 7/10


Epoch 7/10: 100%|██████████| 4/4 [00:26<00:00,  6.64s/batch]


Average training loss: 1.8572
Epoch 8/10


Epoch 8/10: 100%|██████████| 4/4 [00:27<00:00,  6.79s/batch]


Average training loss: 1.6368
Epoch 9/10


Epoch 9/10: 100%|██████████| 4/4 [00:26<00:00,  6.62s/batch]


Average training loss: 1.4879
Epoch 10/10


Epoch 10/10: 100%|██████████| 4/4 [00:27<00:00,  6.92s/batch]

Average training loss: 1.4100





In [15]:
# saving the model at the end of training
torch.save(model.state_dict(), "model.pth")
if use_wandb and is_kaggle():
    wandb.save("model.pth")

In [16]:
model.eval()
with torch.no_grad():
    total_test_loss = 0.
    for batch in test_data_loader:
        tokens = batch["input_ids"].to(device)
        embedded = embedding_layer(tokens.long())  # (batch, seq_len, input_dim)

        recon_x, mu, logvar = model(embedded.float())
        loss, _, _ = model.loss_function(recon_x, embedded.float(), mu, logvar)
        total_test_loss += loss.item()
        
    avg_test_loss = total_test_loss / len(test_data_loader)
    print(f"Average test loss: {avg_test_loss:.4f}")
    
    if use_wandb:
        wandb.log({"average test loss": avg_test_loss})

Average test loss: 4.8111


In [17]:
if use_wandb:
    wandb.finish()

0,1
average test loss,▁
avg_loss,█▃▂▂▁▁▁▁▁▁
epoch,▁▂▃▃▄▅▆▆▇█

0,1
average test loss,4.8111
avg_loss,1.40999
epoch,10.0


In [18]:
# Example function to generate from an existing tokenized MIDI file
def generate_from_token_file(test_midi_file_path, test_output_file_path):
    # Create a small dataset/loader from the single file
    tokenizer = REMI()
    single_dataset = DatasetMIDI(
        files_paths=[Path(test_midi_file_path)],
        tokenizer=tokenizer,
        max_seq_len=1024,
        bos_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer["BOS_None"],
    )
    collator = DataCollator(tokenizer.pad_token_id)
    single_loader = DataLoader(single_dataset, batch_size=1, shuffle=False, collate_fn=collator)

    model.eval()
    with torch.no_grad():
        batch = next(iter(single_loader))
        tokens = batch["input_ids"].to(device)  # shape: (1, seq_len)
        
        # Embed tokens
        embedded = embedding_layer(tokens.long())  # shape: (1, seq_len, input_dim)
        
        # Encode to latent
        mu, logvar = model.encode(embedded.float())
        z = model.reparameterize(mu, logvar)
        
        # Decode back to feature vectors
        decoded = model.decode(z, seq_len=256)  # pick a sequence length
        predicted_tokens = torch.argmax(decoded, dim=-1)  # shape: (1, seq_len)

        # convert predicted tokens to a plain Python list, so that __ids_to_tokens can read it
        predicted_tokens = predicted_tokens.squeeze().tolist()
        
        # Convert integers to token strings
        token_strings = tokenizer._ids_to_tokens(predicted_tokens)
        # Convert token strings back to MIDI
        generated_midi = tokenizer([token_strings])
        # print(len(tokens))
        generated_midi.dump_midi(Path(test_output_file_path))

if is_kaggle():
    test_midi_file_path = "/kaggle/input/pop909-midis/dataset_valid/001_t0_0.mid"
    test_output_file_path = "/kaggle/working/trained_decoded_estimate.mid"
else:
    test_midi_file_path = "dataset_valid/001_t0_0.mid"
    test_output_file_path = "trained_decoded_estimate.mid"

generate_from_token_file(test_midi_file_path, test_output_file_path)