In [1]:
%cd ..
%load_ext autoreload
%autoreload 2

/home/teo/userdata/SejongMusic


In [15]:
import torch
from pathlib import Path
from omegaconf import OmegaConf

from sejong_music.jg_code import JeongganDataset, JeongganTokenizer
from sejong_music.model_zoo import JeongganTransSeq2seq
from sejong_music.inference import JGInferencer
from sejong_music.jg_to_staff_converter import JGToStaffConverter, BeatToGenConverter

checkpoint_dir = Path('wandb/run-20240410_181816-zoguj9tp/files/checkpoints')
config_path = checkpoint_dir / 'config.yaml'
vocab_path = checkpoint_dir / 'tokenizer_vocab.json'
model_path = checkpoint_dir / 'inst_0/iter17400_model.pt'

config = OmegaConf.load(config_path)
dataset = JeongganDataset(is_valid=True, use_offset=True)
tokenizer = JeongganTokenizer(None, None, json_fn=vocab_path)
dataset.tokenizer = tokenizer
model = JeongganTransSeq2seq(tokenizer, config.model)
inferencer = JGInferencer(model, is_condition_shifted=True, is_orch=True, top_p=0.7, temperature=0.9)
decoder = JGToStaffConverter()

state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict)
model.cuda()
model.eval()
None

In [4]:
beat_dataset = JeongganDataset(is_valid=True, use_offset=True)
gen_dataset = JeongganDataset(is_valid=True, use_offset=False)

beat_dataset[0]

