In [1]:
%cd ..
%load_ext autoreload
%autoreload 2

/home/teo/userdata/SejongMusic


In [2]:
from sejong_music.yeominrak_processing import AlignedScore


class MaskedDataset:
    def __init__(self, aligned_score: AlignedScore):
        self.aligned_score = aligned_score

    def __len__(self):
        return len(self.masked_score)

    def __getitem__(self, idx):
        return self.aligned_score[idx]
      
dataset = MaskedDataset(aligned_score=AlignedScore())

In [3]:
aligned_score = AlignedScore()

In [4]:
len(aligned_score.measure_features[1])

160

In [42]:
import numpy as np
class BeatPosConverter:
  def __init__(self) -> None:
    self.era_group = [(0, 1), (2, 3, 4, 5), (6, 7)]
    self.era2group = {era:group for group, eras in enumerate(self.era_group) for era in eras}
    self.anchor_points =  [
      {0.0:0.0, 2.0: 2.0, 3.0:3.0, 5.0:6.0, 6.0:8.0, 7.0:9.0, 8.0:10.0},
      {0.0:0.0, 1.0:1.0, 2.0:2.0, 3.0:3.0, 6.0:5.0, 7.0:6.0, 8.0:7.0, 9.0:8.0, 10.0:10.0},
    ]
    self.measure_len = [8, 8, 10, 10, 10, 10, 10, 10]
  
  def __call__(self, beat_pos:float, org_era_idx:int, target_era_idx:int):
    org_group, target_group = self.era2group[org_era_idx], self.era2group[target_era_idx]
    if org_group == target_group:
      return beat_pos
    assert abs(org_group - target_group) == 1, f"Only one group difference is allowed. {org_era_idx} -> {target_era_idx}"
    if beat_pos >= self.measure_len[org_era_idx]:
      num_measures = int(beat_pos // self.measure_len[org_era_idx])
      beat_pos -= num_measures * self.measure_len[org_era_idx]
    else:
      num_measures = 0


    lower_group = min(org_group, target_group)
    
    anchor_points = self.anchor_points[lower_group]
    forward = np.array(list(anchor_points.keys()))
    backward = np.array(list(anchor_points.values()))
    if org_group > target_group:
      backward, forward = forward, backward
    conv_beat_pos = np.interp(beat_pos, forward, backward)
    return conv_beat_pos + num_measures * self.measure_len[target_era_idx]
    
                          
    
b_pos_conv = BeatPosConverter()
b_pos_conv(9.0, 5, 7)
    
assert b_pos_conv(0.0, 0, 1) == 0.0
assert b_pos_conv(0.0, 1, 0) == 0.0
assert b_pos_conv(0.0, 0, 2) == 0.0
assert b_pos_conv(0.0, 2, 0) == 0.0
assert b_pos_conv(4.0, 1, 5) == 4.5, b_pos_conv(4.0, 1, 5)
assert abs(b_pos_conv(4.0, 5, 1) - 3.66666) < 0.0001
assert abs(b_pos_conv(5.0, 5, 7) - 4.33333) < 0.0001
assert b_pos_conv(5.0, 7, 5) == 6.0
assert abs(b_pos_conv(4.33333, 7, 5) - 5.0) < 0.0001

In [49]:


import numpy as np

front_idx = 5
back_idx = 7

start_measure = 0
end_measure = 2


front_part = aligned_score.measure_features[front_idx][start_measure:end_measure]
front_part = [note for measure in front_part for note in measure]
back_part = aligned_score.measure_features[back_idx][start_measure:end_measure]
back_part = [note for measure in back_part for note in measure]

f_note_id = 0
b_note_id = 0


f_accum_position = np.cumsum([0] + [note[2] for note in front_part])
b_accum_position = np.cumsum([0] + [note[2] for note in back_part])


while f_note_id < len(front_part) and b_note_id < len(back_part):
    f_pos = b_pos_conv(f_accum_position[f_note_id], front_idx, back_idx)
    b_pos = b_accum_position[b_note_id]
    # print(f" f_pos: {f_pos}, b_pos: {b_pos}")
    if f_pos > b_pos:
      b_note_id += 1
      continue
    elif f_pos < b_pos:
      f_note_id += 1
      continue
    f_note = front_part[f_note_id]
    b_note = back_part[b_note_id]
    if f_note[1] == b_note[1]:
        print(f_note, b_note)
    f_note_id += 1
    b_note_id += 1

[5, 51.0, 1.0, 0.0, 'strong', 1] [7, 51.0, 1.0, 0.0, 'strong', 1]
[5, 39.0, 0.0, 1.0, 'weak', 1] [7, 39.0, 0.0, 1.0, 'weak', 1]
[5, 51.0, 2.0, 1.0, 'weak', 1] [7, 51.0, 2.0, 1.0, 'weak', 1]
[5, 51.0, 2.0, 3.0, 'middle', 1] [7, 51.0, 1.0, 3.0, 'middle', 1]
[5, 39.0, 0.0, 1.0, 'weak', 2] [7, 39.0, 0.0, 1.0, 'weak', 2]
[5, 51.0, 1.0, 1.0, 'weak', 2] [7, 51.0, 1.0, 1.0, 'weak', 2]
[5, 51.0, 1.0, 2.0, 'weak', 2] [7, 51.0, 1.0, 2.0, 'weak', 2]
[5, 51.0, 1.0, 9.0, 'weak', 2] [7, 51.0, 1.0, 8.0, 'weak', 2]


In [None]:
from music21 import converter
from pathlib import Path
import pickle
from sejong_music.yeominrak_processing import Part

class T5Dataset:
  def __init__(self, xml_path) -> None:
    self.score = converter.parse(xml_path)
    # self.parts = [Part(part, i+start_part_idx) for i, part in enumerate(self.score.parts)]
    if Path(xml_path).with_suffix('.pkl').exists():
        self.parts = pickle.load(open(Path(xml_path).with_suffix('.pkl'), 'rb'))
    else:
        self.parts = [Part(part, i) for i, part in enumerate(self.score.parts)]
        pickle.dump(self.parts, open(Path(xml_path).with_suffix('.pkl'), 'wb'))


In [46]:
front_part, back_part

([[4, 51.0, 1.0, 0.0, 'strong', 1],
  [4, 39.0, 0.0, 1.0, 'weak', 1],
  [4, 51.0, 1.0, 1.0, 'weak', 1],
  [4, 39.0, 1.0, 2.0, 'weak', 1],
  [4, 51.0, 3.0, 3.0, 'middle', 1],
  [4, 51.0, 1.0, 6.0, 'strong', 1],
  [4, 39.0, 0.0, 7.0, 'weak', 1],
  [4, 51.0, 1.0, 7.0, 'weak', 1],
  [4, 48.0, 1.0, 8.0, 'middle', 1],
  [4, 53.0, 1.0, 9.0, 'weak', 1],
  [4, 51.0, 1.0, 0.0, 'strong', 2],
  [4, 39.0, 0.0, 1.0, 'weak', 2],
  [4, 51.0, 1.0, 1.0, 'weak', 2],
  [4, 48.0, 1.0, 2.0, 'weak', 2],
  [4, 51.0, 3.0, 3.0, 'middle', 2],
  [4, 56.0, 1.0, 6.0, 'strong', 2],
  [4, 39.0, 0.0, 7.0, 'weak', 2],
  [4, 56.0, 1.0, 7.0, 'weak', 2],
  [4, 56.0, 0.5, 8.0, 'middle', 2],
  [4, 53.0, 0.5, 8.5, 'weak', 2],
  [4, 51.0, 1.0, 9.0, 'weak', 2]],
 [[7, 51.0, 1.0, 0.0, 'strong', 1],
  [7, 39.0, 0.0, 1.0, 'weak', 1],
  [7, 51.0, 2.0, 1.0, 'weak', 1],
  [7, 51.0, 1.0, 3.0, 'middle', 1],
  [7, 46.0, 0.5, 4.0, 'weak', 1],
  [7, 51.0, 0.5, 4.5, 'weak', 1],
  [7, 39.0, 0.0, 5.0, 'strong', 1],
  [7, 51.0, 1.0, 5.0, 'st

In [5]:
dataset[0]

(tensor([[ 6,  1,  1,  1,  1,  1],
         [ 6, 10,  6,  3,  3,  3],
         [ 6, 13,  6,  4,  5,  3],
         [ 6,  6,  4,  5,  5,  3],
         [ 6, 13, 11,  5,  5,  3],
         [ 6, 13,  8, 11,  4,  3],
         [ 6, 10,  6, 15,  5,  3],
         [ 6, 13,  6, 17,  5,  3],
         [ 6,  6,  4, 19,  3,  3],
         [ 6, 13,  8, 19,  3,  3],
         [ 6, 13,  8, 22,  5,  3],
         [ 6, 16,  8, 25,  4,  3],
         [ 6, 14,  8, 29,  5,  3],
         [ 6, 13,  8, 32,  5,  3],
         [ 6, 10,  6,  3,  3,  4],
         [ 6, 13,  6,  4,  5,  4],
         [ 6,  6,  4,  5,  5,  4],
         [ 6, 13,  8,  5,  5,  4],
         [ 6, 13,  8,  7,  5,  4],
         [ 6, 11,  6, 11,  4,  4],
         [ 6, 13,  6, 13,  5,  4],
         [ 6, 10,  8, 15,  5,  4],
         [ 6, 14,  6, 19,  3,  4],
         [ 6, 16,  6, 20,  5,  4],
         [ 6,  6,  4, 22,  5,  4],
         [ 6, 16,  8, 22,  5,  4],
         [ 6, 16,  8, 25,  4,  4],
         [ 6, 14,  8, 29,  5,  4],
         [ 6, 13,  8