In [2]:
import glob
import os
import sys
import math
import time
import random
import pickle
import joblib
from tqdm import tqdm
import torch
import pretty_midi

from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.modules.normalization import LayerNorm

In [None]:
dataset_pickle = '/home/storage/3020/db/K_cluster2_backup/TD/data/data_split.pkl'
with open(dataset_pickle, 'rb') as f:
    files = pickle.load(f)

train_list, val_list, test_list = [], [], []

for file in tqdm(files['train']):
    seqs = joblib.load(file)
    for idx, seq in enumerate(seqs):
        if len(seq) == 0:
            continue
        train_list.append(seq)
        
for file in tqdm(files['val']):
    seqs = joblib.load(file)
    for idx, seq in enumerate(seqs):
        if len(seq) == 0:
            continue
        val_list.append(seq)
    
for file in tqdm(files['test']):
    seqs = joblib.load(file)
    for idx, seq in enumerate(seqs):
        if len(seq) == 0:
            continue
        test_list.append(seq)

In [None]:
SEQUENCE_START = 0
RANGE_NOTE_ON = 128
RANGE_NOTE_OFF = 128
RANGE_VEL = 32
RANGE_TIME_SHIFT = 100

START_IDX = {
    'note_on': 0,
    'note_off': RANGE_NOTE_ON,
    'time_shift': RANGE_NOTE_ON + RANGE_NOTE_OFF,
    'velocity': RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT
}


SEPERATOR               = "========================="

# Taken from the paper
ADAM_BETA_1             = 0.9
ADAM_BETA_2             = 0.98
ADAM_EPSILON            = 10e-9

LR_DEFAULT_START        = 1.0
SCHEDULER_WARMUP_STEPS  = 4000
# LABEL_SMOOTHING_E       = 0.1

# DROPOUT_P               = 0.1

TOKEN_END               = RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_VEL + RANGE_TIME_SHIFT
TOKEN_PAD               = TOKEN_END + 1

VOCAB_SIZE              = TOKEN_PAD + 1

TORCH_FLOAT             = torch.float32
TORCH_INT               = torch.int32

TORCH_LABEL_TYPE        = torch.long

PREPEND_ZEROS_WIDTH     = 4

TORCH_CPU_DEVICE = torch.device("cpu")
USE_CUDA = 1
TORCH_CUDA_DEVICE = torch.device("cuda:8")

In [None]:
def cpu_device():

    return TORCH_CPU_DEVICE

def get_device():
    if((not USE_CUDA) or (TORCH_CUDA_DEVICE is None)):
        return TORCH_CPU_DEVICE
    else:
        return TORCH_CUDA_DEVICE

# GPT-1 generation

In [4]:
class SustainAdapter:
    def __init__(self, time, type):
        self.start =  time
        self.type = type


class SustainDownManager:
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.managed_notes = []
        self._note_dict = {} # key: pitch, value: note.start

    def add_managed_note(self, note: pretty_midi.Note):
        self.managed_notes.append(note)

    def transposition_notes(self):
        for note in reversed(self.managed_notes):
            try:
                note.end = self._note_dict[note.pitch]
            except KeyError:
                note.end = max(self.end, note.end)
            self._note_dict[note.pitch] = note.start


# Divided note by note_on, note_off
class SplitNote:
    def __init__(self, type, time, value, velocity):
        ## type: note_on, note_off
        self.type = type
        self.time = time
        self.velocity = velocity
        self.value = value

    def __repr__(self):
        return '<[SNote] time: {} type: {}, value: {}, velocity: {}>'\
            .format(self.time, self.type, self.value, self.velocity)