(tensor([[  1,   1,   1,   1, 112],
         [  5, 118, 132, 152, 113],
         [ 35, 120, 132, 152, 113],
         ...,
         [  3, 120, 146, 155, 116],
         [  4, 118, 147, 155, 116],
         [  2,   2,   2,   2, 112]]),
 tensor([[  1, 118, 132, 152, 112],
         [  5, 120, 132, 152, 112],
         [ 35, 120, 132, 152, 112],
         ...,
         [ 11, 126, 147, 155, 112],
         [ 74, 126, 147, 155, 112],
         [  4, 119, 132, 156, 112]]),
 tensor([  5,  35,  61,   3,   5,  34,  82,   3,  11,  85,   3,   5,  30,   3,
          11, 100,   4,   5,  35,  61,   3,   5,  34,  82,  11,  83,   3,   5,
          38,  11,  34,   3,   5,  39,  61,   3,   6,  83,  11,  38,   3,   5,
          34,  69,  11,  83,   3,   5,  38,   7, 100,  11,  38,  95,   3,   3,
           3,   3,  11,  35,   3,   5,  34,  69,   7,  38,   3,  11,  34,  69,
           3,   5,  39,  61,   3,   5,  39,  84,  11,  38,   3,   5,  34,  69,
           6,  38,  11,  42,  77,   4,   5,  38,   3,  11,  85

In [24]:
sample, tgt, shifted_tgt = dataset[10]
src, output, (attn, output_tensor, _) = inferencer.inference(sample, 'daegeum')

shifted_tgt = inferencer.beat2gen(tokenizer.decode(shifted_tgt))
shifted_tgt = [x for x in shifted_tgt if x != '']
out_decoded = [x[0] for x in output]
src_gen = inferencer.beat2gen([x[0] for x in src])

                                                   

In [25]:
sum([1 for x in out_decoded if x in ('|', '\n')]), sum([1 for x in shifted_tgt if x in ('|', '\n')])

(64, 64)

In [17]:
from sejong_music.eval_metrics import per_jg_note_acc, onset_f1
target_part_idx = 'daegeum'

jg_note_acc = per_jg_note_acc(out_decoded, shifted_tgt, inst=target_part_idx, tokenizer=inferencer.tokenizer)
f1, prec, recall = onset_f1(out_decoded, shifted_tgt, inst=target_part_idx, tokenizer=inferencer.tokenizer)

jg_note_acc, f1

AssertionError: Length of pred and gt must be the same. (384, 6) != (366, 6)

In [25]:
note_dict, stream = decoder.convert_inference_result(output, src, inferencer.tokenizer.decode(shifted_tgt))

TypeError: list indices must be integers or slices, not str

In [22]:
from sejong_music.eval_metrics import per_jg_note_acc, onset_f1
target_part_idx = 'daegeum'

jg_note_acc = per_jg_note_acc(out_decoded, shifted_tgt, inst=target_part_idx, tokenizer=inferencer.tokenizer)
f1, prec, recall = onset_f1(out_decoded, shifted_tgt, inst=target_part_idx, tokenizer=inferencer.tokenizer)


IndexError: index 360 is out of bounds for axis 0 with size 360

In [14]:
inferencer.beat2gen(tokenizer.decode(shifted_tgt)).split(' ')

[':0',
 '태',
 '나니로',
 '|',
 ':2',
 '황',
 '니나',
 ':8',
 '니나*',
 '|',
 ':2',
 '중',
 ':8',
 '황',
 '|',
 ':0',
 '임',
 '나니로',
 '|',
 ':6',
 '니나*',
 ':8',
 '중',
 '|',
 ':2',
 '황',
 '노니로',
 ':8',
 '니나*',
 '|',
 ':2',
 '중',
 ':5',
 '쉼표',
 ':8',
 '중',
 '미는표',
 '|',
 '',
 '|',
 '',
 '|',
 '',
 '|',
 ':8',
 '태',
 '|',
 ':2',
 '황',
 '노니로',
 ':5',
 '중',
 '|',
 ':8',
 '황',
 '노니로',
 '|',
 ':0',
 '임',
 '나니로',
 '|',
 ':2',
 '임',
 '니레',
 ':8',
 '중',
 '|',
 ':2',
 '황',
 '노니로',
 ':6',
 '중',
 ':8',
 '무',
 '느니-르',
 '\n',
 ':0',
 '중',
 '|',
 ':8',
 '니레나',
 '|',
 ':2',
 '황',
 '노니로',
 ':8',
 '쉼표',
 '|',
 ':3',
 '중',
 ':6',
 '니',
 ':7',
 '무',
 ':9',
 '노',
 '|',
 ':2',
 '청황',
 ':8',
 '니나*',
 '|',
 ':1',
 '임',
 '노니로',
 ':3',
 '중',
 ':5',
 '무',
 '느니-르',
 '노니로',
 '|',
 ':0',
 '중',
 '|',
 ':2',
 '황',
 '노니로',
 ':8',
 '태',
 '니레',
 '|',
 ':0',
 '배임',
 '|',
 ':2',
 '배임',
 '니나',
 ':8',
 '니나*',
 '|',
 ':2',
 '황',
 ':8',
 '니나*',
 '|',
 ':0',
 '중',
 '|',
 ':9',
 '니나*',
 '|',
 ':2',
 '황',
 '노니로',
 ':8',
 '태',
 '니레',
 '|',
 '

In [8]:
out_decoded

[[':0', 'prevbeat:0', 'jg:0', 'gak:0', 'daegeum'],
 ['태', 'prevbeat:0', 'jg:0', 'gak:0', 'daegeum'],
 ['나니로', 'prevbeat:0', 'jg:0', 'gak:0', 'daegeum'],
 ['|', 'prev|', 'jg:1', 'gak:0', 'daegeum'],
 [':2', 'prevbeat:0', 'jg:1', 'gak:0', 'daegeum'],
 ['황', 'prevbeat:0', 'jg:1', 'gak:0', 'daegeum'],
 ['니나', 'prevbeat:0', 'jg:1', 'gak:0', 'daegeum'],
 [':8', 'prevbeat:2/3', 'jg:1', 'gak:0', 'daegeum'],
 ['니나*', 'prevbeat:2/3', 'jg:1', 'gak:0', 'daegeum'],
 ['|', 'prev|', 'jg:2', 'gak:0', 'daegeum'],
 [':2', 'prevbeat:0', 'jg:2', 'gak:0', 'daegeum'],
 ['중', 'prevbeat:0', 'jg:2', 'gak:0', 'daegeum'],
 [':9', 'prevbeat:5/6', 'jg:2', 'gak:0', 'daegeum'],
 ['황', 'prevbeat:5/6', 'jg:2', 'gak:0', 'daegeum'],
 ['|', 'prev|', 'jg:3', 'gak:0', 'daegeum'],
 [':0', 'prevbeat:0', 'jg:3', 'gak:0', 'daegeum'],
 ['임', 'prevbeat:0', 'jg:3', 'gak:0', 'daegeum'],
 ['나니로', 'prevbeat:0', 'jg:3', 'gak:0', 'daegeum'],
 ['|', 'prev|', 'jg:4', 'gak:0', 'daegeum'],
 [':6', 'prevbeat:1/2', 'jg:4', 'gak:0', 'daegeum