In [1]:
!pip install librosa pretty_midi
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install miditok
!pip install tqdm

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m41.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty_midi: filename=pretty_midi-0.2.10-py3-none-any.whl size=5592286 sha256=f2e50ba094857ee27c8cd182eeeb9225e8db85c8bf39276f1808fd20661eff84
  Stored in directory: /root/.cache/pip/wheels/e6/95/ac/15ceaeb2823b04d8e638fd1495357adb8d26c00ccac9d7782e
Successfully built pretty_midi
Installing collected packages: mido, pretty_midi
Successf

In [2]:
import os
import torch
import librosa
import pretty_midi
import numpy as np
from google.colab import drive
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn
import pandas as pd
import miditok
from miditok import TokSequence
from tqdm.auto import tqdm

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda


In [4]:
TOKENIZER_PARAMS = {
"pitch_range": (21, 109),
"beat_res": {(0,4):24, (4,12):12},
"special_tokens": ["PAD", "BOS", "EOS", "MASK"],
"use_chords": True,
"use_rests": False,
"use_tempos": True,
"use_time_signatures": True,
"use_programs": False,
"use_microtiming": False,
"ticks_per_quarter": 320,
"max_microtiming_shift": 0.125,
"num_microtiming_bins": 30,
}
config = miditok.TokenizerConfig(**TOKENIZER_PARAMS)

tokenizer=miditok.PerTok(config)

In [5]:
drive.mount('/content/drive')
DATA_ROOT='/content/drive/My Drive/song2cover'
PAD_TOKEN = tokenizer.pad_token_id
VocabSize = tokenizer.vocab_size

Mounted at /content/drive


In [6]:
def collate_fn(batch):
  xs, ys, zs = zip(*batch)



  y_max_len = max([y.shape[0] for y in ys])

  x_max_len = max([x.shape[2] for x in xs])

  y_padded = []
  for y in ys:
    pad_amt = y_max_len - y.shape[0]
    padded = F.pad(y, (0, pad_amt), value=PAD_TOKEN)
    y_padded.append(padded)

  x_padded = []
  for x in xs:
    pad_amt = x_max_len - x.shape[2]
    padded = F.pad(x, (0, pad_amt), value=0)
    x_padded.append(padded)

  x_batch = torch.stack(x_padded, dim=0)
  y_batch = torch.stack(y_padded, dim=0)
  z_batch = torch.stack(zs, dim=0)

  return x_batch, y_batch, z_batch

In [7]:
class AudioMidiDataset(Dataset):
  def __init__(self, root_dir, tokenizer, transform=None):
    self.root_dir = root_dir
    self.transform = transform
    df = pd.read_csv(os.path.join(root_dir, 'manifest.csv'))
    self.items = df.to_dict('records')
    self.tokenizer = tokenizer

  def __len__(self):
    return len(self.items)

  def __getitem__(self, idx):
    item = self.items[idx]
    audio_path = os.path.join(self.root_dir, item['audio_filepath'])
    midi_path = os.path.join(self.root_dir, item['midi_filepath'])
    diff = torch.tensor(float(item['difficulty'])).float()

    y, sr = librosa.load(audio_path, sr=44100)
    stft = librosa.stft(y, n_fft=1024, hop_length=256)
    log_stft = librosa.amplitude_to_db(np.abs(stft), ref=np.max)
    x = torch.tensor(log_stft).unsqueeze(0).float()

    tokens = self.tokenizer.encode(midi_path)
    y_tokens = torch.flatten(torch.tensor(tokens, dtype=torch.long))

    sample = {
        'x': x,
        'y': y_tokens,
        'diff': diff
    }

    if self.transform:
      sample = self.transform(sample)

    return x, y_tokens, diff

In [8]:
dataset = AudioMidiDataset(DATA_ROOT, tokenizer)

loader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

