# Environment Setup


In [None]:
!pip install miditoolkit

Collecting miditoolkit
  Downloading miditoolkit-0.1.16-py3-none-any.whl (20 kB)
Collecting mido>=1.1.16 (from miditoolkit)
  Downloading mido-1.2.10-py2.py3-none-any.whl (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.1/51.1 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mido, miditoolkit
Successfully installed miditoolkit-0.1.16 mido-1.2.10


In [None]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# In case your drive does not have the dataset:
# Link: https://drive.google.com/drive/folders/1CGrVwID0sqdbVe0oI7a4FFRooyuuowIU?usp=sharing

## Configuration

In [None]:
main_path = '/content/drive/MyDrive/THESIS/MuseMorphose'

config = {
    'data': {
        'data_dir': f'{main_path}/remi_dataset',
        'train_split': f'{main_path}/pickles/train_pieces.pkl',
        'val_split': f'{main_path}/pickles/val_pieces.pkl',
        'test_split': f'{main_path}/pickles/test_pieces.pkl',
        'vocab_path': f'{main_path}/pickles/remi_vocab.pkl',
        'max_bars':         16,
        'enc_seqlen':       128,
        'dec_seqlen':       1280,
        'batch_size':        4
    },
    'model': {
        'enc_n_layer':      6,
        'enc_n_head':       4,
        'enc_d_model':      256,
        'enc_d_ff':         1024,
        'dec_n_layer':      6,
        'dec_n_head':       4,
        'dec_d_model':      256,
        'dec_d_ff':         1024,
        'd_embed':          256,
        'd_latent':         64,
        'd_polyph_emb':     32,
        'd_rfreq_emb':      32,
        'cond_mode':        'in-attn',
        'pretrained_params_path':      None,
        'pretrained_optim_path':       None,
    },
    'training': {
        'device':           'cuda:0',
        'ckpt_dir':         './ckpt/enc_dec_12L-16_bars-seqlen_1280',
        'trained_steps':    0,
        'max_epochs':       2,
        'max_lr':           1.0e-4,
        'min_lr':           5.0e-6,
        'lr_warmup_steps':  200,
        'lr_decay_steps':   150000,
        'no_kl_steps':      10000,
        'kl_cycle_steps':   5000,
        'kl_max_beta':      1.0,
        'free_bit_lambda':  0.25,
        'constant_kl':      False,
        'ckpt_interval':    50,
        'log_interval':     10,
        'val_interval':     50,
    },
    'generate': {
        'temperature':                1.2,
        'nucleus_p':                  0.9,
        'use_latent_sampling':        False,
        'latent_sampling_var':        0.0,
        'max_bars':                   16,       # could be set to match the longest input piece during generation (inference)
        'dec_seqlen':                 1280,     # could be set to match the longest input piece during generation (inference)
        'max_input_dec_seqlen':       1024 ,    # should be set to equal to or less than `dec_seqlen` used during training
    }
}

# Import libraries

In [None]:
import os
import pickle
import time
import math
import random
import numpy as np
import miditoolkit

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

from glob import glob
from copy import deepcopy
from scipy.spatial import distance
from scipy.stats import entropy

# Utils

## Converter

In [None]:
def numpy_to_tensor(arr, use_gpu=True, device='cuda:0'):
  if use_gpu:
    return torch.tensor(arr).to(device).float()
  else:
    return torch.tensor(arr).float()

def tensor_to_numpy(tensor):
  return tensor.cpu().detach().numpy()

def pickle_load(f):
  return pickle.load(open(f, 'rb'))

def pickle_dump(obj, f):
  pickle.dump(obj, open(f, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

## REMI2MIDI

In [None]:
##############################
# constants
##############################
DEFAULT_BEAT_RESOL = 480
DEFAULT_BAR_RESOL = 480 * 4
DEFAULT_FRACTION = 16


##############################
# containers for conversion
##############################
class ConversionEvent(object):
  def __init__(self, event, is_full_event=False):
    if not is_full_event:
      if 'Note' in event:
        self.name, self.value = '_'.join(event.split('_')[:-1]), event.split('_')[-1]
      elif 'Chord' in event:
        self.name, self.value = event.split('_')[0], '_'.join(event.split('_')[1:])
      else:
        self.name, self.value = event.split('_')
    else:
      self.name, self.value = event['name'], event['value']
  def __repr__(self):
    return 'Event(name: {} | value: {})'.format(self.name, self.value)

class NoteEvent(object):
  def __init__(self, pitch, bar, position, duration, velocity):
    self.pitch = pitch
    self.start_tick = bar * DEFAULT_BAR_RESOL + position * (DEFAULT_BAR_RESOL // DEFAULT_FRACTION)
    self.duration = duration
    self.velocity = velocity

class TempoEvent(object):
  def __init__(self, tempo, bar, position):
    self.tempo = tempo
    self.start_tick = bar * DEFAULT_BAR_RESOL + position * (DEFAULT_BAR_RESOL // DEFAULT_FRACTION)

class ChordEvent(object):
  def __init__(self, chord_val, bar, position):
    self.chord_val = chord_val
    self.start_tick = bar * DEFAULT_BAR_RESOL + position * (DEFAULT_BAR_RESOL // DEFAULT_FRACTION)

##############################
# conversion functions
##############################
def read_generated_txt(generated_path):
  f = open(generated_path, 'r')
  return f.read().splitlines()

def remi2midi(events, output_midi_path=None, is_full_event=False, return_first_tempo=False, enforce_tempo=False, enforce_tempo_val=None):
  events = [ConversionEvent(ev, is_full_event=is_full_event) for ev in events]
  # print (events[:20])

  assert events[0].name == 'Bar'
  temp_notes = []
  temp_tempos = []
  temp_chords = []

  cur_bar = 0
  cur_position = 0

  for i in range(len(events)):
    if events[i].name == 'Bar':
      if i > 0:
        cur_bar += 1
    elif events[i].name == 'Beat':
      cur_position = int(events[i].value)
      assert cur_position >= 0 and cur_position < DEFAULT_FRACTION
    elif events[i].name == 'Tempo':
      temp_tempos.append(TempoEvent(
        int(events[i].value), cur_bar, cur_position
      ))
    elif 'Note_Pitch' in events[i].name and \
         (i+1) < len(events) and 'Note_Velocity' in events[i+1].name and \
         (i+2) < len(events) and 'Note_Duration' in events[i+2].name:
      # check if the 3 events are of the same instrument
      temp_notes.append(
        NoteEvent(
          pitch=int(events[i].value),
          bar=cur_bar, position=cur_position,
          duration=int(events[i+2].value), velocity=int(events[i+1].value)
        )
      )
    elif 'Chord' in events[i].name:
      temp_chords.append(
        ChordEvent(events[i].value, cur_bar, cur_position)
      )
    elif events[i].name in ['EOS', 'PAD']:
      continue

  # print (len(temp_tempos), len(temp_notes))
  midi_obj = miditoolkit.midi.parser.MidiFile()
  midi_obj.instruments = [
    miditoolkit.Instrument(program=0, is_drum=False, name='Piano')
  ]

  for n in temp_notes:
    midi_obj.instruments[0].notes.append(
      miditoolkit.Note(int(n.velocity), n.pitch, int(n.start_tick), int(n.start_tick + n.duration))
    )

  if enforce_tempo is False:
    for t in temp_tempos:
      midi_obj.tempo_changes.append(
        miditoolkit.TempoChange(t.tempo, int(t.start_tick))
      )
  else:
    if enforce_tempo_val is None:
      enforce_tempo_val = temp_tempos[1]
    for t in enforce_tempo_val:
      midi_obj.tempo_changes.append(
        miditoolkit.TempoChange(t.tempo, int(t.start_tick))
      )


  for c in temp_chords:
    midi_obj.markers.append(
      miditoolkit.Marker('Chord-{}'.format(c.chord_val), int(c.start_tick))
    )
  for b in range(cur_bar):
    midi_obj.markers.append(
      miditoolkit.Marker('Bar-{}'.format(b+1), int(DEFAULT_BAR_RESOL * b))
    )

  if output_midi_path is not None:
    midi_obj.dump(output_midi_path)

  if not return_first_tempo:
    return midi_obj
  else:
    return midi_obj, temp_tempos

In [None]:
###########################################
# little helpers
###########################################
def word2event(word_seq, idx2event):
  return [ idx2event[w] for w in word_seq ]

def get_beat_idx(event):
  return int(event.split('_')[-1])

In [None]:
###########################################
# sampling utilities
###########################################
def temperatured_softmax(logits, temperature):
  try:
    probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature))
    assert np.count_nonzero(np.isnan(probs)) == 0
  except:
    print ('overflow detected, use 128-bit')
    logits = logits.astype(np.float128)
    probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature))
    probs = probs.astype(float)
  return probs

def nucleus(probs, p):
    probs /= sum(probs)
    sorted_probs = np.sort(probs)[::-1]
    sorted_index = np.argsort(probs)[::-1]
    cusum_sorted_probs = np.cumsum(sorted_probs)
    after_threshold = cusum_sorted_probs > p
    if sum(after_threshold) > 0:
        last_index = np.where(after_threshold)[0][1]
        candi_index = sorted_index[:last_index]
    else:
        candi_index = sorted_index[:3] # just assign a value
    candi_probs = np.array([probs[i] for i in candi_index], dtype=np.float64)
    candi_probs /= sum(candi_probs)
    word = np.random.choice(candi_index, size=1, p=candi_probs)[0]
    return word

In [None]:
device = config['training']['device']
########################################
# generation
########################################
def get_latent_embedding_fast(model, piece_data, use_sampling=False, sampling_var=0.):
  # reshape
  batch_inp = piece_data['enc_input'].permute(1, 0).long().to(device)
  batch_padding_mask = piece_data['enc_padding_mask'].bool().to(device)

  # get latent conditioning vectors
  with torch.no_grad():
    piece_latents = model.get_sampled_latent(
      batch_inp, padding_mask=batch_padding_mask,
      use_sampling=use_sampling, sampling_var=sampling_var
    )

  return piece_latents

def generate_on_latent_ctrl_vanilla_truncate(
        model, latents, rfreq_cls, polyph_cls, event2idx, idx2event,
        max_events=12800, primer=None,
        max_input_len=1280, truncate_len=512,
        nucleus_p=0.9, temperature=1.2
      ):
  latent_placeholder = torch.zeros(max_events, 1, latents.size(-1)).to(device)
  rfreq_placeholder = torch.zeros(max_events, 1, dtype=int).to(device)
  polyph_placeholder = torch.zeros(max_events, 1, dtype=int).to(device)
  print ('[info] rhythm cls: {} | polyph_cls: {}'.format(rfreq_cls, polyph_cls))

  if primer is None:
    generated = [event2idx['Bar_None']]
  else:
    generated = [event2idx[e] for e in primer]
    latent_placeholder[:len(generated), 0, :] = latents[0].squeeze(0)
    rfreq_placeholder[:len(generated), 0] = rfreq_cls[0]
    polyph_placeholder[:len(generated), 0] = polyph_cls[0]

  target_bars, generated_bars = latents.size(0), 0

  steps = 0
  time_st = time.time()
  cur_pos = 0
  failed_cnt = 0

  cur_input_len = len(generated)
  generated_final = deepcopy(generated)
  entropies = []

  while generated_bars < target_bars:
    if len(generated) == 1:
      dec_input = numpy_to_tensor([generated], device=device).long()
    else:
      dec_input = numpy_to_tensor([generated], device=device).permute(1, 0).long()

    latent_placeholder[len(generated)-1, 0, :] = latents[ generated_bars ]
    rfreq_placeholder[len(generated)-1, 0] = rfreq_cls[ generated_bars ]
    polyph_placeholder[len(generated)-1, 0] = polyph_cls[ generated_bars ]

    dec_seg_emb = latent_placeholder[:len(generated), :]
    dec_rfreq_cls = rfreq_placeholder[:len(generated), :]
    dec_polyph_cls = polyph_placeholder[:len(generated), :]

    # sampling
    with torch.no_grad():
      logits = model.generate(dec_input, dec_seg_emb, dec_rfreq_cls, dec_polyph_cls)
    logits = tensor_to_numpy(logits[0])
    probs = temperatured_softmax(logits, temperature)
    word = nucleus(probs, nucleus_p)
    word_event = idx2event[word]

    if 'Beat' in word_event:
      event_pos = get_beat_idx(word_event)
      if not event_pos >= cur_pos:
        failed_cnt += 1
        print ('[info] position not increasing, failed cnt:', failed_cnt)
        if failed_cnt >= 128:
          print ('[FATAL] model stuck, exiting ...')
          return generated
        continue
      else:
        cur_pos = event_pos
        failed_cnt = 0

    if 'Bar' in word_event:
      generated_bars += 1
      cur_pos = 0
      print ('[info] generated {} bars, #events = {}'.format(generated_bars, len(generated_final)))
    if word_event == 'PAD_None':
      continue

    if len(generated) > max_events or (word_event == 'EOS_None' and generated_bars == target_bars - 1):
      generated_bars += 1
      generated.append(event2idx['Bar_None'])
      print ('[info] gotten eos')
      break

    generated.append(word)
    generated_final.append(word)
    entropies.append(entropy(probs))

    cur_input_len += 1
    steps += 1

    assert cur_input_len == len(generated)
    if cur_input_len == max_input_len:
      generated = generated[-truncate_len:]
      latent_placeholder[:len(generated)-1, 0, :] = latent_placeholder[cur_input_len-truncate_len:cur_input_len-1, 0, :]
      rfreq_placeholder[:len(generated)-1, 0] = rfreq_placeholder[cur_input_len-truncate_len:cur_input_len-1, 0]
      polyph_placeholder[:len(generated)-1, 0] = polyph_placeholder[cur_input_len-truncate_len:cur_input_len-1, 0]

      print ('[info] reset context length: cur_len: {}, accumulated_len: {}, truncate_range: {} ~ {}'.format(
        cur_input_len, len(generated_final), cur_input_len-truncate_len, cur_input_len-1
      ))
      cur_input_len = len(generated)

  assert generated_bars == target_bars
  print ('-- generated events:', len(generated_final))
  print ('-- time elapsed: {:.2f} secs'.format(time.time() - time_st))
  return generated_final[:-1], time.time() - time_st, np.array(entropies)

In [None]:
########################################
# change attribute classes
########################################
def random_shift_attr_cls(n_samples, upper=4, lower=-3):
  return np.random.randint(lower, upper, (n_samples,))

# Dataset

In [None]:
IDX_TO_KEY = {
  0: 'A',
  1: 'A#',
  2: 'B',
  3: 'C',
  4: 'C#',
  5: 'D',
  6: 'D#',
  7: 'E',
  8: 'F',
  9: 'F#',
  10: 'G',
  11: 'G#'
}
KEY_TO_IDX = {
  v:k for k, v in IDX_TO_KEY.items()
}

def get_chord_tone(chord_event):
  tone = chord_event['value'].split('_')[0]
  return tone

def transpose_chord(chord_event, n_keys):
  if chord_event['value'] == 'N_N':
    return chord_event

  orig_tone = get_chord_tone(chord_event)
  orig_tone_idx = KEY_TO_IDX[orig_tone]
  new_tone_idx = (orig_tone_idx + 12 + n_keys) % 12
  new_chord_value = chord_event['value'].replace(
    '{}_'.format(orig_tone), '{}_'.format(IDX_TO_KEY[new_tone_idx])
  )
  new_chord_event = {'name': chord_event['name'], 'value': new_chord_value}
  # print ('keys={}. {} --> {}'.format(n_keys, chord_event, new_chord_event))

  return new_chord_event

def check_extreme_pitch(raw_events):
  low, high = 128, 0
  for ev in raw_events:
    if ev['name'] == 'Note_Pitch':
      low = min(low, int(ev['value']))
      high = max(high, int(ev['value']))

  return low, high

def transpose_events(raw_events, n_keys):
  transposed_raw_events = []

  for ev in raw_events:
    if ev['name'] == 'Note_Pitch':
      transposed_raw_events.append(
        {'name': ev['name'], 'value': ev['value'] + n_keys}
      )
    elif ev['name'] == 'Chord':
      transposed_raw_events.append(
        transpose_chord(ev, n_keys)
      )
    else:
      transposed_raw_events.append(ev)

  assert len(transposed_raw_events) == len(raw_events)
  return transposed_raw_events

def pickle_load(path):
  return pickle.load(open(path, 'rb'))

def convert_event(event_seq, event2idx, to_ndarr=True):
  if isinstance(event_seq[0], dict):
    event_seq = [event2idx['{}_{}'.format(e['name'], e['value'])] for e in event_seq]
  else:
    event_seq = [event2idx[e] for e in event_seq]

  if to_ndarr:
    return np.array(event_seq)
  else:
    return event_seq

In [None]:
class REMIFullSongTransformerDataset(Dataset):
  def __init__(self, data_dir,
                    vocab_file,
                    model_enc_seqlen=128,
                    model_dec_seqlen=1280,
                    model_max_bars=16,
                    pieces=[],
                    do_augment=True,
                    augment_range=range(-6, 7),
                    min_pitch=22,
                    max_pitch=107,
                    pad_to_same=True,
                    use_attr_cls=True,
                    appoint_st_bar=None,
                    dec_end_pad_value=None):
    self.vocab_file = vocab_file
    self.read_vocab()

    self.data_dir = data_dir
    self.pieces = pieces
    self.build_dataset()

    self.model_enc_seqlen = model_enc_seqlen
    self.model_dec_seqlen = model_dec_seqlen
    self.model_max_bars = model_max_bars

    self.do_augment = do_augment
    self.augment_range = augment_range
    self.min_pitch, self.max_pitch = min_pitch, max_pitch
    self.pad_to_same = pad_to_same
    self.use_attr_cls = use_attr_cls

    self.appoint_st_bar = appoint_st_bar
    if dec_end_pad_value is None:
      self.dec_end_pad_value = self.pad_token
    elif dec_end_pad_value == 'EOS':
      self.dec_end_pad_value = self.eos_token
    else:
      self.dec_end_pad_value = self.pad_token

  def read_vocab(self):
    vocab = pickle_load(self.vocab_file)[0]
    self.idx2event = pickle_load(self.vocab_file)[1]
    orig_vocab_size = len(vocab)
    self.event2idx = vocab
    self.bar_token = self.event2idx['Bar_None']
    self.eos_token = self.event2idx['EOS_None']
    self.pad_token = orig_vocab_size
    self.vocab_size = self.pad_token + 1

  def build_dataset(self):
    if not self.pieces:
      self.pieces = sorted(glob(os.path.join(self.data_dir, '*.pkl')) )
    else:
      self.pieces = sorted( [os.path.join(self.data_dir, p) for p in self.pieces] )

    self.piece_bar_pos = []

    for i, p in enumerate(self.pieces):
      bar_pos, p_evs = pickle_load(p)
      if not i % 200:
        print ('[preparing data] now at #{}'.format(i))
      if bar_pos[-1] == len(p_evs):
        print ('piece {}, got appended bar markers'.format(p))
        bar_pos = bar_pos[:-1]
      if len(p_evs) - bar_pos[-1] == 2:
        # got empty trailing bar
        bar_pos = bar_pos[:-1]

      bar_pos.append(len(p_evs))

      self.piece_bar_pos.append(bar_pos)

  def get_sample_from_file(self, piece_idx):
    piece_evs = pickle_load(self.pieces[piece_idx])[1]
    if len(self.piece_bar_pos[piece_idx]) > self.model_max_bars and self.appoint_st_bar is None:
      picked_st_bar = random.choice(
        range(len(self.piece_bar_pos[piece_idx]) - self.model_max_bars)
      )
    elif self.appoint_st_bar is not None and self.appoint_st_bar < len(self.piece_bar_pos[piece_idx]) - self.model_max_bars:
      picked_st_bar = self.appoint_st_bar
    else:
      picked_st_bar = 0

    piece_bar_pos = self.piece_bar_pos[piece_idx]

    if len(piece_bar_pos) > self.model_max_bars:
      piece_evs = piece_evs[ piece_bar_pos[picked_st_bar] : piece_bar_pos[picked_st_bar + self.model_max_bars] ]
      picked_bar_pos = np.array(piece_bar_pos[ picked_st_bar : picked_st_bar + self.model_max_bars ]) - piece_bar_pos[picked_st_bar]
      n_bars = self.model_max_bars
    else:
      picked_bar_pos = np.array(piece_bar_pos + [piece_bar_pos[-1]] * (self.model_max_bars - len(piece_bar_pos)))
      n_bars = len(piece_bar_pos)
      assert len(picked_bar_pos) == self.model_max_bars

    return piece_evs, picked_st_bar, picked_bar_pos, n_bars

  def pad_sequence(self, seq, maxlen, pad_value=None):
    if pad_value is None:
      pad_value = self.pad_token

    seq.extend( [pad_value for _ in range(maxlen- len(seq))] )

    return seq

  def pitch_augment(self, bar_events):
    bar_min_pitch, bar_max_pitch = check_extreme_pitch(bar_events)

    n_keys = random.choice(self.augment_range)
    while bar_min_pitch + n_keys < self.min_pitch or bar_max_pitch + n_keys > self.max_pitch:
      n_keys = random.choice(self.augment_range)

    augmented_bar_events = transpose_events(bar_events, n_keys)
    return augmented_bar_events

  def get_attr_classes(self, piece, st_bar):
    polyph_cls = pickle_load(os.path.join(self.data_dir, 'attr_cls/polyph', piece))[st_bar : st_bar + self.model_max_bars]
    rfreq_cls = pickle_load(os.path.join(self.data_dir, 'attr_cls/rhythm', piece))[st_bar : st_bar + self.model_max_bars]

    polyph_cls.extend([0 for _ in range(self.model_max_bars - len(polyph_cls))])
    rfreq_cls.extend([0 for _ in range(self.model_max_bars - len(rfreq_cls))])

    assert len(polyph_cls) == self.model_max_bars
    assert len(rfreq_cls) == self.model_max_bars

    return polyph_cls, rfreq_cls

  def get_encoder_input_data(self, bar_positions, bar_events):
    assert len(bar_positions) == self.model_max_bars + 1
    enc_padding_mask = np.ones((self.model_max_bars, self.model_enc_seqlen), dtype=bool)
    enc_padding_mask[:, :2] = False
    padded_enc_input = np.full((self.model_max_bars, self.model_enc_seqlen), dtype=int, fill_value=self.pad_token)
    enc_lens = np.zeros((self.model_max_bars,))

    for b, (st, ed) in enumerate(zip(bar_positions[:-1], bar_positions[1:])):
      enc_padding_mask[b, : (ed-st)] = False
      enc_lens[b] = ed - st
      within_bar_events = self.pad_sequence(bar_events[st : ed], self.model_enc_seqlen, self.pad_token)
      within_bar_events = np.array(within_bar_events)

      padded_enc_input[b, :] = within_bar_events[:self.model_enc_seqlen]

    return padded_enc_input, enc_padding_mask, enc_lens

  def __len__(self):
    return len(self.pieces)

  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      idx = idx.tolist()

    bar_events, st_bar, bar_pos, enc_n_bars = self.get_sample_from_file(idx)
    if self.do_augment:
      bar_events = self.pitch_augment(bar_events)

    if self.use_attr_cls:
      polyph_cls, rfreq_cls = self.get_attr_classes(os.path.basename(self.pieces[idx]), st_bar)
      polyph_cls_expanded = np.zeros((self.model_dec_seqlen,), dtype=int)
      rfreq_cls_expanded = np.zeros((self.model_dec_seqlen,), dtype=int)
      for i, (b_st, b_ed) in enumerate(zip(bar_pos[:-1], bar_pos[1:])):
        polyph_cls_expanded[b_st:b_ed] = polyph_cls[i]
        rfreq_cls_expanded[b_st:b_ed] = rfreq_cls[i]
    else:
      polyph_cls, rfreq_cls = [0], [0]
      polyph_cls_expanded, rfreq_cls_expanded = [0], [0]

    bar_tokens = convert_event(bar_events, self.event2idx, to_ndarr=False)
    bar_pos = bar_pos.tolist() + [len(bar_tokens)]

    enc_inp, enc_padding_mask, enc_lens = self.get_encoder_input_data(bar_pos, bar_tokens)

    length = len(bar_tokens)
    if self.pad_to_same:
      inp = self.pad_sequence(bar_tokens, self.model_dec_seqlen + 1)
    else:
      inp = self.pad_sequence(bar_tokens, len(bar_tokens) + 1, pad_value=self.dec_end_pad_value)
    target = np.array(inp[1:], dtype=int)
    inp = np.array(inp[:-1], dtype=int)
    assert len(inp) == len(target)

    return {
      'id': idx,
      'piece_id': int(os.path.basename(self.pieces[idx]).replace('.pkl', '')),
      'st_bar_id': st_bar,
      'bar_pos': np.array(bar_pos, dtype=int),
      'enc_input': enc_inp,
      'dec_input': inp[:self.model_dec_seqlen],
      'dec_target': target[:self.model_dec_seqlen],
      'polyph_cls': polyph_cls_expanded,
      'rhymfreq_cls': rfreq_cls_expanded,
      'polyph_cls_bar': np.array(polyph_cls),
      'rhymfreq_cls_bar': np.array(rfreq_cls),
      'length': min(length, self.model_dec_seqlen),
      'enc_padding_mask': enc_padding_mask,
      'enc_length': enc_lens,
      'enc_n_bars': enc_n_bars
    }

In [None]:
dset = REMIFullSongTransformerDataset(
                                        config['data']['data_dir'], config['data']['vocab_path'],
                                        do_augment=True,
                                        model_enc_seqlen=config['data']['enc_seqlen'],
                                        model_dec_seqlen=config['data']['dec_seqlen'],
                                        model_max_bars=config['data']['max_bars'],
                                        pieces=pickle_load(config['data']['train_split']),
                                        pad_to_same=True
                                    )
dset_val = REMIFullSongTransformerDataset(
                                    config['data']['data_dir'], config['data']['vocab_path'],
                                    do_augment=False,
                                    model_enc_seqlen=config['data']['enc_seqlen'],
                                    model_dec_seqlen=config['data']['dec_seqlen'],
                                    model_max_bars=config['data']['max_bars'],
                                    pieces=pickle_load(config['data']['val_split']),
                                    pad_to_same=True
                                )
print ('[info]', '# training samples:', len(dset.pieces))

dloader = DataLoader(dset, batch_size=config['data']['batch_size'], shuffle=True, num_workers=8)
dloader_val = DataLoader(dset_val, batch_size=config['data']['batch_size'], shuffle=True, num_workers=8)

[preparing data] now at #0
[preparing data] now at #200
[preparing data] now at #400
[preparing data] now at #600
[preparing data] now at #800
[preparing data] now at #1000
[preparing data] now at #1200
[preparing data] now at #1400
[preparing data] now at #0
[info] # training samples: 1572




In [None]:
sample = next(iter(dloader))
sample['dec_input']

tensor([[  0,   1,  67,  ..., 332, 332, 332],
        [  0,   1, 142,  ..., 332, 332, 332],
        [  0,   1,  45,  ..., 332, 332, 332],
        [  0,   1, 296,  ..., 332, 332, 332]])

# Model Architecture

In [None]:
mconf = config['model']

device = config['training']['device']
trained_steps = config['training']['trained_steps']
lr_decay_steps = config['training']['lr_decay_steps']
lr_warmup_steps = config['training']['lr_warmup_steps']
no_kl_steps = config['training']['no_kl_steps']
kl_cycle_steps = config['training']['kl_cycle_steps']
kl_max_beta = config['training']['kl_max_beta']
free_bit_lambda = config['training']['free_bit_lambda']
max_lr, min_lr = config['training']['max_lr'], config['training']['min_lr']

ckpt_dir = config['training']['ckpt_dir']
params_dir = os.path.join(ckpt_dir, 'params/')
optim_dir = os.path.join(ckpt_dir, 'optim/')
pretrained_params_path = config['model']['pretrained_params_path']
pretrained_optim_path = config['model']['pretrained_optim_path']
ckpt_interval = config['training']['ckpt_interval']
log_interval = config['training']['log_interval']
val_interval = config['training']['val_interval']
constant_kl = config['training']['constant_kl']

In [None]:
def generate_causal_mask(seq_len):
    mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    mask.requires_grad = False
    return mask

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_embed, max_pos=20480):
        super(PositionalEncoding, self).__init__()
        self.d_embed = d_embed
        self.max_pos = max_pos

        pe = torch.zeros(max_pos, d_embed)
        position = torch.arange(0, max_pos, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_embed, 2).float() * (-math.log(10000.0) / d_embed))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, seq_len, bsz=None):
        pos_encoding = self.pe[:seq_len, :]

        if bsz is not None:
          pos_encoding = pos_encoding.expand(seq_len, bsz, -1)
        print(f'Positional Encoding: {pos_encoding.size()}')
        return pos_encoding