class Event:
    def __init__(self, event_type, value):
        self.type = event_type
        self.value = value

    def __repr__(self):
        return '<Event type: {}, value: {}>'.format(self.type, self.value)

    def to_int(self):
        return START_IDX[self.type] + self.value

    @staticmethod
    def from_int(int_value):
        info = Event._type_check(int_value)
        return Event(info['type'], info['value'])

    @staticmethod
    def _type_check(int_value):
        range_note_on = range(0, RANGE_NOTE_ON)
        range_note_off = range(RANGE_NOTE_ON, RANGE_NOTE_ON+RANGE_NOTE_OFF)
        range_time_shift = range(RANGE_NOTE_ON+RANGE_NOTE_OFF,RANGE_NOTE_ON+RANGE_NOTE_OFF+RANGE_TIME_SHIFT)

        valid_value = int_value

        if int_value in range_note_on:
            return {'type': 'note_on', 'value': valid_value}
        elif int_value in range_note_off:
            valid_value -= RANGE_NOTE_ON
            return {'type': 'note_off', 'value': valid_value}
        elif int_value in range_time_shift:
            valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF)
            return {'type': 'time_shift', 'value': valid_value}
        else:
            valid_value -= (RANGE_NOTE_ON + RANGE_NOTE_OFF + RANGE_TIME_SHIFT)
            return {'type': 'velocity', 'value': valid_value}


def _divide_note(notes):
    result_array = []
    notes.sort(key=lambda x: x.start)

    for note in notes:
        on = SplitNote('note_on', note.start, note.pitch, note.velocity)
        off = SplitNote('note_off', note.end, note.pitch, None)
        result_array += [on, off]
    return result_array


def _merge_note(snote_sequence):
    note_on_dict = {}
    result_array = []

    for snote in snote_sequence:
        # print(note_on_dict)
        if snote.type == 'note_on':
            note_on_dict[snote.value] = snote
        elif snote.type == 'note_off':
            try:
                on = note_on_dict[snote.value]
                off = snote
                if off.time - on.time == 0:
                    continue
                result = pretty_midi.Note(on.velocity, snote.value, on.time, off.time)
                result_array.append(result)
            except:
                print('info removed pitch: {}'.format(snote.value))
    return result_array


def _snote2events(snote: SplitNote, prev_vel: int):
    result = []
    if snote.velocity is not None:
        modified_velocity = snote.velocity // 4
        if prev_vel != modified_velocity:
            result.append(Event(event_type='velocity', value=modified_velocity))
    result.append(Event(event_type=snote.type, value=snote.value))
    return result


def _event_seq2snote_seq(event_sequence):
    timeline = 0
    velocity = 0
    snote_seq = []

    for event in event_sequence:
        if event.type == 'time_shift':
            timeline += ((event.value+1) / 100)
        if event.type == 'velocity':
            velocity = event.value * 4
        else:
            snote = SplitNote(event.type, timeline, event.value, velocity)
            snote_seq.append(snote)
    return snote_seq


def _make_time_sift_events(prev_time, post_time):
    time_interval = int(round((post_time - prev_time) * 100))
    results = []
    while time_interval >= RANGE_TIME_SHIFT:
        results.append(Event(event_type='time_shift', value=RANGE_TIME_SHIFT-1))
        time_interval -= RANGE_TIME_SHIFT
    if time_interval == 0:
        return results
    else:
        return results + [Event(event_type='time_shift', value=time_interval-1)]


def _control_preprocess(ctrl_changes):
    sustains = []

    manager = None
    for ctrl in ctrl_changes:
        if ctrl.value >= 64 and manager is None:
            # sustain down
            manager = SustainDownManager(start=ctrl.time, end=None)
        elif ctrl.value < 64 and manager is not None:
            # sustain up
            manager.end = ctrl.time
            sustains.append(manager)
            manager = None
        elif ctrl.value < 64 and len(sustains) > 0:
            sustains[-1].end = ctrl.time
    return sustains


def _note_preprocess(susteins, notes):
    note_stream = []

    if susteins:    # if the midi file has sustain controls
        for sustain in susteins:
            for note_idx, note in enumerate(notes):
                if note.start < sustain.start:
                    note_stream.append(note)
                elif note.start > sustain.end:
                    notes = notes[note_idx:]
                    sustain.transposition_notes()
                    break
                else:
                    sustain.add_managed_note(note)

        for sustain in susteins:
            note_stream += sustain.managed_notes
    
    else:       # else, just push everything into note stream
        for note_idx, note in enumerate(notes):
            note_stream.append(note)

    note_stream.sort(key= lambda x: x.start)
    return note_stream

