In [1]:
!pip install mlflow torch tqdm

Collecting mlflow
  Downloading mlflow-2.22.0-py3-none-any.whl.metadata (30 kB)
Collecting mlflow-skinny==2.22.0 (from mlflow)
  Downloading mlflow_skinny-2.22.0-py3-none-any.whl.metadata (31 kB)
Collecting alembic!=1.10.0,<2 (from mlflow)
  Downloading alembic-1.15.2-py3-none-any.whl.metadata (7.3 kB)
Collecting docker<8,>=4.0.0 (from mlflow)
  Downloading docker-7.1.0-py3-none-any.whl.metadata (3.8 kB)
Collecting graphene<4 (from mlflow)
  Downloading graphene-3.4.3-py2.py3-none-any.whl.metadata (6.9 kB)
Collecting gunicorn<24 (from mlflow)
  Downloading gunicorn-23.0.0-py3-none-any.whl.metadata (4.4 kB)
Collecting databricks-sdk<1,>=0.20.0 (from mlflow-skinny==2.22.0->mlflow)
  Downloading databricks_sdk-0.50.0-py3-none-any.whl.metadata (38 kB)
Collecting fastapi<1 (from mlflow-skinny==2.22.0->mlflow)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn<1 (from mlflow-skinny==2.22.0->mlflow)
  Downloading uvicorn-0.34.2-py3-none-any.whl.metadata (6.5 k

In [4]:
!7z x /content/midi_files.7z -o./content/midi


7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21
p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,2 CPUs Intel(R) Xeon(R) CPU @ 2.20GHz (406F0),ASM,AES-NI)

Scanning the drive for archives:
  0M Scan /content/                   1 file, 2609800 bytes (2549 KiB)

Extracting archive: /content/midi_files.7z
--
Path = /content/midi_files.7z
Type = 7z
Physical Size = 2609800
Headers Size = 255666
Method = LZMA:23
Solid = +
Blocks = 1

  0%      0% 855 - midi_files/evaluation/midi/140410.mid                                                  0% 1371 - midi_files/evaluation/midi/1623626.mid                                                    0% 1751 - midi_files/evaluation/midi/175975.mid

In [9]:
!pip install miditoolkit mido

Collecting miditoolkit
  Downloading miditoolkit-1.0.1-py3-none-any.whl.metadata (4.9 kB)
