In [None]:
!pip install pretty_midi
import os, sys, shutil
import time
import json
import math
import argparse
import itertools
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, Subset, DistributedSampler, Dataset

from lib import constants
from lib.model.transformer import MusicTransformer
from lib.inverse_power_with_warmup_sheduler import InversePowerWithWarmupLRScheduler
from lib.encoded_dataset import EncodedDataset
from lib.augmentations import MusicAugmentations

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import zipfile
archive = '/content/lib.zip'
with zipfile.ZipFile(archive, 'r') as zip_file:
    zip_file.extractall('/content')

In [None]:

PAD_TOKEN = constants.TOKEN_PAD

params = dict(
    NAME = 'model_name',
    DS_FILE_PATH = 'ds_files.pt',
    SEED = 0,
    num_epochs = 100,
    batch_size = 1,
    num_workers = 0,
    val_every = 6000,
    save_every = 6000,
    lr = 1e-4,
    use_scheduler = True,
    peak_lr = 1e-4,
    warmup_steps = 4000,
    power = 2,
    shift = 100000,
    LOAD_NAME = '',
    LOG_TOTAL_NORM = True,
    CLIPPING = False,
    gpus = [0,1,2,3,4],
)

globals().update(params)
import torch

RANGE_NOTE_ON = 128
RANGE_NOTE_OFF = 128
RANGE_VEL = 32
RANGE_TIME_SHIFT = 100

TOKEN_END               = RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_VEL + RANGE_TIME_SHIFT
TOKEN_PAD               = TOKEN_END + 1
VOCAB_SIZE              = TOKEN_PAD + 1 + 4

TORCH_FLOAT             = torch.float32
TORCH_INT               = torch.int32

TORCH_LABEL_TYPE        = torch.long

PREPEND_ZEROS_WIDTH     = 4


## ДАТАСЕТ


In [None]:

import os
import torch
import joblib
import hashlib
import pretty_midi
import numpy as np
from tqdm import tqdm
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor

In [None]:
DATA_DIR = '/content/drive/MyDrive/music/nsynth-valid/test_dataset'
OUTPUT_DIR = '/content/drive/MyDrive/music/nsynth-valid/encoded_dataset'
DS_FILE_PATH = './ds_files.pt' # сохраним сюда
GENRES = ['classic', 'jazz', 'calm', 'pop', 'hiphop']
MAX_LEN = 2048
print('collecting *.mid files...')
FILES = list(map(str, Path(DATA_DIR).rglob('*.mid')))
DS_FILE_PATH = './ds_files.pt'
ds_files = list(map(str, Path(OUTPUT_DIR).rglob('*.pt')))
torch.save(ds_files, DS_FILE_PATH)

print('ds_files.pt saved to', os.path.abspath(DS_FILE_PATH))

collecting *.mid files...
ds_files.pt saved to /content/ds_files.pt


In [None]:
pt_file = torch.load("/content/ds_files.pt")

