In [None]:
%%capture
!pip install transformers
!pip install pretty_midi
!pip install gdown
!pip install music21

In [24]:
from transformers import GPTNeoForCausalLM
import torch
import os
import pretty_midi
from tqdm.notebook import tqdm
from random import randint
import pandas as pd
from sklearn.model_selection import StratifiedKFold

In [25]:
RANGE_NOTE_ON = 128
RANGE_NOTE_OFF = 128
RANGE_VEL = 32
RANGE_TIME_SHIFT = 100

TOKEN_END = RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_VEL + RANGE_TIME_SHIFT
TOKEN_PAD = TOKEN_END + 1
TOKEN_START = TOKEN_END + 1
VOCAB_SIZE = TOKEN_PAD + 1 + 4 + 1

In [27]:
import numpy as np

NOTE_ON = 0
NOTE_OFF = 1
SET_VELOCITY = 2
TIME_SHIFT = 3

MAX_TIME_SHIFT = 1.0
TIME_SHIFT_STEP = 0.01
RANGES = [128,128,32,100]

PIANO_RANGE = [21,96]  # 76 piano keys


def encode(midi, use_piano_range=True):
    """
    Encodes midi to event-based sequences for MusicTransformer.
    
    Parameters
    ----------
    midi : prettyMIDI object
        MIDI to encode.
    use_piano_range : bool
        if True, classical piano range will be used for skip pitches. Pitches which are not in range PIANO_RANGE will be skipped.
    
    Returns
    -------
    encoded_splits : list(list())
        splits of encoded sequences.
    """
    events = get_events(midi, use_piano_range=use_piano_range)
    if len(events) == 0:
        return []
    quantize_(events)
    add_time_shifts(events)
    encoded = encode_events(events)
    return encoded
    
    
