Skip to content

Commit

Permalink
initial upload
Browse files Browse the repository at this point in the history
  • Loading branch information
wayne391 committed Dec 8, 2020
1 parent f21d663 commit a83b7ab
Show file tree
Hide file tree
Showing 123 changed files with 659 additions and 2,452 deletions.
Binary file removed .DS_Store
Binary file not shown.
49 changes: 47 additions & 2 deletions README.md
@@ -1,2 +1,47 @@
# compound-word-transformer
implementation of compound word transformer
# Compound Word Transformer


Authors: [Wen-Yi Hsiao](), [Jen-Yu Liu](), [Yin-Cheng Yeh](), [Yi-Hsuan Yang]()

[**Paper (arXiv)**]() | [**Audio demo (Google Drive)**]() |

Officail PyTorch implementation of AAAI2021 paper "Compound Word Transformer: Learning to Compose Full-Song Musicover Dynamic Directed Hypergraphs".

We presented a new variant of the Transformer that can processes multiple consecutive tokens at once at a time step. The proposed method can greatly reduce the length of the resulting sequence and therefore enhance the training and inference efficiency. We employ it to learn to compose expressive Pop piano music of full-song length (involving up to 10K individual to23 kens per song). In this repository, we open source our **Ailabs.tw Pop17K** dataset, and the codes for unconditional generation.


## Dependencies

* python 3.6
* Required packages:
```bash
pip install miditoolkit
pip install torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --user pytorch-fast-transformers
pip install chorder
```