In [9]:
class AudioEncoder(nn.Module):
  def __init__(self, d_model=512):
    super().__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=3, stride=(2,2), padding=1),
        nn.ReLU(),
        nn.Conv2d(64, 128, kernel_size=3, stride=(2,2), padding=1),
        nn.ReLU(),
        nn.Conv2d(128, d_model, kernel_size=3, stride=(2,2), padding=1),
        nn.ReLU()
    )

    self.project = nn.Linear(in_features=d_model * 65, out_features=d_model)
    self.pos_enc = nn.Embedding(num_embeddings=2048, embedding_dim=d_model)

    enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=8, dim_feedforward=d_model * 4)
    self.encoder = nn.TransformerEncoder(enc_layer, num_layers=3)

  def forward(self, x):
    z = self.conv(x)

    B, D, F_enc, T_enc = z.shape

    enc_in = z.permute(0, 3, 1, 2).reshape(B, T_enc, D * F_enc)
    enc_proj = self.project(enc_in)

    positions = torch.arange(T_enc, device=x.device).unsqueeze(0)
    enc_pos = enc_proj + self.pos_enc(positions)

    enc_in_tp = enc_pos.permute(1,0,2)
    enc_out_tp = self.encoder(enc_in_tp)

    enc_out = enc_out_tp.permute(1,0,2)

    return enc_out

class TokenDecoder(nn.Module):
  def __init__(self, vocab_size, d_model=512, max_len=2048):
    super().__init__()

    self.token_emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
    self.pos_enb = nn.Embedding(num_embeddings=max_len, embedding_dim=d_model)
    self.d_model = d_model

    dec_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=8, dim_feedforward=d_model * 4)
    self.decoder = nn.TransformerDecoder(dec_layer, num_layers=3)

    self.output_fc = nn.Linear(d_model, vocab_size)

  def forward(self, y_input, memory):
    B, T_tgt = y_input.shape

    tok_embed = self.token_emb(y_input)

    pos = torch.arange(T_tgt, device = y_input.device).unsqueeze(0)
    pos_embed = self.pos_enb(pos)

    tgt = tok_embed + pos_embed

    tgt_mask = nn.Transformer.generate_square_subsequent_mask(T_tgt).to(y_input.device)

    tgt2 = tgt.permute(1,0,2)
    mem2 = memory.permute(1,0,2)

    dec_out2 = self.decoder(tgt2, mem2, tgt_mask=tgt_mask)

    dec_out = dec_out2.permute(1,0,2)

    logits = self.output_fc(dec_out)

    return logits


class Audio2Midi(nn.Module):
  def __init__(self, vocab_size, d_model=512, max_tok_len=2048):
    super().__init__()
    self.encoder = AudioEncoder(d_model=d_model)
    self.decoder = TokenDecoder(vocab_size, d_model=d_model, max_len=max_tok_len)

  def forward(self, x_audio, y_input):
    memory = self.encoder(x_audio)
    logits = self.decoder(y_input, memory)

    return logits



In [12]:
model = Audio2Midi(vocab_size=VocabSize, d_model=512, max_tok_len=2048).to(device)
optim = torch.optim.AdamW(model.parameters(), lr=1e-4)

criterion = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)

num_epochs = 10

for epoch in range(1, num_epochs+1):
  model.train()
  total_loss = 0.0

  batch_iterator = tqdm(loader, desc=f"Epoch {epoch:02d}", leave=False, bar_format='{l_bar}{bar} | {percentage:3.0f}%')

  for x_batch, y_batch, d_batch in batch_iterator:
      x_batch = x_batch.to(device)
      y_batch = y_batch.to(device)

      y_input  = y_batch[:, :-1]
      y_target = y_batch[:, 1:]

      optim.zero_grad()

      logits = model(x_batch, y_input)

      B, Tm, V = logits.shape
      loss = criterion(
          logits.reshape(-1, V),
          y_target.reshape(-1)
      )

      loss.backward()
      optim.step()

      total_loss += loss.item()
      batch_iterator.set_postfix(loss=f"{loss.item():.4f}")

  avg_loss = total_loss / len(loader)
  print(f"Epoch {epoch:02d} — Loss {avg_loss:.4f}")

torch.save({
  'epoch': epoch,
  'model_state_dict': model.state_dict(),
  'optimizer_state_dict': optim.state_dict()}, os.path.join(DATA_ROOT, 'checkpoint.pt'))




