In [None]:
# 1. mount gdrive first / run the following
# !pip install gdown
# !gdown https://drive.google.com/uc?id=<file_id>
# !gdown --folder https://drive.google.com/drive/folders/<file_id>


In [None]:
# 2. Download dataset/MuseMorphose-main.zip and unzip
# if in Colab, mounted gdrive:
!unzip "/content/drive/MyDrive/Capstone/dataset/MuseMorphose-main.zip" -d "./"

In [None]:
# # an example of musemorphose's data
# # download REMI-pop-1.7K dataset
# !wget -O remi_dataset.tar.gz https://zenodo.org/record/4782721/files/remi_dataset.tar.gz?download=1
# !tar xzvf remi_dataset.tar.gz
# !rm remi_dataset.tar.gz

# # compute attributes classes
# !python3 attributes.py

# rename MuseMorphose dataset
# !mv /content/MuseMorphose-main/remi_dataset /content/MuseMorphose-main/remi_dataset_mm

In [None]:
# 3. untar our data located in dataset/remi_dataset_0907_updated.tar.gz (if in Colab, mounted gdrive:)
!tar xzvf /content/drive/MyDrive/Capstone/dataset/remi_dataset_0907_updated.tar.gz

In [3]:
# 4. move our dataset to the folder
!mv /content/content/remi_dataset /content/MuseMorphose-main/remi_dataset

In [None]:
# # Musemorphose remi_vocab
# pickle_load("/content/MuseMorphose-main/pickles/remi_vocab.pkl")

In [None]:
# pickle_load("/content/content/compound-word-transformer-main/dataset/representations/uncond/remi/ailab17k_from-scratch_remi/dictionary.pkl")

In [4]:
# 5. rename/remove the original remi_vocab located in MuseMorphose-main/pickles/remi_vocab.pkl 
!mv /content/MuseMorphose-main/pickles/remi_vocab.pkl /content/MuseMorphose-main/pickles/remi_vocab_mm.pkl

In [5]:
# 7. add our remi_vocab, located in dataset/remi_vocab.pkl
!cp "/content/drive/MyDrive/Capstone/Deep Music Generation Group Folder/dataset/remi_vocab.pkl" "/content/MuseMorphose-main/pickles/"

In [28]:
# 8. obtain label files
# located in shared drive /dataset/remi_splitted/
import csv

def print_dict_ex(d):
  print("Printing 3 lines of examples ...")
  count = 0
  for k, v in d.items():
    print(f"{k}: {v}")
    count += 1
    if count >= 3:
      break

def load_tuples(fn):
  fn_map = {}
  recording_track_map = {}
  with open(fn) as csv_file:
    csv_reader = csv.reader(csv_file)
    count = 0
    for row in csv_reader:
      assert len(row) == 2
      if count == 0:
        print(row)
      else:
        track, recording, chunk = row[0].split("_")
        map_id = f"{recording}_{chunk}"
        if map_id in fn_map:
          fn_map[map_id].append(row[1])
        else:
          fn_map[map_id] = [row[1]]
      count += 1
    print(f"Processed {count - 1} lines of tuples.")
  return fn_map

neg_test = load_tuples("/content/drive/MyDrive/Capstone/Deep Music Generation Group Folder/dataset/remi_splitted/neg_tuples_test.csv")
neg_train = load_tuples("/content/drive/MyDrive/Capstone/Deep Music Generation Group Folder/dataset/remi_splitted/neg_tuples_train.csv")
neg_val = load_tuples("/content/drive/MyDrive/Capstone/Deep Music Generation Group Folder/dataset/remi_splitted/neg_tuples_val.csv")
pos_test = load_tuples("/content/drive/MyDrive/Capstone/Deep Music Generation Group Folder/dataset/remi_splitted/pos_tuples_test.csv")
pos_train = load_tuples("/content/drive/MyDrive/Capstone/Deep Music Generation Group Folder/dataset/remi_splitted/pos_tuples_train.csv")
pos_val = load_tuples("/content/drive/MyDrive/Capstone/Deep Music Generation Group Folder/dataset/remi_splitted/pos_tuples_val.csv")
print_dict_ex(neg_test)