``chorder`` is our in-house rule-based symbolic chord recognition algorithm, which is developed by our former intern - [joshuachang2311](https://github.com/joshuachang2311/chorder). He is also a jazz pianist.


## Model
In this work, we conduct two scenario of generation:
* unconditional generation
* To see the experimental results and the discussion, pleasee refer to [here](./worksapce/uncond/Experiments.md).

* conditional generation, leadsheet to full midi (ls2midi)
* [**Working in progress**] The codes associated with this part are planning to open source in the future
* melody extracyion (skyline)
* objective metrics
* model

## Dataset
To preparing your own training data, please refer to [documentaion]() for further understanding.
The full workspace of our dataset **Ailabs.tw Pop17K** are available [here](https://drive.google.com/drive/folders/1DY54sxeCcQfVXdGXps5lHwtRe7D_kBRV?usp=sharing).


## Acknowledgement
- PyTorch codes for transformer-XL is modified from [kimiyoung/transformer-xl](https://github.com/kimiyoung/transformer-xl).
- Thanks [
Yu-Hua Chen](https://github.com/ss12f32v)

Binary file added assets/data_proc_diagram.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified dataset/.DS_Store
Binary file not shown.
43 changes: 43 additions & 0 deletions dataset/Dataset.md
@@ -0,0 +1,43 @@
# Datasets

In this document, we demonstrate our standard data processing pipeline in our team. Following the instructions and runnung corresponding python scripts, you can easily generate and customized your your own dataset.


<p align="center">
<img src="../assets/data_proc_diagram.png" width="500">
</p>


## 1. From `audio` to `midi_transcribed`
We collect audio clips of piano performance from YouTube.

* run google magenta's [onsets and frames](https://github.com/magenta/magenta/tree/master/magenta/models/onsets_frames_transcription)

## 2. From `midi_transcribed` to `midi_synchronized`
In this step, we use [madamom](https://github.com/CPJKU/madmom) for beat/downbeat tracking. Next, We interpolate 480 ticks between two adjacent beats, and map the absolute time into its according tick. Lastly, we infer the tempo changes from the time interval between adjacent beats. We choose beat resolution=480 because it's a common setting in modern DAW. Notice that we don't quantize any timing in this step hence we can keep tiny offset for future purposes.

* run `synchronizer.py`

## 3. From `midi_synchronized` to `midi_analyzed`
In this step, we develop in-house rule-based symbolic melody extraction and chord recognition algorithm to obtain desired information. Only codes for chord are open sourced [here](https://github.com/joshuachang2311/chorder).

* run `analyzer.py`

## 4. From `midi_analyzed` to `Corpus`
We quantize every thing (duration, velocity, bpm) in this step. Also append the data with EOS(end of sequence) token.

* run `midi2corpus.py`

## 5. From `Corpus` to `Representation`
We have 2 kinds of representation - Compound Word (**CP**) and **REMI**, and 2 tasks - unconditional and conditional generation, which resulting 4 combinations. Go to corresponding folder `task\repre` and run the scripts.


* run `corpus2events.py`: to generate human readable tokens and re-arrange data.
* run `events2words.py`: to build dictionary and renumber the tokens.
* run `compile.py`: to discard disqualified songs that exceeding length limits, reshape the data for transformer-XL, and generate mask for variable length.

---

## AILabs.tw Pop17K dataset

Alternatively, you can refer to [here](https://drive.google.com/drive/folders/1DY54sxeCcQfVXdGXps5lHwtRe7D_kBRV?usp=sharing) to obtain the entire workspace and pre-processed training data, which originally used in our paper.
114 changes: 16 additions & 98 deletions dataset/midi/analyzer.py → dataset/analyzer.py
@@ -1,17 +1,16 @@
import os
import copy
import numpy as np
import multiprocessing as mp

import miditoolkit
from miditoolkit.midi import parser as mid_parser
from miditoolkit.pianoroll import parser as pr_parser
from miditoolkit.midi.containers import Marker, Instrument, TempoChange

from chorder import Dechorder
from sf_segmenter.segmenter import Segmenter


segmenter = Segmenter()
num2pitch = {
0: 'C',
1: 'C#',
Expand All @@ -27,6 +26,7 @@
11: 'B',
}


def traverse_dir(
root_dir,
extension=('mid', 'MID', 'midi'),
Expand Down Expand Up @@ -65,74 +65,11 @@ def traverse_dir(
return file_list


def quantize_melody(notes, tick_resol=240):
melody_notes = []
for note in notes:
# cut too long notes
if note.end - note.start > tick_resol * 8:
note.end = note.start + tick_resol * 4

# quantize
note.start = int(np.round(note.start / tick_resol) * tick_resol)
note.end = int(np.round(note.end / tick_resol) * tick_resol)

# append
melody_notes.append(note)
return melody_notes


def extract_melody(notes):
# quanrize
melody_notes = quantize_melody(notes)

# sort by start, pitch from high to low
melody_notes.sort(key=lambda x: (x.start, -x.pitch))

# exclude notes < 60
bins = []
prev = None
tmp_list = []
for nidx in range(len(melody_notes)):
note = melody_notes[nidx]
if note.pitch >= 60:
if note.start != prev:
if tmp_list:
bins.append(tmp_list)
tmp_list = [note]
else:
tmp_list.append(note)
prev = note.start

# preserve only highest one at each step
notes_out = []
for b in bins:
notes_out.append(b[0])

# avoid overlapping
notes_out.sort(key=lambda x:x.start)
for idx in range(len(notes_out) - 1):
if notes_out[idx].end >= notes_out[idx+1].start:
notes_out[idx].end = notes_out[idx+1].start

# delete note having no duration
notes_clean = []
for note in notes_out:
if note.start != note.end:
notes_clean.append(note)

# filtered by interval
notes_final = [notes_clean[0]]
for i in range(1, len(notes_clean) -1):
if ((notes_clean[i].pitch - notes_clean[i-1].pitch) <= -9) and \
((notes_clean[i].pitch - notes_clean[i+1].pitch) <= -9):
continue
else:
notes_final.append(notes_clean[i])
notes_final += [notes_clean[-1]]
return notes_final


def proc_one(path_infile, path_outfile):
print('----')
print(' >', path_infile)
print(' >', path_outfile)

# load
midi_obj = miditoolkit.midi.parser.MidiFile(path_infile)
midi_obj_out = copy.deepcopy(midi_obj)
Expand All @@ -157,26 +94,7 @@ def proc_one(path_infile, path_outfile):
if m.text != prev_chord:
prev_chord = m.text
dedup_chords.append(m)

# --- structure --- #
# structure analysis
bounds, labs = segmenter.proc_midi(path_infile)
bounds = np.round(bounds / 4)
bounds = np.unique(bounds)
print(' > [structure] bars:', bounds)
print(' > [structure] labs:', labs)

bounds_marker = []
for i in range(len(labs)):
b = bounds[i]
l = int(labs[i])
bounds_marker.append(
Marker(time=int(b*4*480), text='Boundary_'+str(l)))

# --- melody --- #
melody_notes = extract_melody(notes)
melody_notes = quantize_melody(melody_notes)


# --- global properties --- #
# global tempo
tempos = [b.tempo for b in midi_obj.tempo_changes][:40]
Expand All @@ -189,13 +107,8 @@ def proc_one(path_infile, path_outfile):
fn = os.path.basename(path_outfile)
os.makedirs(path_outfile[:-len(fn)], exist_ok=True)

# save piano (0) and melody (1)
melody_track = Instrument(program=0, is_drum=False, name='melody')
melody_track.notes = melody_notes
midi_obj_out.instruments.append(melody_track)

# markers
midi_obj_out.markers = dedup_chords + bounds_marker
midi_obj_out.markers = dedup_chords
midi_obj_out.markers.insert(0, Marker(text='global_bpm_'+str(int(global_bpm)), time=0))

# save
Expand All @@ -217,7 +130,8 @@ def proc_one(path_infile, path_outfile):
n_files = len(midifiles)
print('num fiels:', n_files)

# run
# collect
data = []
for fidx in range(n_files):
path_midi = midifiles[fidx]
print('{}/{}'.format(fidx, n_files))
Expand All @@ -226,5 +140,9 @@ def proc_one(path_infile, path_outfile):
path_infile = os.path.join(path_indir, path_midi)
path_outfile = os.path.join(path_outdir, path_midi)

# proc
proc_one(path_infile, path_outfile)
# append
data.append([path_infile, path_outfile])

# run, multi-thread
pool = mp.Pool()
pool.starmap(proc_one, data)
Binary file not shown.
13 changes: 7 additions & 6 deletions dataset/midi/midi2corpus.py → dataset/midi2corpus.py
Expand Up @@ -10,15 +10,15 @@
import seaborn as sns

# ================================================== #
# Conig #
# Configuration #
# ================================================== #
BEAT_RESOL = 480
BAR_RESOL = BEAT_RESOL * 4
TICK_RESOL = BEAT_RESOL // 4
INSTR_NAME_MAP = {'piano': 0, 'melody': 1}
INSTR_NAME_MAP = {'piano': 0}
MIN_BPM = 40
MIN_VELOCITY = 40
NOTE_SORTING = 1 # 0: low first / 1: high first
NOTE_SORTING = 1 # 0: ascending / 1: descending

DEFAULT_VELOCITY_BINS = np.linspace(0, 128, 64+1, dtype=np.int)
DEFAULT_BPM_BINS = np.linspace(32, 224, 64+1, dtype=np.int)
Expand All @@ -27,6 +27,8 @@
BEAT_RESOL/8, BEAT_RESOL*8+1, BEAT_RESOL/8)

# ================================================== #


def traverse_dir(
root_dir,
extension=('mid', 'MID', 'midi'),
Expand Down Expand Up @@ -88,7 +90,7 @@ def proc_one(path_midi, path_outfile):
instr_notes[instr_idx].sort(
key=lambda x: (x.start, -x.pitch))
else:
raise ValueError('Unknown note sorting type. ')
raise ValueError(' [x] Unknown type of sorting.')

# load chords
chords = []
Expand Down Expand Up @@ -116,7 +118,6 @@ def proc_one(path_midi, path_outfile):
marker.text.split('_')[1] == 'bpm':
gobal_bpm = int(marker.text.split('_')[2])


# --- process items to grid --- #
# compute empty bar offset at head
first_note_time = min([instr_notes[k][0].start for k in instr_notes.keys()])
Expand Down Expand Up @@ -220,7 +221,7 @@ def proc_one(path_midi, path_outfile):
if __name__ == '__main__':
# paths
path_indir = './midi_analyzed'
path_outdir = '../corpus'
path_outdir = './corpus'
os.makedirs(path_outdir, exist_ok=True)

# list files
Expand Down
Binary file not shown.
File renamed without changes.
Binary file added dataset/midi_synchronized/.DS_Store
Binary file not shown.
File renamed without changes.
Binary file added dataset/midi_transcribed/.DS_Store
Binary file not shown.
File renamed without changes.
Binary file modified dataset/representations/.DS_Store
Binary file not shown.
File renamed without changes.

0 comments on commit a83b7ab

Please sign in to comment.