class TokenEmbedding(nn.Module):
  def __init__(self, n_token, d_embed, d_proj):
    super(TokenEmbedding, self).__init__()

    self.n_token = n_token
    self.d_embed = d_embed
    self.d_proj = d_proj
    self.emb_scale = d_proj ** 0.5

    self.emb_lookup = nn.Embedding(n_token, d_embed)
    if d_proj != d_embed:
      self.emb_proj = nn.Linear(d_embed, d_proj, bias=False)
    else:
      self.emb_proj = None

  def forward(self, inp_tokens):
    inp_emb = self.emb_lookup(inp_tokens)

    if self.emb_proj is not None:
      inp_emb = self.emb_proj(inp_emb)
    result = inp_emb.mul_(self.emb_scale) # inp_emb = inp_emb * self.emb_scale
    print(f'Token Emb: {result.size()}')
    return result

In [147]:
class MMDecoder(nn.Module):
  def __init__(self, n_layer, n_head, d_model, d_ff, d_seg_emb, dropout=0.1, activation='relu', cond_mode='in-attn'):
    super().__init__()
    self.n_layer = n_layer
    self.n_head = n_head
    self.d_model = d_model
    self.d_ff = d_ff
    self.d_seg_emb = 1280
    self.dropout = dropout
    self.emb_dropout = nn.Dropout(dropout).cuda()
    self.activation = activation
    self.cond_mode = cond_mode

    self.pe = PositionalEncoding(mconf['d_embed']).cuda()
    self.token_emb = TokenEmbedding(dset.vocab_size, mconf['d_embed'], mconf['dec_d_model']).cuda()

    self.decoder_layers = nn.ModuleList()
    for i in range(n_layer):
        self.decoder_layers.append(
        nn.TransformerEncoderLayer(d_model, n_head, d_ff, dropout, activation).cuda()
        )

    self.linear_softmax = nn.Sequential(
        nn.Linear(config['model']['dec_d_model'], dset.vocab_size),
        nn.Softmax()
    ).cuda() #size (batch_size, )

  def forward(self, x):
    dec_inp = x
    print('Positional Embedding')
    dec_token_emb = self.token_emb(dec_inp)
    dec_inp = self.emb_dropout(dec_token_emb) + self.pe(dec_inp.size(0))


    attn_mask = generate_causal_mask(x.size(0)).cuda()
    print(attn_mask.size())

    print('Decoder')
    print(x.size())
    out = dec_inp
    for i in range(self.n_layer):
        out = self.decoder_layers[i](out, src_mask=attn_mask)
        print(f'Layer {i}: {out.size()}')
    out = self.linear_softmax(out)
    print(f'Output layer: {out.size()}')
    return out