In [5]:
class EPianoDataset(Dataset):

    def __init__(self, midi_list, max_seq=2048, random_seq=True):
        self.max_seq    = max_seq
        self.random_seq = random_seq
        self.data_files = midi_list

    def __len__(self):

        return len(self.data_files)

    def __getitem__(self, idx):

        raw_mid = torch.tensor(self.data_files[idx], dtype=TORCH_LABEL_TYPE, device=cpu_device())
        x, tgt = process_midi(raw_mid, self.max_seq, self.random_seq)

        return x, tgt
    
def process_midi(raw_mid, max_seq, random_seq):

    x   = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=cpu_device())
    tgt = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=cpu_device())

    raw_len     = len(raw_mid)
    full_seq    = max_seq + 1 # Performing seq2seq

    if(raw_len == 0):
        return x, tgt

    if(raw_len < full_seq):
        if tgt.shape[0] == raw_len:
            #print(f'Tgt shape: {tgt.shape} Raw len: {raw_len} Skipping')
            x[:raw_len]         = raw_mid
            tgt[:raw_len-1]     = raw_mid[1:]
            tgt[raw_len-1]        = TOKEN_END
        else:
            x[:raw_len]         = raw_mid
            tgt[:raw_len-1]     = raw_mid[1:]
            tgt[raw_len]        = TOKEN_END
    else:
        # Randomly selecting a range
        if(random_seq):
            end_range = raw_len - full_seq
            start = random.randint(SEQUENCE_START, end_range)

        # Always taking from the start to as far as we can
        else:
            start = SEQUENCE_START

        end = start + full_seq

        data = raw_mid[start:end]

        x = data[:max_seq]
        tgt = data[1:full_seq]

    return x, tgt

def decode_midi(idx_array, file_path=None):
    event_sequence = [Event.from_int(idx) for idx in idx_array]
    # print(event_sequence)
    snote_seq = _event_seq2snote_seq(event_sequence)
    note_seq = _merge_note(snote_seq)
    note_seq.sort(key=lambda x:x.start)

    mid = pretty_midi.PrettyMIDI()
    # if want to change instument, see https://www.midi.org/specifications/item/gm-level-1-sound-set
    instument = pretty_midi.Instrument(1, False, "Generated by Music Transformer AI")
    instument.notes = note_seq

    mid.instruments.append(instument)
    if file_path is not None:
        mid.write(file_path)
    return mid

In [6]:
n_workers = 1
batch_size = 2
random_seq = True

rpr = False #'store_true'
max_seq = 512 # Used later to generate primers
n_layers = 6
num_heads = 8
d_model = 512
dim_feedforward = 1024
dropout = 0.1

Pre-saved train/val/test split loading

In [7]:
with open('/home/storage/3020/db/K_cluster2_backup/TD/data/data_split.pkl', 'rb') as f:
    data = pickle.load(f)
print(len(data))
print(data.keys())


test_list = []
for file in tqdm(data['test']):
    seqs = joblib.load(file)
    for idx, seq in enumerate(seqs):
        if len(seq) == 0:
            continue
        test_list.append(seq)

  0%|          | 0/125 [00:00<?, ?it/s]

3
dict_keys(['train', 'val', 'test'])


100%|██████████| 125/125 [01:44<00:00,  1.20it/s]


In [8]:
print(f'Test length: {len(test_list)}')
test_dataset = EPianoDataset(test_list, max_seq, random_seq)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=n_workers)

Test length: 34815


