# Import libraries

In [66]:
# %pip install transformers
# %pip install tokenizers
# %pip install youtokentome
%pip install -U catalyst

Collecting catalyst
  Using cached catalyst-20.11-py2.py3-none-any.whl (489 kB)
Collecting deprecation
  Using cached deprecation-2.1.0-py2.py3-none-any.whl (11 kB)
Collecting GitPython>=3.1.1
  Using cached GitPython-3.1.11-py3-none-any.whl (159 kB)
Collecting gitdb<5,>=4.0.1
  Using cached gitdb-4.0.5-py3-none-any.whl (63 kB)
Collecting ipython
  Using cached ipython-7.13.0-py3-none-any.whl (780 kB)
Collecting backcall
  Using cached backcall-0.2.0-py2.py3-none-any.whl (11 kB)
Collecting decorator
  Using cached decorator-4.4.2-py2.py3-none-any.whl (9.2 kB)
Collecting jedi>=0.10
  Using cached jedi-0.17.2-py2.py3-none-any.whl (1.4 MB)
Collecting matplotlib
  Using cached matplotlib-3.3.3-cp37-cp37m-manylinux1_x86_64.whl (11.6 MB)
Collecting cycler>=0.10
  Using cached cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)
Collecting kiwisolver>=1.0.1
  Using cached kiwisolver-1.3.1-cp37-cp37m-manylinux1_x86_64.whl (1.1 MB)
Collecting numpy>=1.16.4
  Using cached numpy-1.19.4-cp37-cp37m-manylinu



In [2]:
from string import ascii_letters
from pathlib import Path
import subprocess
import requests
import random
import shutil
from typing import List
from IPython.display import Audio

# from midi2audio import FluidSynth
# import music21
import zipfile
from tqdm import tqdm


import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel
import youtokentome as yttm
from catalyst import dl, metrics

# Load data

In [52]:
# sound_font_zip = download_file('https://storage.yandexcloud.net/hackathon-2020/GeneralUser%20GS%201.471.zip', local_filename='GeneralUser_GS_1.471.zip')
unzip(sound_font_zip)

Extracting : 100%|██████████| 46/46 [00:07<00:00,  6.30it/s]


In [53]:
# trainset_path = download_file('https://storage.yandexcloud.net/hackathon-2020/trainset.zip', local_filename='trainset.zip')
unzip(trainset_path)

Extracting : 100%|██████████| 364009/364009 [26:21<00:00, 230.22it/s]


In [54]:
# testset_path = download_file('https://storage.yandexcloud.net/hackathon-2020/testset.zip', local_filename='testset.zip')
unzip(testset_path)

Extracting : 100%|██████████| 20407/20407 [01:19<00:00, 255.69it/s]


In [11]:
train_paths = collect_children(Path('trainset/abc'))[:128]
# test_paths = collect_children(Path('testset/abc'))[:128]

## Load data into RAM

In [12]:
def process_abc_files(paths):
    texts = []
    for i in tqdm(paths):
        if i.suffix != ".abc":
            continue

        keys = []
        notes = []
        with open(i) as rf:
            for line in rf:
                line = line.strip()
                if line.startswith("%"):
                    continue

                if len(line) > 1 and line[0] in "BCDFGHIKLMmNOPQRrSsTUVWwXZ" and line[1] == ":":
                    keys.append(line)
                else:
                    notes.append(line)

        text = "\n".join(keys)
        notes = "".join(notes)

        if text.endswith("|"):
            text = text[:-1]


        notes = notes.replace(" ", "")
        notes = notes.replace("[", " [")
        notes = notes.replace("]", "] ")
        notes = notes.replace("(", " (")
        notes = notes.replace(")", ") ")
        notes = notes.replace("|", " | ")
        notes = notes.strip()
        notes = notes.replace("  ", " ")
        
        if not keys or not notes:
            continue

        text = text + "\n" + notes + "\n"
        text = " ".join(text.split(" "))     
        texts.append(text)
        
    return texts

In [14]:
train = process_abc_files(train_paths)[:100]
# test = process_abc_files(test_paths)[:100]

100%|██████████| 128/128 [00:00<00:00, 4249.78it/s]


## Make dataloader

In [15]:
tokenizer = yttm.BPE("abc.yttm", n_threads=-1)