# Training and Validating

## Train

In [148]:
model = MMDecoder(n_layer=mconf['dec_n_layer'],
                  n_head=mconf['dec_n_head'],
                  d_model=mconf['dec_d_model'],
                  d_ff=mconf['dec_d_ff'],
                  d_seg_emb=1280,
                  dropout=0.1,
                  activation='relu',
                  cond_mode='in-attn')

In [149]:
model(sample['dec_input'].cuda()).size()

Positional Embedding
Token Emb: torch.Size([4, 1280, 256])
Positional Encoding: torch.Size([4, 1, 256])
torch.Size([4, 4])
Decoder
torch.Size([4, 1280])
Layer 0: torch.Size([4, 1280, 256])
Layer 1: torch.Size([4, 1280, 256])
Layer 2: torch.Size([4, 1280, 256])
Layer 3: torch.Size([4, 1280, 256])
Layer 4: torch.Size([4, 1280, 256])
Layer 5: torch.Size([4, 1280, 256])
Output layer: torch.Size([4, 1280, 333])


  input = module(input)


torch.Size([4, 1280, 333])

In [136]:
sample['dec_input'].size()

torch.Size([4, 1280])

In [137]:
sample['dec_target'].size()

torch.Size([4, 1280])

In [None]:
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print ('[info] model # params:', n_params)

