In [None]:
import os, sys, shutil
from os import mkdir, path
from google.colab import drive
drive.mount('/content/drive')

In [None]:
NOTEBOOK_PATH = "/content/drive/MyDrive/Colab\ Notebooks/"
DATASET_PATH = "/media/daftpunk3/home/soonbeom/Dataset/URMPPlus/edm_violin"

In [None]:
%cd $NOTEBOOK_PATH

symlink_path = '/content/notebooks'
if path.exists(symlink_path):
    shutil.rmtree(symlink_path)

os.symlink(path.join(NOTEBOOK_PATH, 'env'), symlink_path)
sys.path.insert(0, symlink_path)

!pip install --target=$symlink_path pretty_midi

In [None]:
!git clone https://github.com/SoonbeomChoi/miniSynth.git
%cd miniSynth
!git pull origin main
!git merge

In [None]:
import torch
import torch.nn as nn
import librosa
import matplotlib.pyplot as plt
import IPython.display as ipd
from tqdm import tqdm

import config, data, preprocess, vocoder
from model import Model

In [None]:
# Run preprocess
preprocess.run(DATASET_PATH)

In [None]:
# Train model
if not path.exists(config.exp_path):
    mkdir(config.exp_path)

dataloader = data.load()
model = Model().cuda()
optimizer = torch.optim.Adam(model.parameters(), config.learning_rate)
criterion = nn.L1Loss()

for epoch in range(int(config.stop_step/config.save_step)):
    model.train()
    progress_bar = tqdm(range(config.save_step), leave=False)
    for _ in progress_bar:
        note, mel = next(dataloader['train'])
        note = note.cuda()
        mel = mel.cuda()
        
        optimizer.zero_grad()
        mel_gen = model(note)
        loss = criterion(mel_gen, mel)
        loss.backward()
        optimizer.step()
        progress_bar.set_description(f"Loss - {loss.item():.4f}")

    torch.save(model.state_dict(), path.join(config.exp_path, 'checkpoint.pt'))

In [None]:
# Select MIDI file
MIDI_FILE = "/content/drive/MyDrive/Academic Share/BigEDMViolin/mid/BEVFE_19_Violin_80_BPM_G.mid"
data = preprocess.preprocess(MIDI_FILE, test=True)

In [None]:
# Synthesize audio
model.load_state_dict(torch.load(path.join(config.exp_path, 'checkpoint.pt')))
model.eval()

audio = []
for d in data:
    with torch.no_grad():
        note = d['note'].unsqueeze(0).cuda()
        mel_pred = model(note)
        audio.append(vocoder.run(mel_pred))

audio = torch.cat(audio, dim=-1).cpu().numpy()

In [None]:
plt.figure(figsize=(14, 4))
librosa.display.waveplot(audio, sr=config.sample_rate)

In [None]:
ipd.Audio(audio, rate=config.sample_rate)