def decode(encoded):
    """
    Decode event-based encoded sequence into MIDI object.
    
    Parameters
    ----------
    encoded : np.array or list
        encoded sequence to decode.
    
    Returns
    -------
    midi_out: PrettyMIDI object
        decoded MIDI.
    """
    midi_out = pretty_midi.PrettyMIDI()
    midi_out.instruments.append(pretty_midi.Instrument(0, name='piano'))
    notes = midi_out.instruments[0].notes
    
    notes_tmp = {}  # pitch: [vel, start, end]
    cur_time = 0
    cur_velocity = 100
    for ev in encoded:
        if ev < RANGES[0]:
            # NOTE_ON
            pitch = ev
            if notes_tmp.get(pitch) is None:
                notes_tmp[pitch] = [cur_velocity, cur_time]
        elif ev >= RANGES[0] and ev < sum(RANGES[:2]):
            # NOTE_OFF
            pitch = ev - RANGES[0]
            note = notes_tmp.get(pitch)
            if note is not None:  # check for overlaps (first-OFF mode)
                notes.append(pretty_midi.Note(note[0], pitch, note[1], cur_time))
                notes_tmp.pop(pitch)
        elif ev >= sum(RANGES[:2]) and ev < sum(RANGES[:3]):
            # SET_VELOCITY
            cur_velocity = max(1,(ev - sum(RANGES[:2]))*128//RANGES[2])
        elif ev >= sum(RANGES[:3]) and ev < sum(RANGES[:]):
            # TIME_SHIFT
            cur_time += (ev - sum(RANGES[:3]) + 1)*TIME_SHIFT_STEP
        else:
            continue

    for pitch, note in notes_tmp.items():
        if note[1] != cur_time:
            notes.append(pretty_midi.Note(note[0], pitch, note[1], cur_time))
        
    return midi_out


def round_step(x, step=0.01):
    return round(x/step)*step


def get_events(midi, use_piano_range=False):
    # helper function used in encode()
    # time, type, value
    events = []
    for inst in midi.instruments:
        if inst.is_drum:
            continue
        for note in inst.notes:
            if use_piano_range and not (PIANO_RANGE[0] <= note.pitch <= PIANO_RANGE[1]):
                continue
            start = note.start
            end = note.end
            events.append([start, SET_VELOCITY, note.velocity])
            events.append([start, NOTE_ON, note.pitch])
            events.append([end, NOTE_OFF, note.pitch])
    events = sorted(events, key=lambda x: x[0])
    return events


def quantize_(events):
    for ev in events:
        ev[0] = round_step(ev[0])


def add_time_shifts(events):
    # populate time_shifts, helper function used in encode()
    times = np.array(list(zip(*events)))[0]
    diff = np.diff(times, prepend=0)
    idxs = diff.nonzero()[0]
    for i in reversed(idxs):
        if i == 0:
            continue
        t0 = events[i-1][0] # if i != 0 else 0
        t1 = events[i][0]
        dt = t1-t0
        events.insert(i, [t0, TIME_SHIFT, dt])


def encode_events(events):
    # helper function used in encode()
    out = []
    types = []
    for time, typ, value in events:
        offset = sum(RANGES[:typ])

        if typ == SET_VELOCITY:
            value = value*RANGES[SET_VELOCITY]//128
            out.append(offset+value)
            types.append(typ)

        elif typ == TIME_SHIFT:
            dt = value
            n = RANGES[TIME_SHIFT]
            enc = lambda x: int(x*n)-1
            for _ in range(int(dt//MAX_TIME_SHIFT)):
                out.append(offset+enc(MAX_TIME_SHIFT))
                types.append(typ)
            r = round_step(dt%MAX_TIME_SHIFT, TIME_SHIFT_STEP)
            if r > 0:
                out.append(offset+enc(r))
                types.append(typ)

        else:
            out.append(offset+value)
            types.append(typ)
            
    return out


RANGES_SUM = np.cumsum(RANGES)


def get_type(ev):
    if ev < RANGES_SUM[0]:
        # NOTE_ON
        return 0
    elif ev < RANGES_SUM[1]:
        # NOTE_OFF
        return 1
    elif ev < RANGES_SUM[2]:
        # VEL
        return 2
    elif ev < RANGES_SUM[3]:
        # TS
        return 3
    else:
        return -1

    
def filter_bad_note_offs(events):
    """Clear NOTE_OFF events for which the corresponding NOTE_ON event is missing."""
    notes_down = {}  # pitch: 1
    keep_idxs = set(range(len(events)))

    for i,ev in enumerate(events):
        typ = get_type(ev)

        if typ == NOTE_ON:
            pitch = ev
            notes_down[pitch] = 1
        if typ == NOTE_OFF:
            pitch = ev-128
            if notes_down.get(pitch) is None:
                # if NOTE_OFF without NOTE_ON, then remove the event
                keep_idxs.remove(i)
            else:
                notes_down.pop(pitch)
    
    return list(keep_idxs)

In [28]:
import torch
import numpy as np
from tqdm import tqdm



def generate(model, primer, target_seq_length=1024, temperature=1.0, topk=40, topp=0.99, topp_temperature=1.0, at_least_k=1, use_rp=False, rp_penalty=0.05, rp_restore_speed=0.7, seed=None, **forward_args):
    """
    Generate batch of samples, conditioned on `primer`. There are used several techniques for acquiring better generated samples such as:
    - temperature skewing for controlling entropy of distribuitions
    - top-k sampling
    - top-p (nucleus) sampling (https://arxiv.org/abs/1904.09751)
    - DynamicRepetitionPenaltyProcessor that prevents notes repeating
    values by default usualy are suitable for our models
        
    Parameters
    ----------
    model : MusicTransformer
        trained model.
    primer : torch.Tensor (B x N)
        primer for condition on.
        B = batch_size, N = seq_lenght.
        We are using the primer consisted of one token - genre. These tokens are {390:'classic', 391:'jazz', 392:'calm', 393:'pop'}.
    target_seq_length : int
        desired length  of generated sequences.
    temperature : float
        temperature alters the output distribuition of the model. Higher values ( > 1.0) lead to more stohastic sampling, lower values lead to more expected and predictable sequences (ending up with endlessly repeating musical patterns).
    topk : int
        restricts sampling from lower probabilities. It is the length of set of tokens from which sampling will be.
    topp : float
        restricts sampling from lower probabilities, but more adaptive then topk. see (https://arxiv.org/abs/1904.09751).
    topp_temperature : float
        temperature for counting cumulative sum doing topp sampling.
    at_least_k : int
        like topk, but force to sample from at least k tokens of higher probabilities.
    use_rp : bool
        use or not the DynamicRepetitionPenaltyProcessor (RP). Trying to prevent the generation of repeated notes.
    rp_penalty : float
        coef for RP. Higher values lead to more RP impact.
    rp_restore_speed : float
        how fast the penalty will be lifted. Lower values lead to more RP impact.
    seed : int
        fixes seed for deterministic generation.
    forward_args : dict
        args for model's forward.
        
    Returns
    -------
    generated : torch.Tensor (B x target_seq_length)
        generated batch of sequences.
    """
    device = model.device
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
    
    if at_least_k < 1:
        at_least_k = 1
    B,N = primer.shape
    generated = torch.full((B,target_seq_length), TOKEN_PAD, dtype=torch.int64, device=device)
    generated[..., :N] = primer.to(device)
    
    if use_rp:
        RP_processor = DynamicRepetitionPenaltyProcessor(B, penalty=rp_penalty, restore_speed=rp_restore_speed, device=device)
    whitelist_mask = make_whitelist_mask()
    
    model.eval()
    with torch.no_grad():
        for i in tqdm(range(N, target_seq_length)):
            logits = model(generated[:, :i], **forward_args)[:, i-1, :]
            logits[:,~whitelist_mask] = float('-inf')
            p = torch.softmax(logits/topp_temperature, -1)
            
            # apply topk:
            if topk == 0:
                topk = p.shape[-1]
            p_topk, idxs = torch.topk(p, topk, -1, sorted=True)
            
            # apply topp:
            mask = p_topk.cumsum(-1) < topp
            mask[:,:at_least_k] = True
            logits_masked = logits.gather(-1, idxs)
            logits_masked[~mask] = float('-inf')
            p_topp = torch.softmax(logits_masked/temperature, -1)
            
            # apply penalty:
            if use_rp:
                p_penalized = RP_processor.apply_penalty(p_topp, idxs)
                ib = p_penalized.sum(-1) == 0
                if ib.sum() > 0:
                    # if all topp tokens get zeroes due RP_processor, then fallback to topk-sampling
                    p_fallback = p_topk[ib].clone()
                    p_fallback[mask[ib]] = 0.  # zeroing topp
                    p_penalized[ib] = p_fallback
                    
                ib = p_penalized.sum(-1) == 0
                if ib.sum() > 0:
                    # if topk tokens get zeroes, fallback to topp without RP
                    print('fallback-2')
                    p_penalized = p_topp
                p_topp = p_penalized
                    
            # sample:
            next_token = idxs.gather(-1, torch.multinomial(p_topp, 1))
            generated[:, i] = next_token.squeeze(-1)
            
            # update penalty:
            if use_rp:
                RP_processor.update(next_token)

    return generated[:, :i+1]


def post_process(generated, remove_bad_generations=True):
    """
    Post-process does 3 routines:
        1) removes long pauses (3+ seconds)
        2) clips velocities to range(30,100) to avoid dramaticly loud notes, which are not suitable for our case.
        3) removes bad generated samples. The model sometimes may generate music that consists only of many repeating notes. We try to detect them and remove from batch.
        
    Parameters
    ----------
    generated : torch.Tensor (B x N)
        batch of generated samples
        
    Returns
    -------
    filtered_generated : cleaner and slightly better sounding generated batch
    """
    generated = generated.cpu().numpy()
    remove_pauses(generated, 3)
    clip_velocity(generated)
    
    bad_filter = np.ones(len(generated), dtype=bool)
    
    if remove_bad_generations:
        for i, gen in enumerate(generated):
            midi = decode(gen)
            if detect_note_repetition(midi) > 0.9:
                bad_filter[i] = False

        if np.sum(bad_filter) != len(bad_filter):
            print(f'{np.sum(~bad_filter)} bad samples will be removed.')
        
    return generated[bad_filter]
    

def make_whitelist_mask():
    """Generate mask for PIANO_RANGE"""
    whitelist_mask = np.zeros(VOCAB_SIZE, dtype=bool)
    whitelist_mask[PIANO_RANGE[0]:PIANO_RANGE[1]+1] = True
    whitelist_mask[128+PIANO_RANGE[0]:128+PIANO_RANGE[1]+1] = True
    whitelist_mask[128*2:] = True
    return whitelist_mask

    
class DynamicRepetitionPenaltyProcessor:
    """
    The class is trying to prevent cases where the model generates repetitive notes or musical patterns that degrade quality.
    It dynamically reduces and restores the probabilities of generatied notes.
    Each generated note will reduce its probability for the next step by `penalty` value (which is hyperparameter). If this note has been generated again, then we continue to reduce its probability, else we will gradually restore its probability (speed is controlled by restore_speed parameter).
    
    Parameters
    ----------
    bs : int
        batch_size. We need to know batch_size in advance to create the penalty_matrix.
    penalty : float
        value by which the probability will be reduced.
    restore_speed : float
        the number inversed to the number of seconds needs to fully restore probability from 0 to 1.
        for restore_speed equal to 1.0 we need 1.0 sec to restore, for 2.0 - 0.5 sec and so on.
    """
    def __init__(self, bs, device, penalty=0.3, restore_speed=1.0) :
        self.bs = bs
        self.penalty = penalty
        self.restore_speed = restore_speed
        self.penalty_matrix = torch.ones(bs,128).to(device)
        
    def apply_penalty(self, p, idxs):
        p = p.clone()
        for b in range(len(p)):
            i = idxs[b]
            pi = p[b]
            mask = i < 128
            if len(i) > 0:
                pi[mask] = pi[mask]*self.penalty_matrix[b,i[mask]]
        return p
        
    def update(self, next_token):
        restoring = next_token - (128+128+32)  # only TS do restore
        restoring = torch.clamp(restoring.float(), 0, 100)/100*self.restore_speed
        self.penalty_matrix += restoring
        nt = next_token.squeeze(-1)
        nt = next_token[next_token < 128]
        self.penalty_matrix[:, nt] -= restoring + self.penalty
        torch.clamp(self.penalty_matrix, 0, 1.0, out=self.penalty_matrix)
        return restoring, nt
    

def detect_note_repetition(midi, threshold_sec=0.01):
    """
    Returns the fraction of note repetitions. Counts cases where prev_note_end == next_note_start at the same pitch ('glued' notes). Used in detection bad generated samples.
    
    Parameters
    ----------
    midi : prettyMIDI object
    threshold_sec : float
        intervals smaller then threshold_sec are treated as 'glued' notes.
    
    Returns
    -------
    fraction of notes repetitions relative to the number of all notes.
    """
    all_notes = [x for inst in midi.instruments for x in inst.notes if not inst.is_drum]
    if len(all_notes) == 0:
        return 0
    all_notes_np = np.array([[x.start,x.end,x.pitch,x.velocity] for x in all_notes])
    
    i_sort = np.lexsort([all_notes_np[:,0], all_notes_np[:,2]])

    s = []
    cur_p = -1
    cur_t = -1
    for t in all_notes_np[i_sort]:
        a,b,p,v = t
        if cur_p != p:
            cur_p = p
        else:
            s.append(a-cur_t)
        cur_t = b
    s = np.array(s)
    return (s < threshold_sec).sum()/len(s)


def remove_pauses(generated, threshold=3):
    """
    Fills  pauses by constants.TOKEN_PAD values. Only pauses that longer than `threshold` seconds are considered.
    Inplace operation. `generated` is a tensor (batch of sequences).
    
    Parameters
    ----------
    generated : torch.Tensor (B x N)
        generated batch of sequences.
    threshold : int/float
        the minimum seconds of silence to treat them as a pause.
    """
    mask = (generated>=RANGES_SUM[2]) & (generated<RANGES_SUM[3])
    seconds = ((generated-RANGES_SUM[2])+1)*0.01
    seconds[~mask] = 0

    res_ab = [[] for _ in range(seconds.shape[0])]

    for ib,i_seconds in enumerate(seconds):
        a,s = 0,0
        notes_down = np.zeros(128, dtype=bool)
        for i,(t,ev) in enumerate(zip(i_seconds,generated[ib])):
            typ = get_type(ev)
            if typ == NOTE_ON:
                pitch = ev
                notes_down[pitch] = True
            if typ == NOTE_OFF:
                pitch = ev-128
                notes_down[pitch] = False
                    
            if t == 0:
                if s >= threshold and notes_down.sum() == 0:
                    res_ab[ib].append([a,i,s])
                s = 0
                a = i+1
            s += t
        if s >= threshold and notes_down.sum() == 0:
            res_ab[ib].append([a,len(i_seconds),s])
    
    # remove inplace
    for ib,t in enumerate(res_ab):
        for a,b,s in t:
            generated[ib, a:b] = TOKEN_PAD
            print(f'pause removed:',ib,f'n={b-a}',a,b,s)

        
def clip_velocity(generated, min_velocity=30, max_velocity=100):
    """
    Clip velocity to range(min_velocity, max_velocity). Since the model sometimes generate overloud sequences, we try to neutralize this effect.
    Inplace operation. `generated` is a tensor (batch of sequences).
    
    Parameters
    ----------
    generated : torch.Tensor (B x N)
        generated batch of sequences.
    min_velocity : int
    max_velocity : int
    """
    max_velocity_encoded = max_velocity*32//128 + RANGES_SUM[1]
    min_velocity_encoded = min_velocity*32//128 + RANGES_SUM[1]
    
    mask = (generated>=RANGES_SUM[1]) & (generated<RANGES_SUM[2])
    generated[mask] = np.clip(generated[mask], min_velocity_encoded, max_velocity_encoded)

In [29]:
def gen_batch(inputs, batch_size):
    batch_start = 0
    while batch_start < len(inputs):
        yield inputs[batch_start: batch_start + batch_size]
        batch_start += batch_size

class MidiDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        dir,
        max_len=128
    ):  
        self.paths = [os.path.join(dir, i) for i in os.listdir(dir) if i[-4:] == '.mid'][:200]
        self.max_len = max_len
        self.dataset = []
        
        for path in tqdm(self.paths):
            try:
                midi_file = pretty_midi.PrettyMIDI(path)
                notes = encode(midi_file, use_piano_range=False)
                for batch in gen_batch(notes, self.max_len - 2):
                    notes = [TOKEN_START] + batch + [TOKEN_END]
                    attention_mask = list(torch.ones(len(notes)))
                    if len(notes) < self.max_len:
                        attention_mask += [0] * (self.max_len - len(notes))
                        notes += [TOKEN_PAD] * (self.max_len - len(notes))
                    self.dataset.append([torch.tensor(notes).long(), torch.tensor(attention_mask).long()])
            except: pass
            
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self,idx):
        return self.dataset[idx]

In [33]:
!rm -rf music && mkdir music
!cd music && gdown --fuzzy https://drive.google.com/file/d/13gDAcndQufUD27dNZg-OhxNk5-vgxtan/view?usp=sharing \
    && unzip midi_post_train.zip && rm -rf midi_post_train.zip

In [None]:
!gdown --fuzzy https://drive.google.com/file/d/1ZLhYZObENV1_oS_CPqrPZrThwrWH84ju/view?usp=sharing

In [None]:
!unzip -u music_all.zip

In [34]:
train_dataset = MidiDataset('music', 512)

In [35]:
train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=8,
            shuffle=True,
            num_workers=2,
        )

