In [2]:
%cd ..
%load_ext autoreload
%autoreload 2

/home/teo/userdata/git_libraries/Yeominrak


In [3]:
from pathlib import Path
from music21 import converter, stream, note as m21_note 
import torch

from tqdm.auto import tqdm

from yeominrak_processing import AlignedScore, SamplingScore, pack_collate, ShiftedAlignedScore, OrchestraScore
from model_zoo import Seq2seq, Converter, AttentionSeq2seq, QkvAttnSeq2seq, get_emb_total_size, QkvAttnSeq2seqOrch
import random as random
from decode import MidiDecoder, OrchestraDecoder
import torch
from torch.utils.data import Dataset, DataLoader
from omegaconf import OmegaConf
from fractions import Fraction

import math
def round_number(number, decimals=3):
    scale = 10.0 ** decimals
    return math.floor(number * scale) / scale


def prepare_input_for_next_part(outputs_tensor):
  pred, condition = outputs_tensor[:, :3], outputs_tensor[:, 3:]
  next_input = torch.cat([torch.cat([pred[0:1], torch.tensor([[3,3,3]])], dim=1) , torch.cat([pred[1:], condition[:-1]], dim=1) ], dim=0)
  return next_input

def get_measure_specific_output(output:torch.LongTensor, measure_idx_in_token:int):
  corresp_ids = torch.where(output[:,-1] == measure_idx_in_token)[0]+1
  return output[corresp_ids]


def fix_measure_idx(sample:torch.LongTensor):
  new_sample = sample.clone()
  current_measure = 2
  for note in new_sample:
    if note[1] in (1, 2): continue
    if note[3] == 3: # beat_start
      current_measure += 1
    note[5] = current_measure
  return new_sample



def fill_in_source(src, num_measure):
  # src: list of token in strings
  outputs = []
  for note in src:
    if note[-1] in range(4-num_measure):
      continue
    else:
      outputs.append(note)
  return outputs

def get_measure_shifted_output(output: torch.Tensor):
  for i, token in enumerate(output):
    if token[-1] == 4: # new measure started:
      break
  for j in range(i, len(output)):
    if output[j][-1] == 6:
      break

  output_tokens_from_second_measure = output[i:j].clone()
  output_tokens_from_second_measure[0, 1] = 1
  output_tokens_from_second_measure[0, 2] = 1
  output_tokens_from_second_measure[:, 5] -= 1
  return output_tokens_from_second_measure


In [4]:

save_name = 'vanilla_CHP'
state = torch.load('outputs/2023-12-20/16-02-36/epoch600_model.pt')
state_alt = torch.load('outputs/2023-12-20/16-02-36/best_loss_model.pt')

config = OmegaConf.load('yamls/orchestration.yaml')
config = get_emb_total_size(config)

output_dir = Path('gen_results/')
output_dir.mkdir(parents=True, exist_ok=True)

device = 'cpu'


In [5]:
val_dataset = OrchestraScore(is_valid= True,
                            # valid_measure_num = [i for i in range(93, 99)],
                            xml_path='yeominlak_omr5.musicxml', 
                            use_pitch_modification=False, 
                            slice_measure_num=config.data.max_meas,
                            min_meas=config.data.min_meas,
                            feature_types=config.model.features,
                              sampling_rate=config.data.sampling_rate)

# test_set = ShiftedAlignedScore(is_valid= True,
#                                    xml_path = 'gen_results/CHP_scoreCPH_seq2seq_grace2_errorfixed.musicxml',
#                                     valid_measure_num = 'entire',
#                                     slice_measure_num=config.data.max_meas,
#                                     min_meas=config.data.min_meas,
#                                     feature_types=config.model.features,
#                                       sampling_rate=config.data.sampling_rate)

test_set = ShiftedAlignedScore(is_valid= True,
                                   xml_path = 'gen_results/CHP_scoreseq2seq_grace2_errorfixed.musicxml',
                                    valid_measure_num = 'entire',
                                    slice_measure_num=config.data.max_meas,
                                    min_meas=config.data.min_meas,
                                    feature_types=config.model.features,
                                      sampling_rate=config.data.sampling_rate)


test_set.result_pairs = [x for x in test_set.result_pairs if x[0] == 0][:len(test_set.slice_info)]
for pair in test_set.result_pairs:
  pair[0] = 7
  pair[1] = 7
test_set.tokenizer = val_dataset.era_dataset.tokenizer


In [6]:

for i in range(30):
  offset = round_number(i/3, 3)
  if offset in test_set.tokenizer.tok2idx['offset']:
     continue
  if i < 6:
    other_value = i % 3 + 6
  else:
    other_value = i % 3 + 21
  test_set.tokenizer.tok2idx['offset'][offset] = test_set.tokenizer.tok2idx['offset'][round_number(other_value/3, 3)]