In [9]:
class MusicTransformer(nn.Module):

    def __init__(self, n_layers=6, num_heads=8, d_model=512, dim_feedforward=1024,
                 dropout=0.1, max_sequence=2048, rpr=False):
        super(MusicTransformer, self).__init__()

        self.dummy      = DummyDecoder()

        self.nlayers    = n_layers
        self.nhead      = num_heads
        self.d_model    = d_model
        self.d_ff       = dim_feedforward
        self.dropout    = dropout
        self.max_seq    = max_sequence
        self.rpr        = rpr

        self.embedding = nn.Embedding(VOCAB_SIZE, self.d_model)

        self.positional_encoding = PositionalEncoding(self.d_model, self.dropout, self.max_seq)

        if(not self.rpr):
            self.transformer = nn.Transformer(
                d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
                num_decoder_layers=0, dropout=self.dropout, # activation=self.ff_activ,
                dim_feedforward=self.d_ff, custom_decoder=self.dummy
            )
        # RPR Transformer
        else:
            encoder_norm = LayerNorm(self.d_model)
            encoder_layer = TransformerEncoderLayerRPR(self.d_model, self.nhead, self.d_ff, self.dropout, er_len=self.max_seq)
            encoder = TransformerEncoderRPR(encoder_layer, self.nlayers, encoder_norm)
            self.transformer = nn.Transformer(
                d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
                num_decoder_layers=0, dropout=self.dropout,
                dim_feedforward=self.d_ff, custom_decoder=self.dummy, custom_encoder=encoder
            )

        self.Wout       = nn.Linear(self.d_model, VOCAB_SIZE)
        self.softmax    = nn.Softmax(dim=-1)

    # forward
    def forward(self, x, mask=True):
        if(mask is True):
            mask = self.transformer.generate_square_subsequent_mask(x.shape[1]).to(get_device())
        else:
            mask = None
        x = self.embedding(x)
        x = x.permute(1,0,2)
        x = self.positional_encoding(x)
        x_out = self.transformer(src=x, tgt=x, src_mask=mask)
        x_out = x_out.permute(1,0,2)
        y = self.Wout(x_out)
        del mask
        return y

    # generate
    def generate(self, primer=None, target_seq_length=1024, beam=0, beam_chance=1.0):

        assert (not self.training), "Cannot generate while in training mode"

        print("Generating sequence of max length:", target_seq_length)

        gen_seq = torch.full((1,target_seq_length), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())

        num_primer = len(primer)
        gen_seq[..., :num_primer] = primer.type(TORCH_LABEL_TYPE).to(get_device())

        cur_i = num_primer
        while(cur_i < target_seq_length):

            y = self.softmax(self.forward(gen_seq[..., :cur_i]))[..., :TOKEN_END]
            token_probs = y[:, cur_i-1, :]

            if(beam == 0):
                beam_ran = 2.0
            else:
                beam_ran = random.uniform(0,1)

            if(beam_ran <= beam_chance):
                token_probs = token_probs.flatten()
                top_res, top_i = torch.topk(token_probs, beam)

                beam_rows = top_i // VOCAB_SIZE
                beam_cols = top_i % VOCAB_SIZE

                gen_seq = gen_seq[beam_rows, :]
                gen_seq[..., cur_i] = beam_cols

            else:
                distrib = torch.distributions.categorical.Categorical(probs=token_probs)
                next_token = distrib.sample()
                gen_seq[:, cur_i] = next_token


                # Let the transformer decide to end if it wants to
                if(next_token == TOKEN_END):
                    print("Model called end of sequence at:", cur_i, "/", target_seq_length)
                    break

            cur_i += 1
            if(cur_i % 50 == 0):
                print(cur_i, "/", target_seq_length)

        return gen_seq[:, :cur_i]

class DummyDecoder(nn.Module):

    def __init__(self):
        super(DummyDecoder, self).__init__()

    def forward(self, tgt, memory, tgt_mask, memory_mask,tgt_key_padding_mask,memory_key_padding_mask):

        return memory
    
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [10]:
model = MusicTransformer(n_layers=n_layers, num_heads=num_heads,
            d_model=d_model, dim_feedforward=dim_feedforward, dropout=dropout,
            max_sequence=2048, rpr=rpr).to(get_device())

In [11]:
model.load_state_dict(torch.load('/home/storage/3020/db/K_cluster2_backup/TD/gpt1_best_acc_bsize2.pth', map_location=get_device()))
model.eval()

