Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
bazanovvanya committed Sep 16, 2021
1 parent f9402cf commit 7496685
Show file tree
Hide file tree
Showing 45 changed files with 2,090 additions and 0 deletions.
Binary file added img/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/.DS_Store
Binary file not shown.
Binary file added src/ds_files.pt
Binary file not shown.
80 changes: 80 additions & 0 deletions src/encode_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
import torch
import joblib
import hashlib
import pretty_midi
import numpy as np
from tqdm import tqdm
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor

from lib import constants
from lib import midi_processing

DATA_DIR = 'test_dataset'
OUTPUT_DIR = 'encoded_dataset'
DS_FILE_PATH = './ds_files.pt' # path where ds_files.pt will be created

GENRES = ['classic', 'jazz', 'calm', 'pop']
MAX_LEN = 2048

print('creating dirs...')
[os.makedirs(OUTPUT_DIR+'/'+g, exist_ok=True) for g in GENRES]

print('collecting *.mid files...')
FILES = list(map(str, Path(DATA_DIR).rglob('*.mid')))

def encode_fn(i):
"""wrapper for loading i-th midi-file, encoding, padding and saving encoded tensor on disk"""
file = FILES[i]
max_len = MAX_LEN

path, fname = os.path.split(file)
try:
midi = pretty_midi.PrettyMIDI(file)
genre = path.split('/')[1] # take GENRE from 'data/GENRE/xxx.mid'
except:
print(f'{i} not loaded')
return -1

assert genre in GENRES, f'{genre} is not in {GENRES}'

fname, ext = os.path.splitext(fname)
h = hashlib.md5(file.encode()).hexdigest()
save_name = f'{OUTPUT_DIR}/{genre}/{fname}_{h}'