Collecting mido
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading miditoolkit-1.0.1-py3-none-any.whl (24 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mido, miditoolkit
Successfully installed miditoolkit-1.0.1 mido-1.3.3


In [10]:


import mlflow
import mlflow.pytorch

import torch
import torch.nn as nn


from tqdm.auto import tqdm

from tokenizer7 import PerfectMIDITokenizer


In [26]:
DATA_DIR = Path("./content/midi")
TRAIN_DIR = Path("./content/midi/midi_files/train/midi")
VAL_DIR   = Path("./content/midi/midi_files/validation/midi")
TEST_DIR  = Path("./content/midi/midi_files/evaluation/midi")
VOCAB_PATH = Path("./vocab.json")
CHECKPOINT_DIR = Path("./checkpoints")
CHECKPOINT_DIR.mkdir(exist_ok=True)

In [27]:
CONFIG = {
    #tokenizer
    "ticks_per_beat": 480,
    "time_step": 30,
    "max_shift": 480 * 4,
    "tempo_step": 5000,
    "min_tempo": 200000,
    "max_tempo": 1000000,
    "use_velocity": False,
    "add_bar_tokens": True,
    # Model
    "vocab_size": None,  #tokenizer will set it up
    "d_model": 512,
    "nhead": 8,
    "num_layers": 6,
    "dim_feedforward": 2048,
    # Training
    "max_seq_len": 1024,
    "batch_size": 8,
    "lr": 1e-4,
    "epochs": 20,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

In [28]:
tokenizer = PerfectMIDITokenizer(
    ticks_per_beat=CONFIG["ticks_per_beat"],
    time_step=CONFIG["time_step"],
    max_shift=CONFIG["max_shift"],
    tempo_step=CONFIG["tempo_step"],
    min_tempo=CONFIG["min_tempo"],
    max_tempo=CONFIG["max_tempo"],
    use_velocity=CONFIG["use_velocity"],
    add_bar_tokens=CONFIG["add_bar_tokens"],
)

In [29]:
tokenizer.save_vocab(VOCAB_PATH)
CONFIG["vocab_size"] = tokenizer.vocab_size
print(f"Vocab size: {tokenizer.vocab_size}")


Vocab size: 521


In [30]:
def load_and_encode(folder: Path, tokenizer, max_len: int):
    sequences = []
    for midi_file in folder.glob("*.mid"):
        ids = tokenizer.encode(str(midi_file))
        # Truncate or pad
        if len(ids) > max_len:
            ids = ids[:max_len]
        else:
            ids += [tokenizer.pad_token_id] * (max_len - len(ids))
        sequences.append(ids)
    return sequences

train_seqs = load_and_encode(TRAIN_DIR, tokenizer, CONFIG["max_seq_len"])
val_seqs   = load_and_encode(VAL_DIR, tokenizer, CONFIG["max_seq_len"])
print(f"Loaded {len(train_seqs)} train, {len(val_seqs)} val sequences")


Loaded 9552 train, 2400 val sequences


In [31]:
class MidiSequenceDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences
    def __len__(self): return len(self.sequences)
    def __getitem__(self, idx):
        seq = torch.tensor(self.sequences[idx], dtype=torch.long)
        return {
            "input_ids": seq[:-1],
            "target_ids": seq[1:],
        }

train_dataset = MidiSequenceDataset(train_seqs)
val_dataset   = MidiSequenceDataset(val_seqs)
train_loader  = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
val_loader    = DataLoader(val_dataset, batch_size=CONFIG["batch_size"])


In [32]:
class TransformerCLM(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward, max_seq_len):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb   = nn.Embedding(max_seq_len, d_model)
        transformer_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=num_layers)
        self.lm_head = nn.Linear(d_model, vocab_size)
        self.max_seq_len = max_seq_len

    def forward(self, input_ids):
        bsz, seq_len = input_ids.size()
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        x = self.token_emb(input_ids) + self.pos_emb(positions)
        # causal mask
        mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1).bool()
        x = self.transformer(x, mask=mask)
        logits = self.lm_head(x)
        return logits

model = TransformerCLM(
    vocab_size=CONFIG["vocab_size"],
    d_model=CONFIG["d_model"],
    nhead=CONFIG["nhead"],
    num_layers=CONFIG["num_layers"],
    dim_feedforward=CONFIG["dim_feedforward"],
    max_seq_len=CONFIG["max_seq_len"],
).to(CONFIG["device"])

In [33]:

# 7. Training & validation loops
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"])
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

mlflow.set_experiment("midi_transformer_clm")
with mlflow.start_run():
    # Log hyperparameters
    mlflow.log_params(CONFIG)

    best_val_loss = float('inf')
    for epoch in range(1, CONFIG["epochs"] + 1):
        model.train()
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch} [Train]"):
            input_ids = batch["input_ids"].to(CONFIG["device"])
            target_ids= batch["target_ids"].to(CONFIG["device"])
            optimizer.zero_grad()
            logits = model(input_ids)
            loss = criterion(logits.view(-1, CONFIG["vocab_size"]), target_ids.view(-1))
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch} [Val]"):
                input_ids = batch["input_ids"].to(CONFIG["device"])
                target_ids= batch["target_ids"].to(CONFIG["device"])
                logits = model(input_ids)
                loss = criterion(logits.view(-1, CONFIG["vocab_size"]), target_ids.view(-1))
                val_loss += loss.item()
        val_loss /= len(val_loader)

        print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
        mlflow.log_metrics({"train_loss": train_loss, "val_loss": val_loss}, step=epoch)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            ckpt_path = CHECKPOINT_DIR / f"best_epoch{epoch}.pt"
            torch.save(model.state_dict(), ckpt_path)
            mlflow.log_artifact(str(ckpt_path))


    mlflow.pytorch.log_model(model, "transformer_clm_model")

print("Training complete.")


2025/04/29 13:55:10 INFO mlflow.tracking.fluent: Experiment with name 'midi_transformer_clm' does not exist. Creating a new experiment.


Epoch 1 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 1 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 1: train_loss=1.6903, val_loss=1.1119


Epoch 2 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 2 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 2: train_loss=1.0370, val_loss=0.9017