In [17]:
class ABCDataset(Dataset):
    def __init__(self, texts, tokenizer, 
                 context_bars_num=8, 
                 target_bars_num=1,
                 is_test=False):
        
        self.notes = []
        self.keys = []
        for i, text in enumerate(texts):
            if text.count("x8 | "*3) != 0 and not is_test:
                continue
                print(i)
            
            try:
                text = text.strip()
                keys, notes = text.rsplit("\n", 1)
                notes = notes.split(" | ")
            except Exception:
                import pdb;pdb.set_trace()
            
            if len(notes) < context_bars_num + target_bars_num and not is_test:
                continue
                print(i)
                
            self.keys.append(keys)
            self.notes.append(notes)
        
        self.tokenizer = tokenizer
        self.context_bars_num = context_bars_num
        self.target_bars_num = target_bars_num
        self.is_test = is_test
        
    def __len__(self):
        return len(self.keys)
    
    
    def __getitem__(self, idx):
        notes = self.notes[idx]
        keys = self.keys[idx]
        
        if not self.is_test:
            split_indx = random.randint(self.context_bars_num, len(notes) - self.target_bars_num)

            context_notes = notes[split_indx - self.context_bars_num : split_indx]
            target = notes[split_indx: split_indx + self.target_bars_num]
        else:
            context_notes = notes
            target = []

        context = keys + "\n" + " | ".join(context_notes).strip()
        if not context.endswith("|"):
            context += " | "

        target = " | ".join(target)

        context_tokens = self.tokenizer.encode(context, bos=True, eos=True)
        target_tokens = self.tokenizer.encode(target, bos=True, eos=True)
        
        context_tokens = torch.tensor(context_tokens, dtype=torch.long)
        target_tokens = torch.tensor(target_tokens, dtype=torch.long)

        return {"features": context_tokens, "target": target_tokens}
    
    
train_dataset = ABCDataset(train, tokenizer)
# test_dataset = ABCDataset(test, tokenizer, is_test=True)


def collate_function(batch):
    features = [i["features"] for i in batch]
    target = [i["target"] for i in batch]
    
    features_lens = [len(i) for i in features]
    target_lens = [len(i) for i in target]
    
    max_features_len = max(features_lens)
    max_target_len = max(target_lens)
    
    features_mask = torch.tensor([[1] * l + [0] * (max_features_len - l) for l in features_lens],
                                 dtype=torch.bool)
    
    target_mask = torch.tensor([[1] * l + [0] * (max_target_len - l) for l in target_lens],
                                dtype=torch.bool)
    
    features_padded = pad_sequence(features, batch_first=True)
    target_padded = pad_sequence(target, batch_first=True)
    
    return {"features": features_padded, "target": target_padded, 
            "features_mask": features_mask, "target_mask": target_mask}

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_function)
# test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=collate_function)

# Building model

In [40]:
config_encoder = BertConfig()
config_decoder = BertConfig()

config_encoder.vocab_size = tokenizer.vocab_size()
config_decoder.vocab_size = tokenizer.vocab_size()
config_decoder.is_decoder = True
config_decoder.add_cross_attention = True

config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
model = EncoderDecoderModel(config=config)

optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

# Train model

In [28]:
loaders = {
    "train": train_loader
}

In [43]:
class CustomRunner(dl.Runner):
    def _handle_batch(self, batch):
        # model train/valid step
        features = batch["features"]
        features_mask = batch["features_mask"]
        target = batch["target"]
        target_mask = batch["target_mask"]
        
        output = model(input_ids=features, decoder_input_ids=target, labels=target,
                       attention_mask=features_mask, decoder_attention_mask=target_mask)

        self.batch_metrics.update(
            {"loss": output.loss.cpu().item()}
        )
        if self.is_train_loader:
            output.loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

In [44]:
runner = CustomRunner()
# model training
runner.train(
    model=model,
    optimizer=optimizer,
    loaders=loaders,
    logdir="./logs",
    num_epochs=5,
    verbose=True,
    load_best_on_end=True
)








