-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f9402cf
commit 7496685
Showing
45 changed files
with
2,090 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| import os | ||
| import torch | ||
| import joblib | ||
| import hashlib | ||
| import pretty_midi | ||
| import numpy as np | ||
| from tqdm import tqdm | ||
| from pathlib import Path | ||
| from concurrent.futures import ProcessPoolExecutor | ||
|
|
||
| from lib import constants | ||
| from lib import midi_processing | ||
|
|
||
| DATA_DIR = 'test_dataset' | ||
| OUTPUT_DIR = 'encoded_dataset' | ||
| DS_FILE_PATH = './ds_files.pt' # path where ds_files.pt will be created | ||
|
|
||
| GENRES = ['classic', 'jazz', 'calm', 'pop'] | ||
| MAX_LEN = 2048 | ||
|
|
||
| print('creating dirs...') | ||
| [os.makedirs(OUTPUT_DIR+'/'+g, exist_ok=True) for g in GENRES] | ||
|
|
||
| print('collecting *.mid files...') | ||
| FILES = list(map(str, Path(DATA_DIR).rglob('*.mid'))) | ||
|
|
||
| def encode_fn(i): | ||
| """wrapper for loading i-th midi-file, encoding, padding and saving encoded tensor on disk""" | ||
| file = FILES[i] | ||
| max_len = MAX_LEN | ||
|
|
||
| path, fname = os.path.split(file) | ||
| try: | ||
| midi = pretty_midi.PrettyMIDI(file) | ||
| genre = path.split('/')[1] # take GENRE from 'data/GENRE/xxx.mid' | ||
| except: | ||
| print(f'{i} not loaded') | ||
| return -1 | ||
|
|
||
| assert genre in GENRES, f'{genre} is not in {GENRES}' | ||
|
|
||
| fname, ext = os.path.splitext(fname) | ||
| h = hashlib.md5(file.encode()).hexdigest() | ||
| save_name = f'{OUTPUT_DIR}/{genre}/{fname}_{h}' | ||
|
|
||
| events = midi_processing.encode(midi, use_piano_range=True) | ||
| events = np.array(events) | ||
| split_idxs = np.cumsum([max_len]*(events.shape[0]//max_len)) | ||
| splits = np.split(events, split_idxs, axis=0) | ||
| n_last = splits[-1].shape[0] | ||
| if n_last < 256: | ||
| splits.pop(-1) | ||
| drop_last = 1 | ||
| else: | ||
| drop_last = 0 | ||
|
|
||
| for i, split in enumerate(splits): | ||
| keep_idxs = midi_processing.filter_bad_note_offs(split) | ||
| split = split[keep_idxs] | ||
| eos_idx = min(max_len - 1, len(split)) | ||
| split = np.pad(split, [[0,max_len - len(split)]]) | ||
| split[eos_idx] = constants.TOKEN_END | ||
| try: | ||
| torch.save(split, f'{save_name}_{i}.pt') | ||
| except OSError: # if fname is too long | ||
| save_name = f'{OUTPUT_DIR}/{genre}/{h}' | ||
| torch.save(split, f'{save_name}_{i}.pt') | ||
| return drop_last | ||
|
|
||
| cpu_count = joblib.cpu_count() | ||
| print(f'starting encoding in {cpu_count} processes...') | ||
| with ProcessPoolExecutor(cpu_count) as pool: | ||
| x = list(tqdm(pool.map(encode_fn, range(len(FILES))), position=0, total=len(FILES))) | ||
|
|
||
| print('collecting encoded (*.pt) files...') | ||
| ds_files = list(map(str, Path(OUTPUT_DIR).rglob('*.pt'))) | ||
| print('total encoded files:', len(ds_files)) | ||
|
|
||
| torch.save(ds_files, DS_FILE_PATH) | ||
| print('ds_files.pt saved to', os.path.abspath(DS_FILE_PATH)) |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+17.1 KB
src/encoded_dataset/calm/calm_0_7bfc0a94983dd5eb495ae0555efa4521_2.pt
Binary file not shown.
Binary file added
BIN
+17.1 KB
src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_0.pt
Binary file not shown.
Binary file added
BIN
+17.1 KB
src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_1.pt
Binary file not shown.
Binary file added
BIN
+17.1 KB
src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_2.pt
Binary file not shown.
Binary file added
BIN
+17.1 KB
src/encoded_dataset/classic/classic_0_94d87b7dc3b6ee96d83f8a173085ce8c_3.pt
Binary file not shown.
Binary file added
BIN
+17.1 KB
src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_0.pt
Binary file not shown.
Binary file added
BIN
+17.1 KB
src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_1.pt
Binary file not shown.
Binary file added
BIN
+16.9 KB
src/encoded_dataset/jazz/jazz_0_4cf8f1246ebc3c24375aca4539fd8adb_2.pt
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "241daf4b", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import os\n", | ||
| "import argparse\n", | ||
| "import numpy as np\n", | ||
| "from tqdm import tqdm\n", | ||
| "import time\n", | ||
| "import torch\n", | ||
| "import pretty_midi\n", | ||
| "\n", | ||
| "from lib import constants\n", | ||
| "from lib import midi_processing\n", | ||
| "from lib import generation\n", | ||
| "from lib.midi_processing import PIANO_RANGE\n", | ||
| "from lib.model.transformer import MusicTransformer" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "487a4fdf", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "def decode_and_write(generated, primer, genre, out_dir):\n", | ||
| " for i, (gen, g) in enumerate(zip(generated, genre)):\n", | ||
| " midi = midi_processing.decode(gen)\n", | ||
| " midi.write(f'{out_dir}/gen_{i:>02}_{id2genre[g]}.mid')" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "ba57008b", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "id2genre = {0:'classic',1:'jazz',2:'calm',3:'pop'}\n", | ||
| "genre2id = dict([[x[1],x[0]] for x in id2genre.items()])\n", | ||
| "tuned_params = {\n", | ||
| " 0: 1.1,\n", | ||
| " 1: 0.95,\n", | ||
| " 2: 0.9,\n", | ||
| " 3: 1.0\n", | ||
| "}" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "074130e6", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "load_path = '../checkpoints/model_big_v3_378k.pt'\n", | ||
| "out_dir = 'generated_' + time.strftime('%d-%m-%Y_%H-%M-%S')\n", | ||
| "genre_to_generate = 'calm' # Use one of ['classic', 'jazz', 'calm', 'pop']\n", | ||
| "batch_size = 8\n", | ||
| "device = torch.device('cuda:0')\n", | ||
| "remove_bad_generations = True" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "40f48a16", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "params = dict(\n", | ||
| " target_seq_length = 512,\n", | ||
| " temperature = tuned_params[genre2id[genre_to_generate]],\n", | ||
| " topk = 40,\n", | ||
| " topp = 0.99,\n", | ||
| " topp_temperature = 1.0,\n", | ||
| " at_least_k = 1,\n", | ||
| " use_rp = False,\n", | ||
| " rp_penalty = 0.05,\n", | ||
| " rp_restore_speed = 0.7,\n", | ||
| " seed = None,\n", | ||
| ")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "9ae21f1d", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# START GENERATION\n", | ||
| "\n", | ||
| "os.makedirs(out_dir, exist_ok=True)\n", | ||
| "genre_id = genre2id[genre_to_generate]\n", | ||
| "\n", | ||
| "# init model\n", | ||
| "print('loading model...')\n", | ||
| "model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(device).eval()\n", | ||
| "model.load_state_dict(torch.load(load_path, map_location=device))\n", | ||
| "\n", | ||
| "# add information about genre (first token)\n", | ||
| "primer_genre = np.repeat([genre_id], batch_size)\n", | ||
| "primer = torch.tensor(primer_genre)[:,None] + constants.VOCAB_SIZE - 4\n", | ||
| "\n", | ||
| "print('generating to:', os.path.abspath(out_dir))\n", | ||
| "generated = generation.generate(model, primer, **params)\n", | ||
| "generated = generation.post_process(generated, remove_bad_generations=remove_bad_generations)\n", | ||
| "\n", | ||
| "decode_and_write(generated, primer, primer_genre, out_dir)" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "Python 3 (ipykernel)", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.7.10" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 5 | ||
| } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| import os | ||
| import time | ||
| import torch | ||
| import argparse | ||
| import pretty_midi | ||
| import numpy as np | ||
| from tqdm import tqdm | ||
|
|
||
| from lib import constants | ||
| from lib import midi_processing | ||
| from lib import generation | ||
| from lib.midi_processing import PIANO_RANGE | ||
| from lib.model.transformer import MusicTransformer | ||
|
|
||
|
|
||
| def decode_and_write(generated, primer, genre, out_dir): | ||
| '''Decodes event-based format to midi and writes resulting file to disk''' | ||
| for i, (gen, g) in enumerate(zip(generated, genre)): | ||
| midi = midi_processing.decode(gen) | ||
| midi.write(f'{out_dir}/gen_{i:>02}_{id2genre[g]}.mid') | ||
|
|
||
|
|
||
| id2genre = {0:'classic',1:'jazz',2:'calm',3:'pop'} | ||
| genre2id = dict([[x[1],x[0]] for x in id2genre.items()]) | ||
| tuned_params = { | ||
| 0: 1.1, | ||
| 1: 0.95, | ||
| 2: 0.9, | ||
| 3: 1.0 | ||
| } | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument('--genre') | ||
| parser.add_argument('--target_seq_length', default=512, type=int) | ||
| parser.add_argument('--temperature', default=None, type=float) | ||
| parser.add_argument('--topk', default=40, type=int) | ||
| parser.add_argument('--topp', default=0.99, type=float) | ||
| parser.add_argument('--topp_temperature', default=1.0, type=float) | ||
| parser.add_argument('--at_least_k', default=1, type=int) | ||
| parser.add_argument('--use_rp', action='store_true') | ||
| parser.add_argument('--rp_penalty', default=0.05, type=int) | ||
| parser.add_argument('--rp_restore_speed', default=0.7, type=int) | ||
| parser.add_argument('--seed', default=None, type=int) | ||
| parser.add_argument('--device', default='cuda:0') | ||
| parser.add_argument('--keep_bad_generations', action='store_true') | ||
| parser.add_argument('--out_dir', default=None) | ||
| parser.add_argument('--load_path', default=None) | ||
| parser.add_argument('--batch_size', default=8, type=int) | ||
| args = parser.parse_args() | ||
|
|
||
|
|
||
| try: | ||
| genre_id = genre2id[args.genre] | ||
| except KeyError: | ||
| raise KeyError("Invalid genre name. Use one of ['classic', 'jazz', 'calm', 'pop']") | ||
|
|
||
| load_path = args.load_path or '../checkpoints/model_big_v3_378k.pt' | ||
| out_dir = args.out_dir or ('generated_' + time.strftime('%d-%m-%Y_%H-%M-%S')) | ||
| batch_size = args.batch_size | ||
| device = torch.device(args.device) | ||
| remove_bad_generations = not args.keep_bad_generations | ||
|
|
||
| default_params = dict( | ||
| target_seq_length = 512, | ||
| temperature = tuned_params[genre_id], | ||
| topk = 40, | ||
| topp = 0.99, | ||
| topp_temperature = 1.0, | ||
| at_least_k = 1, | ||
| use_rp = False, | ||
| rp_penalty = 0.05, | ||
| rp_restore_speed = 0.7, | ||
| seed = None, | ||
| ) | ||
|
|
||
| params = {k:args.__dict__[k] if args.__dict__[k] else default_params[k] for k in default_params} | ||
|
|
||
| os.makedirs(out_dir, exist_ok=True) | ||
|
|
||
| # init model | ||
| print('loading model...') | ||
| model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(device).eval() | ||
| model.load_state_dict(torch.load(load_path, map_location=device)) | ||
|
|
||
| # add information about genre (first token) | ||
| primer_genre = np.repeat([genre_id], batch_size) | ||
| primer = torch.tensor(primer_genre)[:,None] + constants.VOCAB_SIZE - 4 | ||
|
|
||
| print('generating to:', os.path.abspath(out_dir)) | ||
| generated = generation.generate(model, primer, **params) | ||
| generated = generation.post_process(generated, remove_bad_generations=remove_bad_generations) | ||
|
|
||
| decode_and_write(generated, primer, primer_genre, out_dir) |
Oops, something went wrong.