In [1]:
%cd ..
%load_ext autoreload
%autoreload 2

/home/teo/userdata/SejongMusic


In [12]:
from make_orchestration_seq import get_argparse
import argparse
import datetime
import copy
from pathlib import Path
from omegaconf import OmegaConf
from fractions import Fraction
from tqdm.auto import tqdm

import torch
from music21 import converter, stream, note as m21_note 

from sejong_music import model_zoo
from sejong_music.yeominrak_processing import OrchestraScoreSeq, ShiftedAlignedScore, TestScore, TestScoreCPH, Tokenizer
from sejong_music.model_zoo import QkvAttnSeq2seq, get_emb_total_size, QkvAttnSeq2seqMore
from sejong_music.decode import MidiDecoder, OrchestraDecoder
from sejong_music.inference_utils import prepare_input_for_next_part, get_measure_specific_output, fix_measure_idx, fill_in_source, get_measure_shifted_output, recover_beat, round_number


In [16]:
args = get_argparse().parse_args([])

In [45]:
inst_codes = ['gmg', 'gyg', 'hg', 'pr', 'dg']
states = {}
configs = {}
gmg_config = OmegaConf.load(f'{args.gmg_state_dir}/config.yaml')
for inst_code in inst_codes:
  state_dir = getattr(args, f'{inst_code}_state_dir')
  states[inst_code] = torch.load(f'{state_dir}/epoch600_model.pt')
  configs[inst_code] = OmegaConf.load(f'{state_dir}/config.yaml')
tokenizer = Tokenizer(parts=None, feature_types=None, json_fn=f'{args.gmg_state_dir}/tokenizer_vocab.json')
era_tokenizer = Tokenizer(parts=None, feature_types=None, json_fn=f'{args.gmg_state_dir}/era_tokenizer_vocab.json')


In [46]:
if 'offset' in tokenizer.tok2idx:
  for i in range(120):
    if i % 4 == 0:
      continue
    if i / 4 not in tokenizer.tok2idx['offset']:
      print(f"i/4 not in tokenizer.tok2idx['offset']: {i/4}")
      if i >= 6:
        tokenizer.tok2idx['offset'][i/4] = tokenizer.tok2idx['offset'][(i-6)/4]
      else:
        tokenizer.tok2idx['offset'][i/4] = tokenizer.tok2idx['offset'][(i-2)/4]


i/4 not in tokenizer.tok2idx['offset']: 0.75
i/4 not in tokenizer.tok2idx['offset']: 1.75
i/4 not in tokenizer.tok2idx['offset']: 3.75
i/4 not in tokenizer.tok2idx['offset']: 6.25
i/4 not in tokenizer.tok2idx['offset']: 6.75
i/4 not in tokenizer.tok2idx['offset']: 8.25
i/4 not in tokenizer.tok2idx['offset']: 12.75
i/4 not in tokenizer.tok2idx['offset']: 13.75
i/4 not in tokenizer.tok2idx['offset']: 15.25
i/4 not in tokenizer.tok2idx['offset']: 15.75
i/4 not in tokenizer.tok2idx['offset']: 18.75
i/4 not in tokenizer.tok2idx['offset']: 23.25
i/4 not in tokenizer.tok2idx['offset']: 24.75
i/4 not in tokenizer.tok2idx['offset']: 25.75
i/4 not in tokenizer.tok2idx['offset']: 26.25
i/4 not in tokenizer.tok2idx['offset']: 27.25
i/4 not in tokenizer.tok2idx['offset']: 27.75
i/4 not in tokenizer.tok2idx['offset']: 29.25


In [47]:
tokenizer.tok2idx['offset']