[info] model # params: 4823808


In [155]:
loss_fn = torch.nn.CrossEntropyLoss()
opt_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(opt_params, lr=max_lr)

loss_results = []

for epoch in range(config['training']['max_epochs']):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(dloader, 0):
        input = data['dec_input'].permute(1, 0).cuda()
        dec_tgt = data['dec_target'].permute(1, 0).cuda()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        dec_logits = model(input)
        loss = loss_fn(dec_logits.view(-1, dec_logits.size(-1)), dec_tgt.contiguous().view(-1)).float()
        loss_results.append(loss)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

[1;30;43mKết quả truyền trực tuyến bị cắt bớt đến 5000 dòng cuối.[0m
Layer 0: torch.Size([1280, 4, 256])
Layer 1: torch.Size([1280, 4, 256])
Layer 2: torch.Size([1280, 4, 256])
Layer 3: torch.Size([1280, 4, 256])
Layer 4: torch.Size([1280, 4, 256])
Layer 5: torch.Size([1280, 4, 256])
Output layer: torch.Size([1280, 4, 333])
Positional Embedding
Token Emb: torch.Size([1280, 4, 256])
Positional Encoding: torch.Size([1280, 1, 256])
torch.Size([1280, 1280])
Decoder
torch.Size([1280, 4])
Layer 0: torch.Size([1280, 4, 256])
Layer 1: torch.Size([1280, 4, 256])
Layer 2: torch.Size([1280, 4, 256])
Layer 3: torch.Size([1280, 4, 256])
Layer 4: torch.Size([1280, 4, 256])
Layer 5: torch.Size([1280, 4, 256])
Output layer: torch.Size([1280, 4, 333])
Positional Embedding
Token Emb: torch.Size([1280, 4, 256])
Positional Encoding: torch.Size([1280, 1, 256])
torch.Size([1280, 1280])
Decoder
torch.Size([1280, 4])
Layer 0: torch.Size([1280, 4, 256])
Layer 1: torch.Size([1280, 4, 256])
Layer 2: torch.Size