events = midi_processing.encode(midi, use_piano_range=True)
events = np.array(events)
split_idxs = np.cumsum([max_len]*(events.shape[0]//max_len))
splits = np.split(events, split_idxs, axis=0)
n_last = splits[-1].shape[0]
if n_last < 256:
splits.pop(-1)
drop_last = 1
else:
drop_last = 0

for i, split in enumerate(splits):
keep_idxs = midi_processing.filter_bad_note_offs(split)
split = split[keep_idxs]
eos_idx = min(max_len - 1, len(split))
split = np.pad(split, [[0,max_len - len(split)]])
split[eos_idx] = constants.TOKEN_END
try:
torch.save(split, f'{save_name}_{i}.pt')
except OSError: # if fname is too long
save_name = f'{OUTPUT_DIR}/{genre}/{h}'
torch.save(split, f'{save_name}_{i}.pt')
return drop_last

cpu_count = joblib.cpu_count()
print(f'starting encoding in {cpu_count} processes...')
with ProcessPoolExecutor(cpu_count) as pool:
x = list(tqdm(pool.map(encode_fn, range(len(FILES))), position=0, total=len(FILES)))

print('collecting encoded (*.pt) files...')
ds_files = list(map(str, Path(OUTPUT_DIR).rglob('*.pt')))
print('total encoded files:', len(ds_files))

torch.save(ds_files, DS_FILE_PATH)
print('ds_files.pt saved to', os.path.abspath(DS_FILE_PATH))
Binary file added src/encoded_dataset/.DS_Store
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
141 changes: 141 additions & 0 deletions src/generate.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "241daf4b",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import argparse\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"import time\n",
"import torch\n",
"import pretty_midi\n",
"\n",
"from lib import constants\n",
"from lib import midi_processing\n",
"from lib import generation\n",
"from lib.midi_processing import PIANO_RANGE\n",
"from lib.model.transformer import MusicTransformer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "487a4fdf",
"metadata": {},
"outputs": [],
"source": [
"def decode_and_write(generated, primer, genre, out_dir):\n",
" for i, (gen, g) in enumerate(zip(generated, genre)):\n",
" midi = midi_processing.decode(gen)\n",
" midi.write(f'{out_dir}/gen_{i:>02}_{id2genre[g]}.mid')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba57008b",
"metadata": {},
"outputs": [],
"source": [
"id2genre = {0:'classic',1:'jazz',2:'calm',3:'pop'}\n",
"genre2id = dict([[x[1],x[0]] for x in id2genre.items()])\n",
"tuned_params = {\n",
" 0: 1.1,\n",
" 1: 0.95,\n",
" 2: 0.9,\n",
" 3: 1.0\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "074130e6",
"metadata": {},
"outputs": [],
"source": [
"load_path = '../checkpoints/model_big_v3_378k.pt'\n",
"out_dir = 'generated_' + time.strftime('%d-%m-%Y_%H-%M-%S')\n",
"genre_to_generate = 'calm' # Use one of ['classic', 'jazz', 'calm', 'pop']\n",
"batch_size = 8\n",
"device = torch.device('cuda:0')\n",
"remove_bad_generations = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40f48a16",
"metadata": {},
"outputs": [],
"source": [
"params = dict(\n",
" target_seq_length = 512,\n",
" temperature = tuned_params[genre2id[genre_to_generate]],\n",
" topk = 40,\n",
" topp = 0.99,\n",
" topp_temperature = 1.0,\n",
" at_least_k = 1,\n",
" use_rp = False,\n",
" rp_penalty = 0.05,\n",
" rp_restore_speed = 0.7,\n",
" seed = None,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ae21f1d",
"metadata": {},
"outputs": [],
"source": [
"# START GENERATION\n",
"\n",
"os.makedirs(out_dir, exist_ok=True)\n",
"genre_id = genre2id[genre_to_generate]\n",
"\n",
"# init model\n",
"print('loading model...')\n",
"model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(device).eval()\n",
"model.load_state_dict(torch.load(load_path, map_location=device))\n",
"\n",
"# add information about genre (first token)\n",
"primer_genre = np.repeat([genre_id], batch_size)\n",
"primer = torch.tensor(primer_genre)[:,None] + constants.VOCAB_SIZE - 4\n",
"\n",
"print('generating to:', os.path.abspath(out_dir))\n",
"generated = generation.generate(model, primer, **params)\n",
"generated = generation.post_process(generated, remove_bad_generations=remove_bad_generations)\n",
"\n",
"decode_and_write(generated, primer, primer_genre, out_dir)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
95 changes: 95 additions & 0 deletions src/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
import time
import torch
import argparse
import pretty_midi
import numpy as np
from tqdm import tqdm

from lib import constants
from lib import midi_processing
from lib import generation
from lib.midi_processing import PIANO_RANGE
from lib.model.transformer import MusicTransformer


def decode_and_write(generated, primer, genre, out_dir):
'''Decodes event-based format to midi and writes resulting file to disk'''
for i, (gen, g) in enumerate(zip(generated, genre)):
midi = midi_processing.decode(gen)
midi.write(f'{out_dir}/gen_{i:>02}_{id2genre[g]}.mid')


id2genre = {0:'classic',1:'jazz',2:'calm',3:'pop'}
genre2id = dict([[x[1],x[0]] for x in id2genre.items()])
tuned_params = {
0: 1.1,
1: 0.95,
2: 0.9,
3: 1.0
}


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--genre')
parser.add_argument('--target_seq_length', default=512, type=int)
parser.add_argument('--temperature', default=None, type=float)
parser.add_argument('--topk', default=40, type=int)
parser.add_argument('--topp', default=0.99, type=float)
parser.add_argument('--topp_temperature', default=1.0, type=float)
parser.add_argument('--at_least_k', default=1, type=int)
parser.add_argument('--use_rp', action='store_true')
parser.add_argument('--rp_penalty', default=0.05, type=int)
parser.add_argument('--rp_restore_speed', default=0.7, type=int)
parser.add_argument('--seed', default=None, type=int)
parser.add_argument('--device', default='cuda:0')
parser.add_argument('--keep_bad_generations', action='store_true')
parser.add_argument('--out_dir', default=None)
parser.add_argument('--load_path', default=None)
parser.add_argument('--batch_size', default=8, type=int)
args = parser.parse_args()


try:
genre_id = genre2id[args.genre]
except KeyError:
raise KeyError("Invalid genre name. Use one of ['classic', 'jazz', 'calm', 'pop']")

load_path = args.load_path or '../checkpoints/model_big_v3_378k.pt'
out_dir = args.out_dir or ('generated_' + time.strftime('%d-%m-%Y_%H-%M-%S'))
batch_size = args.batch_size
device = torch.device(args.device)
remove_bad_generations = not args.keep_bad_generations

default_params = dict(
target_seq_length = 512,
temperature = tuned_params[genre_id],
topk = 40,
topp = 0.99,
topp_temperature = 1.0,
at_least_k = 1,
use_rp = False,
rp_penalty = 0.05,
rp_restore_speed = 0.7,
seed = None,
)

params = {k:args.__dict__[k] if args.__dict__[k] else default_params[k] for k in default_params}

os.makedirs(out_dir, exist_ok=True)

# init model
print('loading model...')
model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(device).eval()
model.load_state_dict(torch.load(load_path, map_location=device))

# add information about genre (first token)
primer_genre = np.repeat([genre_id], batch_size)
primer = torch.tensor(primer_genre)[:,None] + constants.VOCAB_SIZE - 4

print('generating to:', os.path.abspath(out_dir))
generated = generation.generate(model, primer, **params)
generated = generation.post_process(generated, remove_bad_generations=remove_bad_generations)

decode_and_write(generated, primer, primer_genre, out_dir)
Loading

0 comments on commit 7496685

Please sign in to comment.