for i in range(30):
  offset = round_number(i/2, 3)
  if offset in test_set.tokenizer.tok2idx['offset']:
     continue
  other_value = i - 2
  test_set.tokenizer.tok2idx['offset'][offset] = test_set.tokenizer.tok2idx['offset'][round_number(other_value/2, 3)]


In [29]:
state = torch.load('outputs/2023-12-20/16-02-36/epoch400_model.pt')
state_alt = torch.load('outputs/2023-12-20/16-02-36/epoch200_model.pt')

model = QkvAttnSeq2seqOrch(val_dataset.era_dataset.tokenizer, val_dataset.tokenizer, config.model).to(device)
# state = torch.load('best_model.pt')
model.is_condition_shifted = True
model.load_state_dict(state)
model.eval()

model_alt = QkvAttnSeq2seqOrch(val_dataset.era_dataset.tokenizer, val_dataset.tokenizer, config.model).to(device)
model_alt.is_condition_shifted = True
model_alt.load_state_dict(state_alt)
model_alt.eval()

decoder = OrchestraDecoder(val_dataset.tokenizer)
source_decoder = MidiDecoder(val_dataset.era_dataset.tokenizer)


if 'offset' in model.tokenizer.tok2idx:
  for i in range(120):
    if i % 4 == 0:
      continue
    if i / 4 not in model.tokenizer.tok2idx['offset']:
      model.tokenizer.tok2idx['offset'][i/4] = model.tokenizer.tok2idx['offset'][(i-2)/4]

model_alt.tokenizer = model.tokenizer

In [35]:
print("Start inference")
score = stream.Score(id='mainScore')

prev_generation = None
outputs = []
source_part = stream.Part()
merged_part = stream.Part() 
srcs = []


for target_idx in range(6):
  outputs = []
  prev_generation = None
  for i in tqdm(range(len(test_set))):
    sample, _, _ = test_set[i]
    src, output_decoded, (attention_map, output, new_out) = model.shifted_inference(sample, target_idx, prev_generation=prev_generation)
    final_measure_num =  output[-1][-1]-3
    while final_measure_num != 4:
      src, output_decoded, (attention_map, output, new_out) = model.shifted_inference(sample, target_idx, prev_generation=prev_generation)
      final_measure_num =  output[-1][-1]-3
    if target_idx == 0 and i % test_set.slice_measure_number == 0:
      srcs.append(src)
    if i == 0:
      sel_out = torch.cat([output[0:1]] + [get_measure_specific_output(output, i) for i in range(3,6)], dim=0)
    else:
      sel_out = get_measure_specific_output(output, 5) 
      duration = sum([note[2] for note in model.converter(sel_out)])
      while abs(duration - 30) > 0.1:
        print(f"duration error: {duration}, inst: {target_idx}, measure: {i}")
        src, output_decoded, (attention_map, output, new_out) = model.shifted_inference(sample, target_idx, prev_generation=prev_generation)
        final_measure_num =  output[-1][-1]-3
        if final_measure_num != 4:
          continue
        sel_out = get_measure_specific_output(output, 5)
        duration = sum([note[2] for note in model.converter(sel_out)])

    prev_generation = get_measure_shifted_output(output)
    outputs.append(sel_out)

  if target_idx == 0:
    srcs.append(fill_in_source(src, i % test_set.slice_measure_number))
    source_converted = [y for x in srcs for y in x]
    source_part = source_decoder(source_converted[1:])
    score.insert(0, source_part)

  outputs.append(get_measure_specific_output(output, 6)) # add last measure
  outputs_tensor = torch.cat(outputs, dim=0)
  final_midi = decoder(model.converter(outputs_tensor))
  score.insert(0, final_midi)



Start inference


  0%|          | 0/505 [00:00<?, ?it/s]

duration error: 32.25, inst: 0, measure: 29
duration error: 30.25, inst: 0, measure: 31
duration error: 35.25, inst: 0, measure: 44
duration error: 30.25, inst: 0, measure: 50
duration error: 31.0, inst: 0, measure: 60
duration error: 34.75, inst: 0, measure: 126
duration error: 34.75, inst: 0, measure: 140
duration error: 34.0, inst: 0, measure: 145
duration error: 34.75, inst: 0, measure: 159
duration error: 34.75, inst: 0, measure: 161
duration error: 31.75, inst: 0, measure: 179
duration error: 33.75, inst: 0, measure: 182
duration error: 30.5, inst: 0, measure: 194
duration error: 31.75, inst: 0, measure: 196
duration error: 35.25, inst: 0, measure: 226
duration error: 30.25, inst: 0, measure: 268
duration error: 31.75, inst: 0, measure: 277
duration error: 30.25, inst: 0, measure: 283
duration error: 31.75, inst: 0, measure: 284
duration error: 32.25, inst: 0, measure: 304
duration error: 34.0, inst: 0, measure: 304
duration error: 30.25, inst: 0, measure: 321
duration error: 32.

  0%|          | 0/505 [00:00<?, ?it/s]