/content/drive/MyDrive/music/nsynth-valid/encoded_dataset/pop/pop_0_3c1c5acc2d141a741dd07e570183cd93_0.pt


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
def encode_fn(i):
    '''обертка для загрузки i-го миди-файла, кодирования, заполнения и сохранения закодированного тензора на диск'''
    file = FILES[i]
    max_len = MAX_LEN
    
    path, fname = os.path.split(file)
    try:
            midi = pretty_midi.PrettyMIDI(file)
            genre = path.split('/')[-1]  
    except:
        print(f'{i} not loaded')
        return -1
    print(path)
    assert genre in GENRES, f'{genre} is not in {GENRES}'
    
    fname, ext = os.path.splitext(fname)
    h = hashlib.md5(file.encode()).hexdigest()
    save_name = f'{OUTPUT_DIR}/{genre}/{fname}_{h}'
        
    events = midi_processing.encode(midi, use_piano_range=True)
    events = np.array(events)
    split_idxs = np.cumsum([max_len]*(events.shape[0]//max_len))
    splits = np.split(events, split_idxs, axis=0)
    n_last = splits[-1].shape[0]
    if n_last < 256:
        splits.pop(-1)
        drop_last = 1
    else:
        drop_last = 0
        
    for i, split in enumerate(splits):
        keep_idxs = midi_processing.filter_bad_note_offs(split)
        split = split[keep_idxs]
        eos_idx = min(max_len - 1, len(split))
        split = np.pad(split, [[0,max_len - len(split)]])
        split[eos_idx] = constants.TOKEN_END
        try:
            torch.save(split, f'{save_name}_{i}.pt')
        except OSError:  # если имя слишком большое
            save_name = f'{OUTPUT_DIR}/{genre}/{h}'
            torch.save(split, f'{save_name}_{i}.pt')
    return drop_last

In [None]:
#закодируем датасет
cpu_count = joblib.cpu_count()
print(f'starting encoding in {cpu_count} processes...')
x = list(tqdm(map(encode_fn, range(len(FILES))), position=0, total=len(FILES)))

print('collecting encoded (*.pt) files...')
ds_files = list(map(str, Path(OUTPUT_DIR).rglob('*.pt')))
print('total encoded files:', len(ds_files))

torch.save(ds_files, DS_FILE_PATH)
print('ds_files.pt saved to', os.path.abspath(DS_FILE_PATH))

In [None]:
import midi_processing

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset


class EncodedDataset(Dataset):
    """
    Класс набора данных для обучения и оценки модели.
    
    Parameters
    ----------
    ds_files : str
        путь до файлаов'ds_files.pt'
    prefix_path : str
        prefix_path  будет добавлен в 'ds_files.pt'. Иногда используется для удобства.
    transform : MusicAugmentations
        трансформер
    """
    def __init__(self, ds_files, prefix_path='', transform=None):
        self.transform = transform
        self.files = torch.load(ds_files)
        self.prefix_path = prefix_path
        self.genre2id = {'classic':0, 'jazz':1, 'calm':2, 'pop':3, 'hiphop':4}
        self.genre = [self.genre2id.get(f.split('/')[1], 0) for f in self.files]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        x = torch.load(self.prefix_path + self.files[idx])
        if self.transform:
            x = torch.from_numpy(self.transform(x))
        genre = self.genre[idx]
        return x, genre, idx

# train


In [None]:
def create_dataloaders(batch_size, num_workers=0):
    '''Инициализирует дополнения, загружает списки файлов в наборы данных и загрузчики и возвращает их'''
    print('loading data...')
    
    aug = MusicAugmentations()
    
    tr_dataset = EncodedDataset(DS_FILE_PATH, transform=aug)
    vl_dataset = EncodedDataset(DS_FILE_PATH, transform=None)
    
    np.random.seed(0)
    idxs = np.random.permutation(len(tr_dataset))
    vl, tr = np.split(idxs, [2])
    train_dataset = Subset(tr_dataset, tr)
    val_dataset = Subset(vl_dataset, vl)
    print("size", len(train_dataset))
    print("size", len(val_dataset))
    sampler = DistributedSampler(train_dataset, world_size, rank, True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, pin_memory=False, num_workers=num_workers)
    sampler = DistributedSampler(val_dataset, world_size, rank, False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size*4, sampler=sampler, pin_memory=False, num_workers=num_workers)
    
    return train_loader, val_loader

In [None]:
def init_model(lr, seed=0):
    '''Инициализирует модель, при необходимости загружает веса и создает оптимизатор'''
    torch.manual_seed(seed)
    model = MusicTransformer(device, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(device)
    if LOAD_NAME != '':
        model.load_state_dict(torch.load(LOAD_NAME, map_location=device))
        print(f'Loaded model from {LOAD_NAME}')
    print(sum((torch.numel(x) for x in model.parameters()))/1e6, 'M parameters')
    optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay=1e-5)
    return model, optimizer

In [None]:
def validate(model, val_loader):
    CE = 0
    ACC = 0
    n = 0
    model.eval()
    with torch.no_grad():
        for x, genre, idxs in val_loader:
            x[x==0] = PAD_TOKEN
            tgt = x.clone()
            x[:,-1] = constants.VOCAB_SIZE - 4-1 + genre
            x = torch.roll(x, 1, -1)
            x, tgt = x.to(device), tgt.to(device)

            logits = model(x)
            pred = logits.argmax(-1)

            mask = tgt != PAD_TOKEN
            n += mask.sum().item()
            CE += F.cross_entropy(logits.view(-1, logits.shape[-1]), tgt.flatten(), ignore_index=PAD_TOKEN, reduction='sum').item()
            ACC += (pred[mask] == tgt[mask]).sum().item()
            
    model.train()
    return CE/n, ACC/n

In [None]:
fworld_size = len(gpus)
global device, NAME, SEED, rank
rank, world_size = 0, world_size
    
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
    
device = torch.device(f'cuda:{gpus[rank]}')
print(rank, gpus[rank], device)
    
train_loader, val_loader = create_dataloaders(batch_size, num_workers)
torch.cuda.empty_cache()   
model, optimizer = init_model(lr, SEED)
if use_scheduler:
  scheduler = InversePowerWithWarmupLRScheduler(optimizer, peak_lr=peak_lr, warmup_steps=warmup_steps, power=power, shift=shift)
    
if rank == 0:
  save_dir = f'output/{NAME}'
  save_name = f'{NAME}'
  if os.path.exists(save_dir):
    print(f'WARNING: {save_dir} exists! It may rewrite useful files')
  os.makedirs(save_dir, exist_ok=True)
  writer = SummaryWriter(f'runs/{save_name}')
    
    # TRAIN
LS = {'loss':[], 'lr':[], 'val_ce':[], 'val_acc':[]}

i_val = 0
i_step = -1
best_ce = float('inf')
patience = 0
for ep in range(num_epochs):
        model.train()
        train_loader.sampler.set_epoch(ep)
        if rank == 0:
            bar = tqdm(train_loader, position=rank)
        else:
            bar = train_loader
        for x, genre, idxs in bar:
            i_step += 1
            x[x==0] = PAD_TOKEN
            tgt = x.clone()
            x[:,-1] = constants.VOCAB_SIZE - 4 - 1 + genre
            x = torch.roll(x, 1, -1)
            x, tgt = x.to(device), tgt.to(device)
            torch.cuda.empty_cache()
            logits = model(x)
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), tgt.flatten(), ignore_index=PAD_TOKEN)
            
            optimizer.zero_grad()
            loss.backward()
            
            if CLIPPING:
                total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CLIPPING).item()
            else:
                total_norm = 0
            
            optimizer.step()
            
            if use_scheduler:
                scheduler.step()
                
            if i_step == warmup_steps - 1 and rank == 0:
                torch.save(model.module.state_dict(), f'{save_dir}/model_{save_name}_after_warmup.pt')

            if rank == 0:
                # logs
                LS['loss'] += [loss.item()]
                LS['lr'] += [optimizer.param_groups[0]['lr']]
                writer.add_scalar(f'Train/embedding_weight_norm', torch.norm(model.embedding.weight).item(), i_step)
                writer.add_scalar(f'Train/embedding_grad_norm', torch.norm(model.embedding.weight.grad).item(), i_step)
                writer.add_scalar(f'Train/output_weight_norm', torch.norm(model.Wout.weight).item(), i_step)
                writer.add_scalar(f'Train/output_grad_norm', torch.norm(model.Wout.weight.grad).item(), i_step)
                writer.add_scalar(f'Train/loss', loss.item(), i_step)
                writer.add_scalar(f'Train/perplexity', math.exp(loss.item()), i_step)
                writer.add_scalar(f'Train/lr', optimizer.param_groups[0]['lr'], i_step)
                if LOG_TOTAL_NORM:
                    total_norm = 0.
                    for p in model.parameters():
                        param_norm = p.grad.detach().data.norm(2)
                        total_norm += param_norm.item() ** 2
                    total_norm = total_norm ** 0.5
                    writer.add_scalar(f'Train/total_grad_norm', total_norm, i_step)
                bar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'], norm=total_norm)
                

            # VALIDATION
            if i_step % val_every == val_every-1:
                val_ce, val_acc = validate(model, val_loader)
                if world_size > 1:
                    ce_all, acc_all = [[torch.zeros(1,device=device) for i in range(world_size)] for _ in range(2)]
                    [torch.distributed.all_gather(a, torch.tensor(x, dtype=torch.float32, device=device)) for a,x in zip([ce_all,acc_all], [val_ce,val_acc])]
                    val_ce, val_acc = [torch.cat(a).mean().item() for a in [ce_all,acc_all]]
                if rank == 0:
                    # log, save, patience tracking
                    LS['val_ce'] += [val_ce]
                    LS['val_acc'] += [val_acc]
                    writer.add_scalar(f'Val/ce', val_ce, i_val)
                    writer.add_scalar(f'Val/acc', val_acc, i_val)
                    writer.add_scalar(f'Val/perplexity', math.exp(val_ce), i_val)
                    if val_ce < best_ce:
                        patience = 0
                        best_ce = val_ce
                        torch.save({'history':LS,'epoch':ep,'params':params}, f'{save_dir}/hist_{save_name}_best.pt')
                        torch.save(model.module.state_dict(), f'{save_dir}/model_{save_name}_best.pt')
                    else:
                        patience += 1
                    print(f'{ep}: val_ce={val_ce}, val_acc={val_acc}, patience={patience}')
                i_val += 1

            # CHECKPOINT
            if (i_step % save_every == save_every-1) and rank == 0:
                torch.save({'history':LS,'epoch':ep,'params':params}, f'{save_dir}/hist_{save_name}.pt')
                torch.save(model.module.state_dict(), f'{save_dir}/model_{save_name}_{(i_step+1)//1000}k.pt')
    


