Skip to content

Commit

Permalink
Merge pull request #4 from andreasjansson/main
Browse files Browse the repository at this point in the history
Add Cog config and demo link
  • Loading branch information
annahung31 committed Aug 29, 2021
2 parents a345ea8 + 79ed89f commit e468255
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ This is the official repository of **EMOPIA: A Multi-Modal Pop Piano Dataset For

- [Paper on Arxiv](https://arxiv.org/abs/2108.01374)
- [Demo Page](https://annahung31.github.io/EMOPIA/)
- [Interactive demo and Docker image on Replicate](https://replicate.ai/annahung31/emopia)
- [Dataset at Zenodo](https://zenodo.org/record/5090631#.YPPo-JMzZz8)

* Note: We release the transcribed MIDI files. As for the audio part, due to the copyright issue, we will only release the YouTube ID of the tracks and the timestamp of them. You might use [open source crawler](https://github.com/ytdl-org/youtube-dl) to get the audio file.
Expand Down
22 changes: 22 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
predict: "predict.py:Predictor"
build:
gpu: true
system_packages:
- "ffmpeg"
- "fluidsynth"
python_packages:
- "torch==1.7.0"
- "scikit-learn==0.24.1"
- "seaborn==0.11.1"
- "numpy==1.19.5"
- "miditoolkit==0.1.14"
- "pandas==1.1.5"
- "tqdm==4.62.2"
- "matplotlib==3.4.3"
- "scipy==1.7.1"
- "midiSynth==0.3"
- "wheel==0.37.0"
- "ipdb===0.13.9"
- "pyfluidsynth==1.3.0"
pre_install:
- "pip install pytorch-fast-transformers==0.4.0" # needs to be installed after the main pip install
90 changes: 90 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# based on workspace/transformer/generate.ipynb

import subprocess
from pathlib import Path
import tempfile
import os
import pickle
import sys
import torch
import numpy as np
from midiSynth.synth import MidiSynth
import cog

sys.path.insert(0, "workspace/transformer")
from utils import write_midi
from models import TransformerModel


EMOTIONS = {
"High valence, high arousal": 1,
"Low valence, high arousal": 2,
"Low valence, low arousal": 3,
"High valence, low arousal": 4,
}


class Predictor(cog.Predictor):
def setup(self):
print("Loading dictionary...")
path_dictionary = "dataset/co-representation/dictionary.pkl"
with open(path_dictionary, "rb") as f:
self.dictionary = pickle.load(f)
event2word, self.word2event = self.dictionary

n_class = [] # num of classes for each token
for key in event2word.keys():
n_class.append(len(event2word[key]))
n_token = len(n_class)

print("Loading model...")
path_saved_ckpt = "exp/pretrained_transformer/loss_25_params.pt"
self.net = TransformerModel(n_class, is_training=False)
self.net.cuda()
self.net.eval()

self.net.load_state_dict(torch.load(path_saved_ckpt))

self.midi_synth = MidiSynth()

@cog.input(
"emotion",
type=str,
default="High valence, high arousal",
options=EMOTIONS.keys(),
help="Emotion to generate for",
)
@cog.input("seed", type=int, default=-1, help="Random seed, -1 for random")
def predict(self, emotion, seed):
if seed < 0:
seed = int.from_bytes(os.urandom(2), "big")
torch.manual_seed(seed)
np.random.seed(seed)
print(f"Prediction seed: {seed}")

out_dir = Path(tempfile.mkdtemp())
midi_path = out_dir / "out.midi"
wav_path = out_dir / "out.wav"
mp3_path = out_dir / "out.mp3"

emotion_tag = EMOTIONS[emotion]
res, _ = self.net.inference_from_scratch(
self.dictionary, emotion_tag, n_token=8
)
try:
write_midi(res, str(midi_path), self.word2event)
self.midi_synth.midi2audio(str(midi_path), str(wav_path))
subprocess.check_output(
[
"ffmpeg",
"-i",
str(wav_path),
"-af",
"silenceremove=1:0:-50dB,aformat=dblp,areverse,silenceremove=1:0:-50dB,aformat=dblp,areverse", # strip silence
str(mp3_path),
],
)
return mp3_path
finally:
midi_path.unlink(missing_ok=True)
wav_path.unlink(missing_ok=True)

0 comments on commit e468255

Please sign in to comment.