duration error: 31.75, inst: 1, measure: 6
duration error: 31.75, inst: 1, measure: 7
duration error: 31.75, inst: 1, measure: 36
duration error: 31.75, inst: 1, measure: 133
duration error: 31.75, inst: 1, measure: 156
duration error: 31.75, inst: 1, measure: 156
duration error: 30.25, inst: 1, measure: 214
duration error: 31.25, inst: 1, measure: 250
duration error: 31.75, inst: 1, measure: 250
duration error: 35.25, inst: 1, measure: 264
duration error: 31.75, inst: 1, measure: 271
duration error: 31.75, inst: 1, measure: 339
duration error: 30.25, inst: 1, measure: 379
duration error: 30.25, inst: 1, measure: 394
duration error: 30.25, inst: 1, measure: 412
duration error: 32.25, inst: 1, measure: 416
duration error: 31.75, inst: 1, measure: 421
duration error: 31.25, inst: 1, measure: 437
duration error: 31.75, inst: 1, measure: 451
duration error: 32.25, inst: 1, measure: 469


  0%|          | 0/505 [00:00<?, ?it/s]

duration error: 31.75, inst: 2, measure: 9
duration error: 32.25, inst: 2, measure: 28
duration error: 30.25, inst: 2, measure: 81
duration error: 32.5, inst: 2, measure: 90
duration error: 31.75, inst: 2, measure: 94
duration error: 31.75, inst: 2, measure: 177
duration error: 31.75, inst: 2, measure: 209
duration error: 31.75, inst: 2, measure: 280
duration error: 31.75, inst: 2, measure: 301
duration error: 30.25, inst: 2, measure: 338
duration error: 31.75, inst: 2, measure: 373
duration error: 32.25, inst: 2, measure: 428
duration error: 31.75, inst: 2, measure: 492


  0%|          | 0/505 [00:00<?, ?it/s]

duration error: 32.25, inst: 3, measure: 124


  0%|          | 0/505 [00:00<?, ?it/s]

duration error: 31.5, inst: 4, measure: 317


  0%|          | 0/505 [00:00<?, ?it/s]

duration error: 32.5, inst: 5, measure: 25
duration error: 31.0, inst: 5, measure: 162
duration error: 30.5, inst: 5, measure: 193
duration error: 30.5, inst: 5, measure: 193
duration error: 32.25, inst: 5, measure: 228
duration error: 32.75, inst: 5, measure: 236


In [36]:
score.write('musicxml', fp='chihwapyeong_orchestraion_final.musicxml')

PosixPath('/home/teo/userdata/git_libraries/Yeominrak/chihwapyeong_orchestraion_final.musicxml')

In [33]:
final_measure_num

tensor(4)

In [12]:

print("Start inference")
score = stream.Score(id='mainScore')

prev_generation = None
outputs = []
source_part = stream.Part()
merged_part = stream.Part() 
srcs = []

for target_idx in range(6):
# for target_idx in range(1):
  outputs = []
  prev_generation = None
  for i in tqdm(range(len(test_set))):
  # for i in tqdm(range(20)):
    sample, _, _ = test_set[i]
    src, output_decoded, (attention_map, output, new_out) = model.shifted_inference(sample, target_idx, prev_generation=prev_generation)
    if target_idx == 0 and i % test_set.slice_measure_number == 0:
      srcs.append(src)
    if i == 0:
      sel_out = torch.cat([output[0:1]] + [get_measure_specific_output(output, i) for i in range(3,6)], dim=0)
    else:
      sel_out = get_measure_specific_output(output, 5) 
      duration = sum([note[2] for note in model.converter(sel_out)])
    prev_generation = get_measure_shifted_output(output)
    outputs.append(sel_out)

  if target_idx == 0:
    srcs.append(fill_in_source(src, i % test_set.slice_measure_number))
    source_converted = [y for x in srcs for y in x]
    source_part = source_decoder(source_converted[1:])
    score.insert(0, source_part)

  outputs.append(get_measure_specific_output(output, 6)) # add last measure
  outputs_tensor = torch.cat(outputs, dim=0)
  final_midi = decoder(model.converter(outputs_tensor))
  score.insert(0, final_midi)


Start inference


  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [9]:
train_dataset = OrchestraScore(is_valid= False, 
                              xml_path='yeominlak_omr5.musicxml',
                              use_pitch_modification = config.data.use_pitch_modification, 
                              pitch_modification_ratio=config.data.modification_ratio,
                              min_meas=config.data.min_meas, 
                              max_meas=config.data.max_meas,
                              feature_types=config.model.features,
                              sampling_rate=config.data.sampling_rate,
                              slice_measure_num=config.data.max_meas
                              )


In [20]:
for idx in range(len(train_dataset)):
  _, tgt, _ = train_dataset.get_processed_feature(7, 0, idx)
  print(tgt[-1])

tensor([ 0, 26, 27,  3,  3,  7])
tensor([ 0, 27,  6,  3,  3,  7])
tensor([ 0, 27,  6,  3,  3,  7])
tensor([ 0, 22, 15,  3,  3,  7])
tensor([ 0, 20,  6,  3,  3,  7])
tensor([ 0, 18, 15,  3,  3,  7])
tensor([ 0, 23,  6,  3,  3,  7])
tensor([ 0, 22, 15,  3,  3,  7])
tensor([ 0, 23,  7,  3,  3,  7])
tensor([ 0, 20,  6,  3,  3,  7])
tensor([ 0, 20,  6,  3,  3,  7])
tensor([ 0, 20,  6,  3,  3,  7])
tensor([ 0, 20,  6,  3,  3,  7])
tensor([ 0, 22, 15,  3,  3,  7])
tensor([ 0, 18, 15,  3,  3,  7])
tensor([ 0, 20,  6,  3,  3,  7])
tensor([ 0, 26,  6,  3,  3,  7])
tensor([ 0, 20,  6,  3,  3,  7])
tensor([ 0, 20,  6,  3,  3,  7])
tensor([ 0, 20,  6,  3,  3,  7])
tensor([ 0, 19,  6,  3,  3,  7])
tensor([ 0, 19, 10,  3,  3,  7])
tensor([ 0, 18, 15,  3,  3,  7])
tensor([ 0, 22, 15,  3,  3,  7])
tensor([ 0, 18, 15,  3,  3,  7])
tensor([ 0, 22,  6,  3,  3,  7])
tensor([ 0, 20,  6,  3,  3,  7])
tensor([ 0, 23,  6,  3,  3,  7])
tensor([ 0, 20,  6,  3,  3,  7])
tensor([ 0, 18, 15,  3,  3,  7])
tensor([ 0

In [72]:
train_dataset.slice_measure_number

2

In [86]:
for idx in range(len(train_dataset)):
  source, _, shifted_tgt = train_dataset.get_processed_feature(7, 5, 100)
  print(shifted_tgt.shape, shifted_tgt[-2])
  # print(source.shape, source[-2])
  break

[[6, 'start', 'start', 0.0, 'strong', 0], [6, 46.0, 6.0, 6.0, 'weak', 0], [6, 39.0, 3.0, 9.0, 'middle', 0], [6, 46.0, 6.0, 15.0, 'strong', 0], [6, 46.0, 6.0, 21.0, 'middle', 0], [6, 46.0, 3.0, 24.0, 'weak', 0], [6, 44.0, 6.0, 0.0, 'strong', 1], [6, 56.0, 6.0, 6.0, 'weak', 1], [6, 53.0, 3.0, 9.0, 'middle', 1], [6, 51.0, 6.0, 15.0, 'strong', 1], [6, 46.0, 6.0, 21.0, 'middle', 1], [6, 48.0, 3.0, 24.0, 'weak', 1], [6, 53.0, 3.0, 27.0, 'weak', 1], [6, 56.0, 3.0, 0.0, 'strong', 2], [6, 51.0, 6.0, 6.0, 'weak', 2], [6, 39.0, 3.0, 9.0, 'middle', 2], [6, 51.0, 6.0, 15.0, 'strong', 2], [6, 51.0, 6.0, 21.0, 'middle', 2], [6, 51.0, 3.0, 24.0, 'weak', 2], [6, 53.0, 6.0, 0.0, 'strong', 3], [6, 51.0, 6.0, 6.0, 'weak', 3], [6, 48.0, 3.0, 9.0, 'middle', 3], [6, 51.0, 6.0, 15.0, 'strong', 3], [6, 56.0, 6.0, 21.0, 'middle', 3], [6, 53.0, 3.0, 24.0, 'weak', 3], [6, 51.0, 6.0, 0, 'strong', 4]]
[[6, 'start', 'start', 0.0, 'strong', 0], [6, 60.0, 6.0, 6.0, 'weak', 0], [6, 65.0, 3.0, 9.0, 'middle', 0], [6, 46.