Epoch 01:   0%|           |   0%

Epoch 01 — Loss 5.0327


Epoch 02:   0%|           |   0%

Epoch 02 — Loss 4.0986


Epoch 03:   0%|           |   0%

Epoch 03 — Loss 3.4159


Epoch 04:   0%|           |   0%

Epoch 04 — Loss 2.9502


Epoch 05:   0%|           |   0%

Epoch 05 — Loss 2.7137


Epoch 06:   0%|           |   0%

Epoch 06 — Loss 2.5430


Epoch 07:   0%|           |   0%

Epoch 07 — Loss 2.4248


Epoch 08:   0%|           |   0%

Epoch 08 — Loss 2.3624


Epoch 09:   0%|           |   0%

Epoch 09 — Loss 2.2714


Epoch 10:   0%|           |   0%

Epoch 10 — Loss 2.2115


In [61]:
torch.save(model.state_dict(), os.path.join(DATA_ROOT, 'model.pt'))

In [17]:
def generate_midi(model, tokenizer, audio_path, max_len=2048, device='cpu'):
  model.eval()

  y, sr = librosa.load(audio_path, sr=44100)
  stft = librosa.stft(y, n_fft=1024, hop_length=256)
  log_stft = librosa.amplitude_to_db(np.abs(stft), ref=np.max)
  x = torch.tensor(log_stft, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

  with torch.no_grad():
    memory = model.encoder(x)

  generated = [tokenizer.special_tokens[1]]
  for _ in range(max_len):
    y_input = torch.tensor([tokenizer.vocab[t] for t in generated], dtype=torch.long).unsqueeze(0).to(device)
    with torch.no_grad():
      logits = model.decoder(y_input, memory)
    next_logits = logits[:, -1, :]
    next_id = next_logits.argmax(dim=-1).item()
    next_token = [key for key, val in tokenizer.vocab.items() if val == next_id][0]
    generated.append(next_token)
    if next_id == tokenizer.special_tokens[2]:
      break

  if generated[0] == tokenizer.special_tokens[1]:
    generated = generated[1:]
  if generated[-1] == tokenizer.special_tokens[2]:
    generated = generated[:-1]

  generated_seq = TokSequence(tokens=generated)
  print(generated_seq)
  tokens = tokenizer.decode([generated_seq])
  return tokens


In [18]:
path = os.path.join(DATA_ROOT, "generate.mp3")

tokens = generate_midi(model, tokenizer, path, device=device)

TokSequence(tokens=['Tempo_121.29', 'Pitch_49', 'Velocity_79', 'Duration_0.156.320', 'Pitch_56', 'Velocity_79', 'Duration_0.156.320', 'TimeShift_0.156.320', 'Pitch_51', 'Velocity_79', 'Duration_0.156.320', 'TimeShift_0.156.320', 'Pitch_51', 'Velocity_79', 'Duration_0.156.320', 'TimeShift_0.156.320', 'Pitch_51', 'Velocity_79', 'Duration_0.156.320', 'TimeShift_0.156.320', 'Pitch_51', 'Velocity_79', 'Duration_0.156.320', 'TimeShift_0.169.320', 'Pitch_51', 'Velocity_79', 'Duration_0.156.320', 'TimeShift_0.169.320', 'Pitch_51', 'Velocity_79', 'Duration_0.156.320', 'TimeShift_0.156.320', 'Pitch_51', 'Velocity_79', 'Duration_0.156.320', 'TimeShift_0.156.320', 'Pitch_51', 'Velocity_79', 'Duration_0.156.320', 'TimeShift_0.156.320', 'Pitch_51', 'Velocity_79', 'Duration_0.156.320', 'TimeShift_0.169.320', 'Pitch_51', 'Velocity_79', 'Duration_0.156.320', 'TimeShift_0.156.320', 'Pitch_51', 'Velocity_79', 'Duration_0.156.320', 'TimeShift_0.169.320', 'Pitch_51', 'Velocity_79', 'Duration_0.156.320', 'T

In [22]:
tokens.dump_midi(os.path.join(DATA_ROOT, 'generated.mid'))