['file_name', 'label']
Processed 63240 lines of tuples.
['file_name', 'label']
Processed 574998 lines of tuples.
['file_name', 'label']
Processed 80007 lines of tuples.
['file_name', 'label']
Processed 63240 lines of tuples.
['file_name', 'label']
Processed 574998 lines of tuples.
['file_name', 'label']
Processed 80007 lines of tuples.
Printing 3 lines of examples ...
23b932508e50b703a605bffda6324c01_0: ['Doo Wop', 'I can honestly say that I believe &quot;Dirt&quot; is a masterpiece.Forget the grunge stereotype, this is a brutal beauty.My favorite tracks are &quot;Them Bones&quot; &quot;Rooster&quot; and &quot;Sickman&quot;.  Talk  about a wall of sound !']
23b932508e50b703a605bffda6324c01_1: ['1979: the fag end of punk was dwindling to a gradual stop as Britain, it seemed, reached a cataclysmic collapse.  For those not in the UK during the infamous "winter of discontent" of 1978-9, Britain was on the brink of meltdown as everyone went on strike, the bins went unemptied, the dead unbur

In [7]:
%cd /content/MuseMorphose-main

/content/MuseMorphose-main


# Dataloader

In [8]:
import os, pickle, random
from glob import glob

import torch
import numpy as np

from torch.utils.data import Dataset, DataLoader

IDX_TO_KEY = {
  0: 'A',
  1: 'A#',
  2: 'B',
  3: 'C',
  4: 'C#',
  5: 'D',
  6: 'D#',
  7: 'E',
  8: 'F',
  9: 'F#',
  10: 'G',
  11: 'G#'
}
KEY_TO_IDX = {
  v:k for k, v in IDX_TO_KEY.items()
}


In [9]:
def get_chord_tone(chord_event):
  tone = chord_event['value'].split('_')[0]
  return tone

def transpose_chord(chord_event, n_keys):
  if chord_event['value'] == 'N_N':
    return chord_event

  orig_tone = get_chord_tone(chord_event)
  orig_tone_idx = KEY_TO_IDX[orig_tone]
  new_tone_idx = (orig_tone_idx + 12 + n_keys) % 12
  new_chord_value = chord_event['value'].replace(
    '{}_'.format(orig_tone), '{}_'.format(IDX_TO_KEY[new_tone_idx])
  )
  new_chord_event = {'name': chord_event['name'], 'value': new_chord_value}
  # print ('keys={}. {} --> {}'.format(n_keys, chord_event, new_chord_event))

  return new_chord_event

In [10]:
def check_extreme_pitch(raw_events):
  low, high = 128, 0
  for ev in raw_events:
    if ev['name'] == 'Note_Pitch':
      low = min(low, int(ev['value']))
      high = max(high, int(ev['value']))

  return low, high

def transpose_events(raw_events, n_keys):
  transposed_raw_events = []

  for ev in raw_events:
    if ev['name'] == 'Note_Pitch':
      transposed_raw_events.append(
        {'name': ev['name'], 'value': ev['value'] + n_keys}
      )
    elif ev['name'] == 'Chord':
      transposed_raw_events.append(
        transpose_chord(ev, n_keys)
      )
    else:
      transposed_raw_events.append(ev)

  assert len(transposed_raw_events) == len(raw_events)
  return transposed_raw_events

In [11]:
def pickle_load(path):
  return pickle.load(open(path, 'rb'))

def convert_event(event_seq, event2idx, to_ndarr=True):
  if isinstance(event_seq[0], dict):
    event_seq = [event2idx['{}_{}'.format(e['name'], e['value'])] for e in event_seq]
  else:
    event_seq = [event2idx[e] for e in event_seq]

  if to_ndarr:
    return np.array(event_seq)
  else:
    return event_seq

In [29]:
from torch.autograd.grad_mode import F
class REMIFullSongTransformerDataset(Dataset):
  def __init__(self, data_dir, vocab_file, 
               model_enc_seqlen=128, model_dec_seqlen=1280, model_max_bars=16,
               pieces=[], do_augment=True, augment_range=range(-6, 7), 
               min_pitch=22, max_pitch=107, pad_to_same=True, use_attr_cls=True,
               appoint_st_bar=None, dec_end_pad_value=None, pos_map=None, neg_map=None):
    self.vocab_file = vocab_file
    self.read_vocab()

    self.data_dir = data_dir
    self.pieces = pieces
    self.build_dataset()

    self.model_enc_seqlen = model_enc_seqlen
    self.model_dec_seqlen = model_dec_seqlen
    self.model_max_bars = model_max_bars

    self.do_augment = do_augment
    self.augment_range = augment_range
    self.min_pitch, self.max_pitch = min_pitch, max_pitch
    self.pad_to_same = pad_to_same
    self.use_attr_cls = use_attr_cls

    self.appoint_st_bar = appoint_st_bar
    if dec_end_pad_value is None:
      self.dec_end_pad_value = self.pad_token
    elif dec_end_pad_value == 'EOS':
      self.dec_end_pad_value = self.eos_token
    else:
      self.dec_end_pad_value = self.pad_token

    # added for our labels
    if pos_map is None or neg_map is None:
      raise ValueError("pos_map and neg_map cannot be None.")
    self.matched_pieces = self.match_labels(pos_map, neg_map)

  def match_labels(self, pos_map, neg_map):
    # match fn with different pos/neg labels
    data_groups = []
    for idx in range(len(self.pieces)):
      f = self.pieces[idx]
      fn = f.split("/")[-1].split(".")[0]
      if fn in pos_map:
        for label in pos_map[fn]:
          pos = True
          data_groups.append((idx, fn, pos, label))
      if fn in neg_map:
        for label in neg_map[fn]:
          pos = False
          data_groups.append((idx, fn, pos, label))
    return data_groups

  def read_vocab(self):
    vocab = pickle_load(self.vocab_file)[0]
    self.idx2event = pickle_load(self.vocab_file)[1]
    orig_vocab_size = len(vocab)
    self.event2idx = vocab
    self.bar_token = self.event2idx['Bar_None']
    self.eos_token = self.event2idx['EOS_None']
    self.pad_token = orig_vocab_size
    self.vocab_size = self.pad_token + 1
  
  def build_dataset(self):
    if not self.pieces:
      self.pieces = sorted( glob(os.path.join(self.data_dir, '*.pkl')) )
    else:
      self.pieces = sorted( [os.path.join(self.data_dir, p) for p in self.pieces] )

    self.piece_bar_pos = []

    for i, p in enumerate(self.pieces):
      bar_pos, p_evs = pickle_load(p)
      if not i % 200:
        print ('[preparing data] now at #{}'.format(i))
      if bar_pos[-1] == len(p_evs):
        print ('piece {}, got appended bar markers'.format(p))
        bar_pos = bar_pos[:-1]
      if len(p_evs) - bar_pos[-1] == 2:
        # got empty trailing bar
        bar_pos = bar_pos[:-1]

      bar_pos.append(len(p_evs))

      self.piece_bar_pos.append(bar_pos)

  def get_sample_from_file(self, piece_idx):
    piece_evs = pickle_load(self.pieces[piece_idx])[1]
    if len(self.piece_bar_pos[piece_idx]) > self.model_max_bars and self.appoint_st_bar is None:
      picked_st_bar = random.choice(
        range(len(self.piece_bar_pos[piece_idx]) - self.model_max_bars)
      )
    elif self.appoint_st_bar is not None and self.appoint_st_bar < len(self.piece_bar_pos[piece_idx]) - self.model_max_bars:
      picked_st_bar = self.appoint_st_bar
    else:
      picked_st_bar = 0

    piece_bar_pos = self.piece_bar_pos[piece_idx]

    if len(piece_bar_pos) > self.model_max_bars:
      piece_evs = piece_evs[ piece_bar_pos[picked_st_bar] : piece_bar_pos[picked_st_bar + self.model_max_bars] ]
      picked_bar_pos = np.array(piece_bar_pos[ picked_st_bar : picked_st_bar + self.model_max_bars ]) - piece_bar_pos[picked_st_bar]
      n_bars = self.model_max_bars
    else:
      picked_bar_pos = np.array(piece_bar_pos + [piece_bar_pos[-1]] * (self.model_max_bars - len(piece_bar_pos)))
      n_bars = len(piece_bar_pos)
      assert len(picked_bar_pos) == self.model_max_bars

    return piece_evs, picked_st_bar, picked_bar_pos, n_bars

  def pad_sequence(self, seq, maxlen, pad_value=None):
    if pad_value is None:
      pad_value = self.pad_token

    seq.extend( [pad_value for _ in range(maxlen- len(seq))] )

    return seq

  def pitch_augment(self, bar_events):
    bar_min_pitch, bar_max_pitch = check_extreme_pitch(bar_events)
    
    n_keys = random.choice(self.augment_range)
    while bar_min_pitch + n_keys < self.min_pitch or bar_max_pitch + n_keys > self.max_pitch:
      n_keys = random.choice(self.augment_range)

    augmented_bar_events = transpose_events(bar_events, n_keys)
    return augmented_bar_events

  def get_attr_classes(self, piece, st_bar):
    polyph_cls = pickle_load(os.path.join(self.data_dir, 'attr_cls/polyph', piece))[st_bar : st_bar + self.model_max_bars]
    rfreq_cls = pickle_load(os.path.join(self.data_dir, 'attr_cls/rhythm', piece))[st_bar : st_bar + self.model_max_bars]

    polyph_cls.extend([0 for _ in range(self.model_max_bars - len(polyph_cls))])
    rfreq_cls.extend([0 for _ in range(self.model_max_bars - len(rfreq_cls))])

    assert len(polyph_cls) == self.model_max_bars
    assert len(rfreq_cls) == self.model_max_bars

    return polyph_cls, rfreq_cls

  def get_encoder_input_data(self, bar_positions, bar_events):
    assert len(bar_positions) == self.model_max_bars + 1
    enc_padding_mask = np.ones((self.model_max_bars, self.model_enc_seqlen), dtype=bool)
    enc_padding_mask[:, :2] = False
    padded_enc_input = np.full((self.model_max_bars, self.model_enc_seqlen), dtype=int, fill_value=self.pad_token)
    enc_lens = np.zeros((self.model_max_bars,))

    for b, (st, ed) in enumerate(zip(bar_positions[:-1], bar_positions[1:])):
      enc_padding_mask[b, : (ed-st)] = False
      enc_lens[b] = ed - st
      within_bar_events = self.pad_sequence(bar_events[st : ed], self.model_enc_seqlen, self.pad_token)
      within_bar_events = np.array(within_bar_events)

      padded_enc_input[b, :] = within_bar_events[:self.model_enc_seqlen]

    return padded_enc_input, enc_padding_mask, enc_lens

  def __len__(self):
    # return len(self.pieces)
    return len(self.matched_pieces)

  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      print("idx", idx)
      idx = idx.tolist()

    # original idx for indexing self.pieces
    # now we need to idx self.matched_pieces
    # get original idx and matched idx

    matched_idx = idx
    matched_data = self.matched_pieces[idx]
    idx = matched_data[0]
    text_labels = matched_data[3]
    pos = matched_data[2]

    bar_events, st_bar, bar_pos, enc_n_bars = self.get_sample_from_file(idx)
    if self.do_augment:
      bar_events = self.pitch_augment(bar_events)

    if self.use_attr_cls:
      polyph_cls, rfreq_cls = self.get_attr_classes(os.path.basename(self.pieces[idx]), st_bar)
      polyph_cls_expanded = np.zeros((self.model_dec_seqlen,), dtype=int)
      rfreq_cls_expanded = np.zeros((self.model_dec_seqlen,), dtype=int)
      for i, (b_st, b_ed) in enumerate(zip(bar_pos[:-1], bar_pos[1:])):
        polyph_cls_expanded[b_st:b_ed] = polyph_cls[i]
        rfreq_cls_expanded[b_st:b_ed] = rfreq_cls[i]
    else:
      polyph_cls, rfreq_cls = [0], [0]
      polyph_cls_expanded, rfreq_cls_expanded = [0], [0]

    bar_tokens = convert_event(bar_events, self.event2idx, to_ndarr=False)
    bar_pos = bar_pos.tolist() + [len(bar_tokens)]

    enc_inp, enc_padding_mask, enc_lens = self.get_encoder_input_data(bar_pos, bar_tokens)

    length = len(bar_tokens)
    if self.pad_to_same:
      inp = self.pad_sequence(bar_tokens, self.model_dec_seqlen + 1) 
    else:
      inp = self.pad_sequence(bar_tokens, len(bar_tokens) + 1, pad_value=self.dec_end_pad_value)
    target = np.array(inp[1:], dtype=int)
    inp = np.array(inp[:-1], dtype=int)
    assert len(inp) == len(target)

    return {
      'id': idx,
      # 'piece_id': int(os.path.basename(self.pieces[idx]).replace('.pkl', '')),
      'piece_id': os.path.basename(self.pieces[idx]).replace('.pkl', ''),
      'st_bar_id': st_bar,
      'bar_pos': np.array(bar_pos, dtype=int),
      'enc_input': enc_inp,
      'dec_input': inp[:self.model_dec_seqlen],
      'dec_target': target[:self.model_dec_seqlen],
      'polyph_cls': polyph_cls_expanded,
      'rhymfreq_cls': rfreq_cls_expanded,
      'polyph_cls_bar': np.array(polyph_cls),
      'rhymfreq_cls_bar': np.array(rfreq_cls),
      'length': min(length, self.model_dec_seqlen),
      'enc_padding_mask': enc_padding_mask,
      'enc_length': enc_lens,
      'enc_n_bars': enc_n_bars,
      'text_label': text_labels,
      'pos': pos,  # whether this is a positive tuple
    }

In [32]:
# change pos_map and neg_map to generate val_dset, test_dset
train_dset = REMIFullSongTransformerDataset(
  './remi_dataset', './pickles/remi_vocab.pkl', do_augment=True, use_attr_cls=True,
  model_max_bars=16, model_dec_seqlen=1280, model_enc_seqlen=128, min_pitch=22, max_pitch=107,
  pos_map=pos_train, neg_map=neg_train,
)
print (train_dset.bar_token, train_dset.pad_token, train_dset.vocab_size)
print ('length:', len(train_dset))

# for i in random.sample(range(len(dset)), 100):
# for i in range(len(dset)):
#   sample = dset[i]
  # print (i, len(sample['bar_pos']), sample['bar_pos'])
  # print (i)
  # print ('******* ----------- *******')
  # print ('piece: {}, st_bar: {}'.format(sample['piece_id'], sample['st_bar_id']))
  # print (sample['enc_input'][:8, :16])
  # print (sample['dec_input'][:16])
  # print (sample['dec_target'][:16])
  # print (sample['enc_padding_mask'][:32, :16])
  # print (sample['length'])


[preparing data] now at #0
[preparing data] now at #200
[preparing data] now at #400
[preparing data] now at #600
[preparing data] now at #800
[preparing data] now at #1000
[preparing data] now at #1200
[preparing data] now at #1400
[preparing data] now at #1600
[preparing data] now at #1800
[preparing data] now at #2000
[preparing data] now at #2200
[preparing data] now at #2400
[preparing data] now at #2600
[preparing data] now at #2800
[preparing data] now at #3000
[preparing data] now at #3200
[preparing data] now at #3400
[preparing data] now at #3600
[preparing data] now at #3800
[preparing data] now at #4000
[preparing data] now at #4200
[preparing data] now at #4400
[preparing data] now at #4600
[preparing data] now at #4800
[preparing data] now at #5000
[preparing data] now at #5200
[preparing data] now at #5400
[preparing data] now at #5600
[preparing data] now at #5800
[preparing data] now at #6000
[preparing data] now at #6200
[preparing data] now at #6400
[preparing data] 

In [35]:
train_dloader = DataLoader(train_dset, batch_size=4, shuffle=False, num_workers=24)
for i, batch in enumerate(train_dloader):
  for k, v in batch.items():
    if torch.is_tensor(v):
      print (k, ':', v.dtype, v.size())
    else:
      print(k, ':', v)
  break
  print ('=====================================\n')

  cpuset_checked))