Epoch 3 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 3 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 3: train_loss=0.9169, val_loss=0.8358


Epoch 4 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 4 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 4: train_loss=0.8443, val_loss=0.7535


Epoch 5 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 5 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 5: train_loss=0.7665, val_loss=0.6805


Epoch 6 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 6 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 6: train_loss=0.7066, val_loss=0.6460


Epoch 7 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 7 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 7: train_loss=0.6663, val_loss=0.6138


Epoch 8 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 8 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 8: train_loss=0.6380, val_loss=0.5964


Epoch 9 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 9 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 9: train_loss=0.6172, val_loss=0.5824


Epoch 10 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 10 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 10: train_loss=0.6015, val_loss=0.5798


Epoch 11 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 11 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 11: train_loss=0.5900, val_loss=0.5666


Epoch 12 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 12 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 12: train_loss=0.5783, val_loss=0.5608


Epoch 13 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 13 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 13: train_loss=0.5681, val_loss=0.5546


Epoch 14 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 14 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 14: train_loss=0.5602, val_loss=0.5510


Epoch 15 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 15 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 15: train_loss=0.5512, val_loss=0.5457


Epoch 16 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 16 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 16: train_loss=0.5434, val_loss=0.5397


Epoch 17 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 17 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 17: train_loss=0.5357, val_loss=0.5355


Epoch 18 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 18 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 18: train_loss=0.5292, val_loss=0.5400


Epoch 19 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 19 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 19: train_loss=0.5237, val_loss=0.5390


Epoch 20 [Train]:   0%|          | 0/1194 [00:00<?, ?it/s]

Epoch 20 [Val]:   0%|          | 0/300 [00:00<?, ?it/s]

Epoch 20: train_loss=0.5169, val_loss=0.5303




Training complete.


###Inference: Generate MIDI Sequences





### 1) Reload the tokenizer & model (optional)

In [None]:

# tokenizer = PerfectMIDITokenizer(
#     ticks_per_beat=CONFIG["ticks_per_beat"],
#     time_step=CONFIG["time_step"],
#     max_shift=CONFIG["max_shift"],
#     tempo_step=CONFIG["tempo_step"],
#     min_tempo=CONFIG["min_tempo"],
#     max_tempo=CONFIG["max_tempo"],
#     use_velocity=CONFIG["use_velocity"],
#     add_bar_tokens=CONFIG["add_bar_tokens"],
# )
# tokenizer.load_vocab(str(VOCAB_PATH))


# model = TransformerCLM(
#     vocab_size=tokenizer.vocab_size,
#     d_model=CONFIG["d_model"],
#     nhead=CONFIG["nhead"],
#     num_layers=CONFIG["num_layers"],
#     dim_feedforward=CONFIG["dim_feedforward"],
#     max_seq_len=CONFIG["max_seq_len"],
# ).to(CONFIG["device"])

# find most recent checkpoint
# ckpts = sorted(CHECKPOINT_DIR.glob("best_epoch*.pt"), key=lambda p: p.stat().st_mtime)
# assert ckpts, "No checkpoints found!"
# latest_ckpt = ckpts[-1]
# model.load_state_dict(torch.load(latest_ckpt, map_location=CONFIG["device"]))
# model.eval()
# print(f"Loaded checkpoint: {latest_ckpt.name}")





### 2) Sampling Function

In [34]:

@torch.no_grad()
def generate(
    model: TransformerCLM,
    tokenizer: PerfectMIDITokenizer,
    prefix_ids: list[int],
    max_length: int = 512,
    temperature: float = 1.0,
    top_k: int = 50
) -> list[int]:
    generated = prefix_ids.copy()
    for _ in range(max_length - len(prefix_ids)):
        inp = torch.tensor([generated], device=CONFIG["device"])
        logits = model(inp)               # (1, seq_len, vocab_size)
        next_logits = logits[0, -1, :] / temperature

        # top-k filtering
        topk = torch.topk(next_logits, top_k)
        indices, values = topk.indices, topk.values
        probs = torch.softmax(values, dim=-1)
        next_token = indices[torch.multinomial(probs, 1)].item()

        generated.append(next_token)
        if next_token == tokenizer.eos_token_id:
            break
    return generated


