In [1]:
%cd ..
%load_ext autoreload
%autoreload 2


/home/teo/userdata/SejongMusic


In [30]:
import argparse
import datetime
from pathlib import Path
from omegaconf import OmegaConf
from fractions import Fraction
from tqdm.auto import tqdm

import torch
from music21 import converter, stream, note as m21_note 

from sejong_music import model_zoo
from sejong_music.yeominrak_processing import AlignedScore, SamplingScore, pack_collate, ShiftedAlignedScore, TestScore, TestScoreCPH, Tokenizer, SimpleNote
from sejong_music.model_zoo import QkvAttnSeq2seq, get_emb_total_size, QkvAttnSeq2seqMore
from sejong_music.decode import MidiDecoder
from torch.utils.data import  DataLoader

from make_inference import get_measure_shifted_output, prepare_input_for_next_part, get_measure_specific_output, fix_measure_idx, recover_beat, fill_in_source


In [3]:
state_dir = 'outputs/2024-03-08/00-27-49/wandb/run-20240308_002750-ui4ko24t/files/checkpoints'

In [88]:
state = torch.load(f'{state_dir}/best_pitch_sim_model.pt')

config = OmegaConf.load(f'{state_dir}/config.yaml')
tokenizer = Tokenizer(parts=None, feature_types=None, json_fn=f'{state_dir}/tokenizer_vocab.json')
# config = get_emb_total_size(config)
device = 'cpu'
# val_dataset = ShiftedAlignedScore(is_valid= True, slice_measure_num=4, feature_types=config.model.features)
# valid_loader = DataLoader(val_dataset, batch_size=4, collate_fn=pack_collate, shuffle=False)
model = getattr(model_zoo, config.model_class)(tokenizer, config.model).to(device)
model.is_condition_shifted = True
model.load_state_dict(state)
model.eval()

decoder = MidiDecoder(tokenizer)


In [91]:
class TestScoreCPH2(TestScore):
  def __init__(self, xml_path='0_edited.musicxml', valid_measure_num=..., slice_measure_num=2, is_valid=False, use_pitch_modification=False, pitch_modification_ratio=0, min_meas=3, max_meas=6, transpose=0, feature_types=..., sampling_rate=None) -> None:
    super().__init__(xml_path, valid_measure_num, slice_measure_num, is_valid, use_pitch_modification, pitch_modification_ratio, min_meas, max_meas, transpose, feature_types, sampling_rate)


  def fix_part_by_rule(self):
    flattened_notes = [SimpleNote(note) for measure in self.parts[0].measures for note in measure]
              
    measure_shifted_notes = []
    for note in flattened_notes:

      ### pitch modification
      if note.pitch == 51:
        note.pitch = 52 # 51 is error
      if note.pitch == 39:
        note.pitch = 40
      if note.pitch == 56:
        note.pitch = 57
      if note.pitch == 57:
        note.pitch = 45
      measure_shifted_notes.append(note)


    # make measure list
    note_by_measure = []
    temp_measure = []
    prev_measure_number = 0
    for note in measure_shifted_notes:
      if note.measure_number != prev_measure_number:
        note_by_measure.append(temp_measure)
        temp_measure = []
        prev_measure_number = note.measure_number
      temp_measure.append(note)
    note_by_measure.append(temp_measure)

    self.parts[0].measures = note_by_measure
    self.measure_features = [self.get_feature(i) for i in range(len(self.parts))]

    part_zero = self.measure_features[0].copy()
    for measure in part_zero:
      for note in measure:
        note[2] *= 2
        note[3] *= 2
    self.measure_features[0] = part_zero
    self.offset_list[0] *= 2

test_set = TestScoreCPH2(xml_path="music_score/chwipunghyeong.musicxml", is_valid=True, 
                        #  valid_measure_num=[i for i in range(133)], 
                        valid_measure_num='entire',
                        slice_measure_num=4, transpose=+8, feature_types=config.model.features)
test_set.tokenizer = tokenizer

In [16]:
test_set =  TestScoreCPH(xml_path="music_score/chwipunghyeong.musicxml", is_valid=True, 
                        #  valid_measure_num=[i for i in range(133)], 
                        valid_measure_num='entire',
                        slice_measure_num=4, transpose=+8, feature_types=config.model.features)
test_set.tokenizer = tokenizer

In [85]:
tokenizer.vocab['pitch']