id : torch.int64 torch.Size([4])
piece_id : ['000869391f4fc43161b1d30676cb3fa8_0', '000869391f4fc43161b1d30676cb3fa8_0', '000869391f4fc43161b1d30676cb3fa8_0', '000869391f4fc43161b1d30676cb3fa8_0']
st_bar_id : torch.int64 torch.Size([4])
bar_pos : torch.int64 torch.Size([4, 17])
enc_input : torch.int64 torch.Size([4, 16, 128])
dec_input : torch.int64 torch.Size([4, 1280])
dec_target : torch.int64 torch.Size([4, 1280])
polyph_cls : torch.int64 torch.Size([4, 1280])
rhymfreq_cls : torch.int64 torch.Size([4, 1280])
polyph_cls_bar : torch.int64 torch.Size([4, 16])
rhymfreq_cls_bar : torch.int64 torch.Size([4, 16])
length : torch.int64 torch.Size([4])
enc_padding_mask : torch.bool torch.Size([4, 16, 128])
enc_length : torch.float64 torch.Size([4, 16])
enc_n_bars : torch.int64 torch.Size([4])
text_label : ["'B*Witched' is my first pop album I brought, and I am impressed. Unlike other pop albums, this one has a strong cultural feel in most of the songs. This is what makes the album so original