### 3) Prepare a seed & Generate


In [92]:
# seed = [tokenizer.bos_token_id]
# b) Or prime with an existing short MIDI

seed = tokenizer.tokens_to_ids(tokenizer.encode_to_tokens("Fur Elise.mid")[:73])

print("Seed IDs:", seed)

gen_ids = generate(
    model,
    tokenizer,
    prefix_ids=seed,
    max_length=CONFIG["max_seq_len"],
    temperature=1.5,
    top_k=50
)

gen_tokens = tokenizer.ids_to_tokens(gen_ids)
print("Generated token sequence (first 50):", gen_tokens)


Seed IDs: [1, 3, 3, 3, 260, 3, 260, 3, 260, 3, 260, 3, 260, 3, 260, 3, 260, 156, 260, 156, 262, 157, 260, 157, 260, 154, 260, 154, 262, 155, 260, 155, 260, 156, 260, 156, 262, 157, 260, 157, 260, 154, 260, 154, 262, 155, 260, 155, 260, 156, 260, 156, 262, 157, 260, 157, 260, 146, 260, 146, 262, 147, 260, 147, 260, 152, 260, 152, 262, 153, 260, 153, 260]
Generated token sequence (first 50): ['<BOS>', '<UNK>', '<UNK>', '<UNK>', 'TIME_SHIFT_30', '<UNK>', 'TIME_SHIFT_30', '<UNK>', 'TIME_SHIFT_30', '<UNK>', 'TIME_SHIFT_30', '<UNK>', 'TIME_SHIFT_30', '<UNK>', 'TIME_SHIFT_30', '<UNK>', 'TIME_SHIFT_30', 'NOTE_ON_76', 'TIME_SHIFT_30', 'NOTE_ON_76', 'TIME_SHIFT_90', 'NOTE_OFF_76', 'TIME_SHIFT_30', 'NOTE_OFF_76', 'TIME_SHIFT_30', 'NOTE_ON_75', 'TIME_SHIFT_30', 'NOTE_ON_75', 'TIME_SHIFT_90', 'NOTE_OFF_75', 'TIME_SHIFT_30', 'NOTE_OFF_75', 'TIME_SHIFT_30', 'NOTE_ON_76', 'TIME_SHIFT_30', 'NOTE_ON_76', 'TIME_SHIFT_90', 'NOTE_OFF_76', 'TIME_SHIFT_30', 'NOTE_OFF_76', 'TIME_SHIFT_30', 'NOTE_ON_75', 'TIME

In [93]:
# tokenizer.decode(seed, "elise_seed75.mid")

output_path = "generated_elise73-60.mid"
tokenizer.decode(gen_ids, output_path)
print(f"🎹 New MIDI written to {output_path}")


🎹 New MIDI written to generated_elise73-60.mid


In [81]:
!zip -r mlruns.zip mlruns

  adding: mlruns/ (stored 0%)
  adding: mlruns/0/ (stored 0%)
  adding: mlruns/0/meta.yaml (deflated 25%)
  adding: mlruns/388485100549338541/ (stored 0%)
  adding: mlruns/388485100549338541/3f9028efc6bc4de8a153602f609aa4cf/ (stored 0%)
  adding: mlruns/388485100549338541/3f9028efc6bc4de8a153602f609aa4cf/tags/ (stored 0%)
  adding: mlruns/388485100549338541/3f9028efc6bc4de8a153602f609aa4cf/tags/mlflow.runName (stored 0%)
  adding: mlruns/388485100549338541/3f9028efc6bc4de8a153602f609aa4cf/tags/mlflow.user (stored 0%)
  adding: mlruns/388485100549338541/3f9028efc6bc4de8a153602f609aa4cf/tags/mlflow.source.type (stored 0%)
  adding: mlruns/388485100549338541/3f9028efc6bc4de8a153602f609aa4cf/tags/mlflow.source.name (deflated 5%)
  adding: mlruns/388485100549338541/3f9028efc6bc4de8a153602f609aa4cf/tags/mlflow.log-model.history (deflated 43%)
  adding: mlruns/388485100549338541/3f9028efc6bc4de8a153602f609aa4cf/artifacts/ (stored 0%)
  adding: mlruns/388485100549338541/3f9028efc6bc4de8a153602