['pad',
 'start',
 'end',
 34.0,
 36.0,
 39.0,
 41.0,
 43.0,
 44.0,
 46.0,
 48.0,
 49.0,
 51.0,
 53.0,
 55.0,
 56.0,
 58.0,
 60.0,
 63.0]

In [89]:
if 6.0 not in tokenizer.vocab['duration']:
  tokenizer.vocab['duration'].append(6.0)
  tokenizer.tok2idx['duration'][6.0] = len(tokenizer.tok2idx['duration'])
  
  # add weight to the model embedding
  dup_weight = model.encoder.emb.layers[2].weight[tokenizer.tok2idx['duration'][5.0]]

  model.encoder.emb.layers[2].weight.data = torch.cat([model.encoder.emb.layers[2].weight, dup_weight.unsqueeze(0)], dim=0)

tokenizer.vocab['duration']

['pad',
 'start',
 'end',
 8.0,
 0.0,
 Fraction(1, 3),
 0.5,
 Fraction(2, 3),
 1.0,
 1.5,
 1.6666666666666665,
 2.0,
 2.5,
 3.0,
 4.0,
 5.0,
 6.0]

In [92]:
test_set[0]

tensor([[ 0,  1,  1,  1,  1,  1],
        [ 0, 12, 11,  3,  3,  3],
        [ 0, 13, 13,  7,  5,  3],
        [ 0, 12, 13, 19,  3,  3],
        [ 0, 12, 11,  3,  3,  4],
        [ 0, 12, 16,  7,  5,  4],
        [ 0, 12, 11,  3,  3,  5],
        [ 0, 13,  8,  7,  5,  5],
        [ 0, 12,  6, 11,  4,  5],
        [ 0, 10,  9, 13,  5,  5],
        [ 0,  9, 11, 19,  3,  5],
        [ 0, 10,  8, 25,  5,  5],
        [ 0,  2,  2,  2,  2,  2]])

In [96]:
test_set.tokenizer = tokenizer
print("Start inference")
score = stream.Score(id='mainScore')

prev_generation = None
outputs = []
source_part = stream.Part()
merged_part = stream.Part() 
srcs = []
start_idx = 1
target_idx = 2

for i in tqdm(range(len(test_set))):
  sample = test_set[i]
  sample[:,0] = start_idx # regard as era 1
  src, output_decoded, (attention_map, output, new_out) = model.shifted_inference(sample, target_idx, prev_generation=prev_generation)
  if i % test_set.slice_measure_number == 0:
    srcs.append(src)
  if i == 0:
    sel_out = torch.cat([output[0:1]] + [get_measure_specific_output(output, i) for i in range(3,6)], dim=0)
  else:
    sel_out = get_measure_specific_output(output, 5)
    sel_duration = sum([x[2] for x in model.converter(sel_out)])
    if abs(sel_duration - model.tokenizer.measure_duration[target_idx]) > 0.001:
      print(f'Generated duration: {sel_duration}, target idx: {target_idx}, Running Alt model inference')
      src, output_decoded, (attention_map, output, new_out) = model_alt.shifted_inference(sample, target_idx, prev_generation=prev_generation)
      sel_out = get_measure_specific_output(output, 5) 
      sel_duration = sum([x[2] for x in model.converter(sel_out)])
      assert abs(sel_duration - model.tokenizer.measure_duration[target_idx]) < 0.001
  prev_generation = get_measure_shifted_output(output)
  outputs.append(sel_out)

srcs.append(fill_in_source(src, i % test_set.slice_measure_number))
source_converted = [y for x in srcs for y in x]
source_part = decoder(source_converted)
outputs.append(get_measure_specific_output(output, 6)) # add last measure
outputs_tensor = torch.cat(outputs, dim=0)
final_midi = decoder(model.converter(outputs_tensor))
score.insert(0, source_part)
score.insert(0, final_midi)


Start inference


  0%|          | 0/130 [00:00<?, ?it/s]

100%|██████████| 130/130 [00:01<00:00, 67.66it/s]


In [97]:
score.write('musicxml', fp='gen_results/chwipunghyeong_testest4.musicxml')

PosixPath('/home/teo/userdata/SejongMusic/gen_results/chwipunghyeong_testest4.musicxml')

In [25]:
tokenizer.tok2idx[tokenizer.key_types[1]][53.0]

13

In [20]:
tokenizer.key_types

['index',
 'pitch',
 'duration',
 'offset',
 'dynamic',
 'measure_idx',
 'offset_fraction']