{'pad': 0,
 'start': 1,
 'end': 2,
 0.0: 3,
 0.25: 4,
 0.5: 5,
 1.0: 6,
 1.25: 7,
 1.5: 8,
 2.0: 9,
 2.25: 10,
 2.5: 11,
 2.75: 12,
 3.0: 13,
 3.25: 14,
 3.5: 15,
 4.0: 16,
 4.25: 17,
 4.5: 18,
 4.75: 19,
 5.0: 20,
 5.25: 21,
 5.5: 22,
 5.75: 23,
 6.0: 24,
 6.5: 25,
 7.0: 26,
 7.25: 27,
 7.5: 28,
 7.75: 29,
 8.0: 30,
 8.5: 31,
 8.75: 32,
 9.0: 33,
 9.25: 34,
 9.5: 35,
 9.75: 36,
 10.0: 37,
 10.25: 38,
 10.5: 39,
 10.75: 40,
 11.0: 41,
 11.25: 42,
 11.5: 43,
 11.75: 44,
 12.0: 45,
 12.25: 46,
 12.5: 47,
 13.0: 48,
 13.25: 49,
 13.5: 50,
 14.0: 51,
 14.25: 52,
 14.5: 53,
 14.75: 54,
 15.0: 55,
 15.5: 56,
 16.0: 57,
 16.25: 58,
 16.5: 59,
 16.75: 60,
 17.0: 61,
 17.25: 62,
 17.5: 63,
 17.75: 64,
 18.0: 65,
 18.25: 66,
 18.5: 67,
 19.0: 68,
 19.25: 69,
 19.5: 70,
 19.75: 71,
 20.0: 72,
 20.25: 73,
 20.5: 74,
 20.75: 75,
 21.0: 76,
 21.25: 77,
 21.5: 78,
 21.75: 79,
 22.0: 80,
 22.25: 81,
 22.5: 82,
 22.75: 83,
 23.0: 84,
 23.5: 85,
 23.75: 86,
 24.0: 87,
 24.25: 88,
 24.5: 89,
 25.0: 90,
 

In [36]:
tg_idx = 1
inst = 'gyg'
device = 'cpu'

target_idx = len(inst_codes) - tg_idx - 1
print(f"target_idx: {target_idx}, inst: {inst}")
model = getattr(model_zoo, configs[inst].model_class)(tokenizer, tokenizer, configs[inst].model).to(device)
model.is_condition_shifted = True
model.load_state_dict(states[inst])
model.eval()
prev_xml_path = 'gen_results/CPH_Orch_0308-2320_seq_gmg.musicxml'

next_dataset = OrchestraScoreSeq(is_valid=True,
                                xml_path=prev_xml_path,
                                valid_measure_num = 'entire',
                                slice_measure_num=configs[inst].data.max_meas,
                                min_meas=configs[inst].data.min_meas,
                                feature_types=configs[inst].model.features,
                                sampling_rate=configs[inst].data.sampling_rate,
                                target_instrument=target_idx)
next_dataset.tokenizer = tokenizer
next_dataset.result_pairs = [x for x in next_dataset.result_pairs if x[1] == 1]
score = stream.Score(id='mainScore')
outputs = []
prev_generation = None

target_idx: 3, inst: gyg


In [39]:
decoder = OrchestraDecoder(tokenizer)


In [37]:
score = stream.Score(id='mainScore')
outputs = []
prev_generation = None
for i in tqdm(range(len(next_dataset))):
  sample, _, _ = next_dataset[i]
  _, output_decoded, (attention_map, output, new_out) = model.shifted_inference(sample, target_idx, prev_generation=prev_generation)
  if i == 0:
    sel_out = torch.cat([output[0:1]] + [get_measure_specific_output(output, i) for i in range(3,6)], dim=0)
  else:
    sel_out = get_measure_specific_output(output, 5) 
    duration = sum([note[2] for note in model.converter(sel_out)])
    if abs(duration - 30) > 0.1:
      break
    
  prev_generation = get_measure_shifted_output(output)
  outputs.append(sel_out)




100%|██████████| 128/128 [00:04<00:00, 25.97it/s]


In [40]:
# outputs.append(get_measure_specific_output(output, 6)) # add last measure
# outputs_tensor = torch.cat(outputs, dim=0)
gen_by_insts[inst] = decoder(model.converter(outputs_tensor))
for _ in range(target_idx+1):
  score.insert(0, copy.copy(gen_by_insts[inst]))

for prev_inst in reversed(inst_codes[:tg_idx]):
  score.insert(0, gen_by_insts[prev_inst])
    
prev_xml_path = output_dir/ f'{save_name}_seq_{inst}.musicxml'


NameError: name 'gen_by_insts' is not defined

In [50]:
model.converter(sample)

[[4, 'start', 'start', 'start', 'start', 'start'],
 [5, 46.0, 3.0, 0.0, 'strong', 0],
 [5, 46.0, 3.0, 3.0, 'weak', 0],
 [5, 39.0, 3.0, 6.0, 'weak', 0],
 [5, 56.0, 3.0, 9.0, 'middle', 0],
 [5, 58.0, 3.0, 12.0, 'weak', 0],
 [5, 58.0, 3.0, 15.0, 'strong', 0],
 [5, 58.0, 1.5, 18.0, 'weak', 0],
 [5, 58.0, 1.5, 19.5, 'weak', 0],
 [5, 56.0, 1.5, 21.0, 'middle', 0],
 [5, 58.0, 1.5, 22.5, 'weak', 0],
 [5, 56.0, 3.0, 24.0, 'weak', 0],
 [5, 53.0, 3.0, 27.0, 'weak', 0],
 [5, 58.0, 3.0, 0.0, 'strong', 1],
 [5, 58.0, 3.0, 3.0, 'weak', 1],
 [5, 53.0, 3.0, 6.0, 'weak', 1],
 [5, 51.0, 3.0, 9.0, 'middle', 1],
 [5, 53.0, 1.5, 12.0, 'weak', 1],
 [5, 56.0, 1.5, 13.5, 'weak', 1],
 [5, 56.0, 3.0, 15.0, 'strong', 1],
 [5, 53.0, 3.0, 18.0, 'weak', 1],
 [5, 56.0, 3.0, 21.0, 'middle', 1],
 [5, 56.0, 3.0, 24.0, 'weak', 1],
 [5, 39.0, 3.0, 27.0, 'weak', 1],
 [5, 53.0, 3.0, 0.0, 'strong', 2],
 [5, 53.0, 1.5, 3.0, 'weak', 2],
 [5, 51.0, 1.5, 4.5, 'weak', 2],
 [5, 58.0, 3.0, 6.0, 'weak', 2],
 [5, 51.0, 3.0, 9.0, 'mid

In [29]:
tokenizer.tok2idx['offset']

{'pad': 0,
 'start': 1,
 'end': 2,
 0.0: 3,
 0.25: 4,
 0.5: 5,
 1.0: 6,
 1.25: 7,
 1.5: 8,
 2.0: 9,
 2.25: 10,
 2.5: 11,
 2.75: 12,
 3.0: 13,
 3.25: 14,
 3.5: 15,
 4.0: 16,
 4.25: 17,
 4.5: 18,
 4.75: 19,
 5.0: 20,
 5.25: 21,
 5.5: 22,
 5.75: 23,
 6.0: 24,
 6.5: 25,
 7.0: 26,
 7.25: 27,
 7.5: 28,
 7.75: 29,
 8.0: 30,
 8.5: 31,
 8.75: 32,
 9.0: 33,
 9.25: 34,
 9.5: 35,
 9.75: 36,
 10.0: 37,
 10.25: 38,
 10.5: 39,
 10.75: 40,
 11.0: 41,
 11.25: 42,
 11.5: 43,
 11.75: 44,
 12.0: 45,
 12.25: 46,
 12.5: 47,
 13.0: 48,
 13.25: 49,
 13.5: 50,
 14.0: 51,
 14.25: 52,
 14.5: 53,
 14.75: 54,
 15.0: 55,
 15.5: 56,
 16.0: 57,
 16.25: 58,
 16.5: 59,
 16.75: 60,
 17.0: 61,
 17.25: 62,
 17.5: 63,
 17.75: 64,
 18.0: 65,
 18.25: 66,
 18.5: 67,
 19.0: 68,
 19.25: 69,
 19.5: 70,
 19.75: 71,
 20.0: 72,
 20.25: 73,
 20.5: 74,
 20.75: 75,
 21.0: 76,
 21.25: 77,
 21.5: 78,
 21.75: 79,
 22.0: 80,
 22.25: 81,
 22.5: 82,
 22.75: 83,
 23.0: 84,
 23.5: 85,
 23.75: 86,
 24.0: 87,
 24.25: 88,
 24.5: 89,
 25.0: 90,
 

In [27]:
model.converter(sel_out)

[[4, 63.0, 3.0, 3.0, 'weak', 2],
 [4, 58.0, 1.5, 4.5, 'weak', 2],
 [4, 60.0, 1.5, 6.0, 'weak', 2],
 [4, 58.0, 1.5, 7.5, 'weak', 2],
 [4, 56.0, 3.0, 10.5, 'weak', 2],
 [4, 63.0, 3.0, 13.5, 'weak', 2],
 [4, 60.0, 3.0, 16.5, 'weak', 2],
 [4, 63.0, 3.0, 19.5, 'weak', 2],
 [4, 60.0, 3.0, 22.5, 'weak', 2],
 [4, 60.0, 3.0, 25.5, 'weak', 2],
 [4, 63.0, 3.0, 28.5, 'weak', 2],
 [4, 60.0, 3.0, 1.5, 'weak', 3]]