1/5 * Epoch (train):   0% 0/2 [02:29<?, ?it/s][A[A[A[A[A[A[A
1/5 * Epoch (train):   0% 0/2 [13:22<?, ?it/s]
1/5 * Epoch (train):   0% 0/2 [11:57<?, ?it/s]
1/5 * Epoch (train):   0% 0/2 [07:29<?, ?it/s]
1/5 * Epoch (train):   0% 0/2 [10:10<?, ?it/s]
1/5 * Epoch (train):   0% 0/2 [03:06<?, ?it/s]







1/5 * Epoch (train):   0% 0/2 [01:14<?, ?it/s, loss=10.858][A[A[A[A[A[A[A






1/5 * Epoch (train):  50% 1/2 [01:14<01:14, 74.22s/it, loss=10.858][A[A[A[A[A[A[A






1/5 * Epoch (train):  50% 1/2 [01:55<01:14, 74.22s/it, loss=4.640] [A[A[A[A[A[A[A






1/5 * Epoch (train): 100% 2/2 [01:55<00:00, 64.22s/it, loss=4.640][A[A[A[A[A[A[A






1/5 * Epoch (train): 100% 2/2 [01:55<00:00, 57.56s/it, loss=4.640]
[2020-12-18 00:43:34,648] 
1/5 * Epoch 1 (train): loss=8.7004
2/5 * Epoch (train): 100% 2/2 [02:03<00:00, 61.72s/it, loss=13.809]
[2020-12-18 00:45:46,638] 
2/5 * Epoch 2 (train): loss=10.3554
3/5 * Epoch (train): 100% 2/2 [03:37<00:00, 108.84s/

# Generate melodies

In [138]:
z = model.predict(test)

100%|██████████| 10200/10200 [04:08<00:00, 41.06it/s]


# Submit midi

In [140]:
#!nirvana
yadc_submit_results --result-dir ./predict_midi --user $user

[1m[Step 1/2][0m Submitting audio files from ./predict_midi ...


Error: Unexpected error has occurred during execution of your task. Please, try again or contact us via Support: https://console.cloud.yandex.ru/support

# Вспомогательные функции (уже загружены в состояние)

In [3]:
def download_file(url: str, local_filename: str = None) -> str:
    local_filename = local_filename or url.split('/')[-1]
    with requests.get(url, stream=True) as r:
        with open(local_filename, 'wb') as f:
            shutil.copyfileobj(r.raw, f)
    return local_filename

In [4]:
def unzip(zip_filename: str, dst_dir: str = './'):
    with zipfile.ZipFile(zip_filename, 'r') as zf:
        for entry in tqdm(zf.infolist(), desc='Extracting '):
            try:
                zf.extract(entry, dst_dir)
            except zipfile.error as e:
                print(e)
                pass

In [5]:
def collect_children(path: Path) -> List[Path]:
    if path.is_file():
        return [path]
    else:
        result = []
        for child in path.iterdir():
            if child.is_file():
                result.append(child)
            elif child.is_dir():
                result.extend(collect_children(child))
        return result

In [6]:
def read_abc(path: Path) -> List[List[str]]:
    header = []
    bars = []
    is_header = True
    with open(path, 'r') as input_file:
        for line in input_file:
            line = line.strip()
            if is_header and (not line or line.startswith('%') or (line[0] in ascii_letters and line[1] == ':')):
                header.append(line)
            else:
                is_header = False
                bars.append(line)
    return header, bars

In [7]:
def abc2midi(abc_file: Path, midi_file: Path) -> None:
    command = f'abc2midi {abc_file} -o {midi_file}'
    subprocess.run(command.split(), timeout=2, check=True)

In [8]:
def plot_pianoroll(path: Path, title: str = '') -> None:
    ext = path.suffix
    midi_file = path if ext == '.mid' else Path('tmp') / f'{path.stem}.mid'
    if ext == '.abc' and not midi_file.is_file():
        abc2midi(path, midi_file)
    if midi_file.is_file():
        music21.converter.parse(midi_file).plot('pianoroll', title=title)

In [9]:
def player(path: Path):
    ext = path.suffix
    wav_file = path if ext == '.wav' else Path('tmp') / f'{path.stem}.wav'
    if ext == '.mid':
        midi2wav(path, wav_file)
    elif ext == '.abc':
        midi_file = Path('tmp') / f'{path.stem}.mid'
        abc2midi(path, midi_file)
        midi2wav(midi_file, wav_file)
    if wav_file.is_file():
        return Audio(wav_file)
    else:
        print(f'could not convert {path} to .wav')
    return None

In [10]:
def midi2wav(midi_file: Path, wav_file: Path) -> None:
    FluidSynth(sound_font='GeneralUser GS 1.471/GeneralUser GS v1.471.sf2', sample_rate=8000).midi_to_audio(midi_file, wav_file)