In [36]:
model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M')

In [44]:
model.load_state_dict(torch.load("../input/midineo/gpt_neo_all_genres_ep7")["model"])

In [45]:
#model.lm_head = torch.nn.Linear(768, VOCAB_SIZE, bias=False)
model = model.to('cuda')

In [46]:
epochs = 3

In [47]:
optimizer = torch.optim.AdamW(model.parameters(),
                  lr = 2e-5,
                  eps = 1e-8
                )

In [48]:
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=len(train_loader) * epochs, max_lr=1e-4, div_factor=25, pct_start=0.1)

In [49]:
def save_model(name, model, epoch=None, train_loss=None, val_loss=None, optimizer=None, scheduler=None):
    '''Save PyTorch model.'''

    torch.save({
        'model': model.state_dict(),
        'epoch': epoch,
        'train_loss': train_loss,
#         'val_loss': val_loss,
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
    }, os.path.join('./', name))

In [50]:
from tqdm.notebook import tqdm
average_loss = []
for epoch in tqdm(range(epochs)):
    average_loss = []
    for i, batch in enumerate(tqdm(train_loader)):
        optimizer.zero_grad() 
        input_ids, attention_mask = batch[0].to('cuda'), batch[1].to('cuda')       
        loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids.clone())['loss']
        loss.backward()
        average_loss.append(loss.item())
        optimizer.step()
        scheduler.step()
        print('step', i, 'loss', loss.item(), 'mean_loss', np.array(average_loss).mean())#step 600 loss 2.349747657775879 mean_loss 2.5898751151561736
            
    save_model(f'gpt_neo_all_genres_ep{6 + epoch}', model=model, epoch=epoch, optimizer=optimizer, scheduler=scheduler)
    average_loss = []

In [58]:
def generate_midi_gpt(length=32):
    model.eval()
    input_ids = torch.tensor([TOKEN_START]).unsqueeze(0).long().to('cuda')
    sample_outputs = model.generate(
        input_ids,
        pad_token_id=TOKEN_PAD,
        eos_token_id=TOKEN_END,
        min_length=2048,
        max_length=2048, 
        do_sample=True, 
        top_k=50,
        top_p=0.95,
        no_repeat_ngram_size=2,
        temperature=1,
        num_return_sequences=1,
        repition_penalty=5,
        num_beams=5,
        )
    return sample_outputs

In [59]:
out = generate_midi_gpt()

In [60]:
midi_out = decode(post_process(out[0][1:].unsqueeze(0), remove_bad_generations=True)[0])
midi_out.write('test.mid')

In [None]:
# from music21 import converter,instrument

# s = converter.parse('./test.mid')

# for el in s.recurse():
#     if 'Instrument' in el.classes: # or 'Piano'
#         el.activeSite.replace(el, instrument.Contrabassoon())

# s.write('midi', './test_Contrabassoon.mid')