0 0 cuda:0
loading data...
size 22
size 2
103.180682 M parameters


100%|██████████| 5/5 [00:07<00:00,  1.50s/it, loss=6.21, lr=1.5e-7, norm=19.6]
100%|██████████| 5/5 [00:06<00:00,  1.31s/it, loss=6.13, lr=2.75e-7, norm=14]
100%|██████████| 5/5 [00:06<00:00,  1.21s/it, loss=6.21, lr=4e-7, norm=11]
100%|██████████| 5/5 [00:05<00:00,  1.18s/it, loss=6.17, lr=5.25e-7, norm=8.34]
100%|██████████| 5/5 [00:06<00:00,  1.27s/it, loss=6.11, lr=6.5e-7, norm=8.19]
100%|██████████| 5/5 [00:05<00:00,  1.18s/it, loss=6.16, lr=7.75e-7, norm=10.8]
100%|██████████| 5/5 [00:05<00:00,  1.19s/it, loss=6.03, lr=9e-7, norm=9.85]
100%|██████████| 5/5 [00:06<00:00,  1.24s/it, loss=5.97, lr=1.03e-6, norm=8.07]
100%|██████████| 5/5 [00:05<00:00,  1.14s/it, loss=6.14, lr=1.15e-6, norm=9.05]
100%|██████████| 5/5 [00:05<00:00,  1.17s/it, loss=6.1, lr=1.28e-6, norm=8.92]
100%|██████████| 5/5 [00:05<00:00,  1.16s/it, loss=5.86, lr=1.4e-6, norm=8.67]
100%|██████████| 5/5 [00:06<00:00,  1.20s/it, loss=5.89, lr=1.53e-6, norm=7.28]
100%|██████████| 5/5 [00:05<00:00,  1.16s/it, loss=5.8

In [None]:
torch.save(model.state_dict(), f'{save_dir}/model_{1}_{(i_step+1)//1000}k.pt')

# generate

In [None]:
import os
import time
import torch
import argparse
import pretty_midi
import numpy as np
from tqdm import tqdm

from lib import constants
from lib import midi_processing
from lib import generation
from lib.midi_processing import PIANO_RANGE
from lib.model.transformer import MusicTransformer


def decode_and_write(generated, primer, genre, out_dir='/content/output'):
    '''Decodes event-based format to midi and writes resulting file to disk'''
    for i, (gen, g) in enumerate(zip(generated, genre)):
        midi = midi_processing.decode(gen)
        midi.write(f'{out_dir}/gen_{i:>02}_{id2genre[g]}.mid')

        
id2genre = {0:'classic',1:'jazz',2:'calm',3:'pop', 4:'hiphop'}
genre2id = dict([[x[1],x[0]] for x in id2genre.items()])
tuned_params = {
    0: 1.1,
    1: 0.95,
    2: 0.9,
    3: 1.0,
    4: 1.05
}

In [None]:
params = {'target_seq_length': 1024, 'temperature': 1.0, 'topk': 40, 'topp': 0.99, 'topp_temperature': 1.0,
          'at_least_k': 1, 'use_rp': False, 'rp_penalty':0.05, 'rp_restore_speed':0.7, 'seed': None}

In [None]:
print('loading model...')
if torch.cuda.is_available():
    map_location = 'cuda'
else:
    map_location = 'cpu'
model = MusicTransformer(map_location, n_layers=12, d_model=1024, dim_feedforward=2048, num_heads=16, vocab_size=constants.VOCAB_SIZE, rpr=True).to(map_location).eval()
model.load_state_dict(torch.load('/content/drive/MyDrive/music/model_1_0k.pt', map_location=map_location))

    # add information about genre (first token)
primer_genre = np.repeat([4], batch_size)
primer = torch.tensor(primer_genre)[:,None] + constants.VOCAB_SIZE - 4 -1


loading model...


In [None]:
generated = generation.generate(model, primer, **params)
generated = generation.post_process(generated, remove_bad_generations=False)

decode_and_write(generated, primer, primer_genre, '/content/')

100%|██████████| 1023/1023 [01:02<00:00, 16.39it/s]


In [None]:
pm = pretty_midi.PrettyMIDI('/content/output/gen_00_hiphop.mid')

In [None]:
midi_path='/content/output/gen_00_hiphop.mid'

FluidSynth().midi_to_audio(midi_path, midi_path.replace('.mid', '.wav'))

In [None]:
!pip install midi2audio
from midi2audio import FluidSynth

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting midi2audio
  Downloading midi2audio-0.1.1-py2.py3-none-any.whl (8.7 kB)
Installing collected packages: midi2audio
Successfully installed midi2audio-0.1.1