MusicTransformer(
  (dummy): DummyDecoder()
  (embedding): Embedding(390, 512)
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): _LinearWithBias(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=1024, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=1024, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
        (1): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (

# Random generation  
Generate without using beam search

In [41]:
idx = 0
primer, _  = test_dataset[idx]
primer = primer.to(get_device())
print(primer.shape)

out_dir = '/home/storage/3020/db/K_cluster2_backup/TD/gpt1_examples'
save_path = os.path.join(out_dir, f"primer_{idx}.mid")
decode_midi(primer[:max_seq].cpu().numpy(), file_path=save_path)

torch.Size([512])
info removed pitch: 54


<pretty_midi.pretty_midi.PrettyMIDI at 0x7f9f6e5f9c10>

In [39]:
target_seq_length = 2048
rand_seq = model.generate(primer[:max_seq], target_seq_length, beam=0)

f_path = os.path.join(out_dir, f"rand_{idx}.mid")
decode_midi(rand_seq[0].cpu().numpy(), file_path=f_path)

Generating sequence of max length: 2048
550 / 2048
600 / 2048
650 / 2048
700 / 2048
750 / 2048
800 / 2048
850 / 2048
900 / 2048
950 / 2048
1000 / 2048
1050 / 2048
1100 / 2048
1150 / 2048
1200 / 2048
1250 / 2048
1300 / 2048
1350 / 2048
1400 / 2048
1450 / 2048
1500 / 2048
1550 / 2048
1600 / 2048
1650 / 2048
1700 / 2048
1750 / 2048
1800 / 2048
1850 / 2048
1900 / 2048
1950 / 2048
2000 / 2048
info removed pitch: 66


<pretty_midi.pretty_midi.PrettyMIDI at 0x7f9f6e63dad0>

# Beam search generation  
Generate from same primer using beam search of 2

In [13]:
idx = 0
primer, _  = test_dataset[idx]
primer = primer.to(get_device())
out_dir = '/home/storage/3020/db/K_cluster2_backup/TD/gpt1_examples'
save_path = os.path.join(out_dir, f"primerbeam_{idx}.mid")
decode_midi(primer[:max_seq].cpu().numpy(), file_path=save_path)

info removed pitch: 54
info removed pitch: 66
info removed pitch: 62
info removed pitch: 57


<pretty_midi.pretty_midi.PrettyMIDI at 0x7f0e4fd0a990>

In [14]:
target_seq_length = 2048
beam = 2
beam_seq = model.generate(primer[:max_seq], target_seq_length, beam=beam)

f_path = os.path.join(out_dir, f"beam{beam}_{idx}.mid")

decode_midi(beam_seq[0].cpu().numpy(), file_path=f_path)

Generating sequence of max length: 2048
550 / 2048
600 / 2048
650 / 2048
700 / 2048
750 / 2048
800 / 2048
850 / 2048
900 / 2048
950 / 2048
1000 / 2048
1050 / 2048
1100 / 2048
1150 / 2048
1200 / 2048
1250 / 2048
1300 / 2048
1350 / 2048
1400 / 2048
1450 / 2048
1500 / 2048
1550 / 2048
1600 / 2048
1650 / 2048
1700 / 2048
1750 / 2048
1800 / 2048
1850 / 2048
1900 / 2048
1950 / 2048
2000 / 2048
info removed pitch: 54
info removed pitch: 66
info removed pitch: 62
info removed pitch: 57
info removed pitch: 127
info removed pitch: 126
info removed pitch: 29
info removed pitch: 126
info removed pitch: 41
info removed pitch: 27
info removed pitch: 24
info removed pitch: 126
info removed pitch: 127
info removed pitch: 126
info removed pitch: 126
info removed pitch: 126
info removed pitch: 55
info removed pitch: 126
info removed pitch: 126
info removed pitch: 126
info removed pitch: 126
info removed pitch: 126
info removed pitch: 126
info removed pitch: 59
info removed pitch: 126
info removed pitch:

<pretty_midi.pretty_midi.PrettyMIDI at 0x7f0e7e7f8910>