In [1]:
%cd /home/dongmin/23FW-NCG/jeongganbo-omr
%load_ext autoreload
%autoreload 2

/home/dongmin/userdata/jeongganbo-omr


In [2]:
from time import time
import csv
import re
import glob
import json
from random import randint, choice, uniform, seed
from collections import defaultdict
import pickle
from pathlib import Path

from tqdm import tqdm
import matplotlib.pyplot as plt

from omegaconf import OmegaConf, DictConfig

import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader

from exp_utils import JeongganSynthesizer, get_img_paths
from exp_utils.model_zoo import OMRModel, TransformerOMR
from exp_utils.train_utils import Trainer, Dataset, Tokenizer, pad_collate, LabelStudioDataset, draw_low_confidence_plot

dprint = lambda d: print(json.dumps(d, indent=2))

  from .autonotebook import tqdm as notebook_tqdm


## pipline mock-up

### load model and inference

In [3]:
# random seed setting
project_root_dir = Path('.')
output_dir = project_root_dir / 'outputs'

dir_date = '2024-05-04'
dir_time = '20-31-38'

exp_dir = output_dir / dir_date / dir_time

conf = OmegaConf.load(exp_dir / '.hydra' / 'config.yaml')
device = torch.device(conf.general.device)

print('\ndata_set loading...')

note_img_path_dict = get_img_paths(project_root_dir / 'test/synth/src', ['notes', 'symbols'])

test_set = LabelStudioDataset(project_root_dir / conf.data_path.test, project_root_dir / 'jeongganbo-png/splited-pngs', remove_borders=conf.test_setting.remove_borders, is_valid=True)
test_loader = DataLoader(test_set, batch_size=1000, shuffle=False, collate_fn=pad_collate, num_workers=conf.dataloader.num_workers_load)

print('COMPLETE: data_set loading')

model_dir = exp_dir / 'model'

print('\nload tokenizer')

tokenizer_vocab_fn = model_dir / f'{conf.general.model_name}_tokenizer.txt'
tokenizer = Tokenizer(vocab_txt_fn=tokenizer_vocab_fn)

print('COMPLETE: load tokenizer')

print('\nmodel initializing...')
model = TransformerOMR(conf.model.dim, len(tokenizer.vocab), enc_depth=conf.model.enc_depth, dec_depth=conf.model.dec_depth, num_heads=conf.model.num_heads, dropout=conf.model.dropout)
model.load_state_dict(torch.load(model_dir / f'{conf.general.model_name}_HL_{conf.test_setting.target_metric}_best.pt', map_location='cpu')['model'])

tester = Trainer(model, 
                 None, #optimizer
                 None, #loss_fn
                 None, #train_loader
                 test_loader, 
                 tokenizer,
                 device=device, 
                 scheduler=None,
                 aux_loader=None,
                 aux_freq=None,
                 mix_aux=None,
                 aux_valid_loader=None,
                 wandb=None, 
                 model_name=conf.general.model_name,
                 model_save_path=model_dir,
                 checkpoint_logger=None)

print('COMPLETE: model initializing')

print('\nTesting...')

_, _, pred_tensor_list, confidence_tensor_list = tester.validate(with_confidence=True)

print('COMPELETE: Testing')


data_set loading...
COMPLETE: data_set loading

load tokenizer
COMPLETE: load tokenizer

model initializing...
COMPLETE: model initializing

Testing...


                                             

COMPELETE: Testing




### process test result

In [4]:
pred_list = []

for b_idx in range(len(pred_tensor_list)):
  pred_list += pred_tensor_list[b_idx].tolist()

pred_list = list( enumerate(pred_list) )

In [5]:
filter_specials = lambda seq: list( map(int, filter( lambda x: x not in (0, 1, 2), seq )) )

def calc_acc(gt, prd):
  gt, prd = torch.tensor(gt, dtype=torch.long), torch.tensor(prd, dtype=torch.long)
  
  if prd.shape[0] > gt.shape[0]:
    prd = prd[:gt.shape[0]]
  elif prd.shape[0] < gt.shape[0]:
    prd = torch.cat([prd, torch.zeros(gt.shape[0] - prd.shape[0])], dim=-1)
  
  num_tokens = gt.shape[0]
  num_match_token = prd == gt
  num_match_token = num_match_token[gt != 0].sum().item()
  
  return num_match_token / num_tokens

def make_gt_pred_tuple_list(ls):
  tup_list = []
  
  for didx, pred in ls:
    img, label = test_set.get_item_by_idx(didx)
    
    label_enc = tokenizer(label)[1:-1]
    pred = filter_specials(pred)
    
    acc = calc_acc(label_enc, pred)
    
    tup_list.append( (didx, img, label_enc, pred, acc) )
  
  return tup_list

pred_list = make_gt_pred_tuple_list(pred_list)

In [6]:
pred_list_long = list( filter( lambda x: len(x[2]) > 2, pred_list ) )

len(pred_list_long)

1281

In [7]:
pred_list_long = sorted(pred_list_long, key=lambda x: x[4])

### pick test cases

In [8]:
seed(1)

test_ls = []

for i in range(11):
  f = [t for t in pred_list_long if t[4] < (i + 1)*0.1 and t[4] > i*0.1]
  test_ls += [choice(f)] if len(f) > 0 else []

len(test_ls)

10

In [9]:
list(zip([1, 5, 7], [2, 10, 14]))

[(1, 2), (5, 10), (7, 14)]

In [10]:
test_ls

[(402,
  array([[[255, 255, 255],
          [255, 255, 255],
          [255, 255, 255],
          ...,
          [255, 255, 255],
          [255, 255, 255],
          [255, 255, 255]],
  
         [[255, 255, 255],
          [255, 255, 255],
          [255, 255, 255],
          ...,
          [255, 255, 255],
          [255, 255, 255],
          [255, 255, 255]],
  
         [[255, 255, 255],
          [255, 255, 255],
          [254, 254, 254],
          ...,
          [254, 254, 254],
          [255, 255, 255],
          [255, 255, 255]],
  
         ...,
  
         [[255, 255, 255],
          [255, 255, 255],
          [254, 254, 254],
          ...,
          [254, 254, 254],
          [255, 255, 255],
          [255, 255, 255]],
  
         [[255, 255, 255],
          [255, 255, 255],
          [255, 255, 255],
          ...,
          [255, 255, 255],
          [255, 255, 255],
          [255, 255, 255]],
  
         [[255, 255, 255],
          [255, 255, 255],
          [255, 2

## New metrics
* NER: note error rate
* PER: position error rate
* NPER: note-position pair error rate

In [11]:
list( map(lambda et: (et[0], et[1]), enumerate([1, 2, 3, 4]) ))

[(0, 1), (1, 2), (2, 3), (3, 4)]

In [12]:
'느나니'[0]

'느'

In [13]:
def get_notes_and_positions(label):
  pattern = r'([^_\s:]+|_+[^_\s:]+|:\d+|[-])'
  
  def clean_findings(fds):
    int_idx = []
    
    for i, t in enumerate(fds):
      try:
        int(t[1:]) # if str p cannot be cast into integer goto except block
        int_idx.append(i)
      except:
        pass
    
    if len(int_idx) < 1:
      return [fds]
    
    split_fds = []
    start_idx = 0
    
    for i in int_idx:
      end_idx = i + 1
      split_fds.append(fds[start_idx:end_idx])
      start_idx = end_idx
    
    return split_fds

  token_groups = label.split()

  notes = []
  positions = []
  ornaments = []

  for group in token_groups:
    findings = re.findall(pattern, group)
    findings = clean_findings(findings)
    
    for fd in findings:
      note = fd[0]
      position = fd[-1]
      ornament = fd[1:-1]
      
      if note[0] == '_':
        ornament = [note] + ornament
        note = '<pad>'
        
      notes.append(note)
      positions.append(position[1:])
      ornaments.append(ornament)

  return notes, positions, ornaments

for test_case in test_ls:
  label, pred = test_case[2:4]
  
  label_dec = tokenizer.decode( label )
  pred_dec = tokenizer.decode( pred )
  
  label_note, label_pos, label_ornament = get_notes_and_positions(label_dec)
  pred_note, pred_pos, pred_ornament = get_notes_and_positions(pred_dec)
  
  label_note_pos = list( zip(label_note, label_pos) )
  pred_note_pos = list( zip(pred_note, pred_pos) )
  
  map_ornamnet_pos = lambda ol, pl: [(o, p) for i, p in enumerate(pl) for o in ol[i]]
  map_ornament_note_pos = lambda ol, tl: [ (o, *t) for i, t in enumerate(tl) for o in ol[i]]
  
  print(label_dec)
  print(map_ornamnet_pos(label_ornament, label_pos))
  print(map_ornament_note_pos(label_ornament, label_note_pos))
  
  print()

청남_노네:2 -:4 니나*:6 청임:8
[('_노네', '2')]
[('_노네', '청남', '2')]

고_덧길이표:10 태_반길이표:11
[('_덧길이표', '10'), ('_반길이표', '11')]
[('_덧길이표', '고', '10'), ('_반길이표', '태', '11')]

황:2 -:5 태_니레:8
[('_니레', '8')]
[('_니레', '태', '8')]

황_루러표:5
[('_루러표', '5')]
[('_루러표', '황', '5')]

노:2 -:4 니_미는표:6 느나:8
[('_미는표', '6')]
[('_미는표', '니', '6')]

황:2 -_미는표:5 배중:8
[('_미는표', '5')]
[('_미는표', '-', '5')]

황_자출:5
[('_자출', '5')]
[('_자출', '황', '5')]

태:2 -:5 -:7 흘림표_니:9
[('_니', '9')]
[('_니', '흘림표', '9')]

-:2 -_미는표:5 태:8
[('_미는표', '5')]
[('_미는표', '-', '5')]

배임:2 -:4 태:6 배임:7 황_루러표:9
[('_루러표', '9')]
[('_루러표', '황', '9')]



In [14]:
test_lbs = [
  '임_노니로:1 중:3 -:5 무_느니-르_노니로:8_느니-르_노니로:8',
  '황:1 태:3 -:5 -:7 배임_노니로:9_노니로:9',
  '황:1 태:3 -_흘림표:5 -:8 니:9니:9',
  '태:1 황:3 -:4 니:6 배남_흘림표:8_흘림표:8',
]

for lb in test_lbs:
  print(lb)
  print(get_notes_and_positions(lb))

# print(test_lbs[2])
# get_notes_and_positions(test_lbs[2])

임_노니로:1 중:3 -:5 무_느니-르_노니로:8_느니-르_노니로:8
(['임', '중', '-', '무', '<pad>'], ['1', '3', '5', '8', '8'], [['_노니로'], [], [], ['_느니-르', '_노니로'], ['_느니-르', '_노니로']])
황:1 태:3 -:5 -:7 배임_노니로:9_노니로:9
(['황', '태', '-', '-', '배임', '<pad>'], ['1', '3', '5', '7', '9', '9'], [[], [], [], [], ['_노니로'], ['_노니로']])
황:1 태:3 -_흘림표:5 -:8 니:9니:9
(['황', '태', '-', '-', '니', '니'], ['1', '3', '5', '8', '9', '9'], [[], [], ['_흘림표'], [], [], []])
태:1 황:3 -:4 니:6 배남_흘림표:8_흘림표:8
(['태', '황', '-', '니', '배남', '<pad>'], ['1', '3', '4', '6', '8', '8'], [[], [], [], [], ['_흘림표'], ['_흘림표']])


In [15]:
def filter_pos_list(pl):
  pl_f = []
  
  for p in pl:
    try:
      int(p) # if str p cannot be cast into integer goto except block
      pl_f.append(p)
    except:
      pass
  
  return pl_f


def tok2idx_list(tl, is_pos=False):
  prefix = ':' if is_pos else '' 
  return [ tokenizer.tok2idx[prefix + t] for t in tl ]

def calc_metric(gt, prd): # +) [꾸밈](GT 꾸밈 0개 이상일 때), [꾸밈, 위치], [본음, 꾸밈, 위치]
  gt_tok_dict = defaultdict(int)
  
  for tok in gt:
    gt_tok_dict[tok] += 1
  
  num_match = 0
  
  for tok in prd:
    if gt_tok_dict[tok] > 0:
      num_match += 1
      gt_tok_dict[tok] -= 1
  
  eps = 1e-8
  
  gt_len = len( [ tok for tok in gt if tok != '<pad>' or tok != 0 ] )
  recall = num_match / gt_len if gt_len > 0 else 0
  
  precision = num_match / len(prd) if len(prd) > 0 else 0
  f1 = 2 * (precision * recall) / (precision + recall + eps)
  
  return precision, recall, f1

def calc_all_metrics(gt, prd, encode=True):
  gt_dec = tokenizer.decode( gt )
  prd_dec = tokenizer.decode( prd )

  gt_note, gt_pos, gt_ornament = get_notes_and_positions( gt_dec )
  prd_note, prd_pos, prd_ornament = get_notes_and_positions( prd_dec )
  
  gt_pos, prd_pos = ( filter_pos_list(pl) for pl in (gt_pos, prd_pos) )
  
  if encode:
    gt_note, prd_note = ( tok2idx_list(nl) for nl in ( gt_note, prd_note ) )
    gt_pos, prd_pos = ( tok2idx_list(pl, is_pos=True) for pl in ( gt_pos, prd_pos ) )
    gt_ornament, prd_ornament = ( [ tok2idx_list(osl) for osl in ol ] for ol in ( gt_ornament, prd_ornament ))
  
  gt_ornament_flat, prd_ornament_flat = sum(gt_ornament, []), sum(prd_ornament, [])

  gt_note_pos = list( zip(gt_note, gt_pos) )
  prd_note_pos = list( zip(prd_note, prd_pos) )
  
  map_ornament_pos = lambda ol, pl: [(o, p) for i, p in enumerate(pl) for o in ol[i]]
  map_ornament_note_pos = lambda ol, tl: [ (o, *t) for i, t in enumerate(tl) for o in ol[i]]
  
  gt_ornament_pos = map_ornament_pos(gt_ornament, gt_pos)
  prd_ornament_pos = map_ornamnet_pos(prd_ornament, prd_pos)
  
  gt_ornament_note_pos = map_ornament_note_pos(gt_ornament, gt_note_pos)
  prd_ornament_note_pos = map_ornament_note_pos(prd_ornament, prd_note_pos)
  
  ner = calc_metric(gt_note, prd_note)
  per = calc_metric(gt_pos, prd_pos)
  nper = calc_metric(gt_note_pos, prd_note_pos)
  
  oer = calc_metric(gt_ornament_flat, prd_ornament_flat)
  oper = calc_metric(gt_ornament_pos, prd_ornament_pos)
  onper = calc_metric(gt_ornament_note_pos, prd_ornament_note_pos)
  
  return ((gt_note, gt_pos, gt_ornament, gt_note_pos, gt_ornament_pos, gt_ornament_note_pos), (prd_note, prd_pos, prd_ornament, prd_note_pos, prd_ornament_pos, prd_ornament_note_pos)), (ner, per, nper, oer, oper, onper)

test_index = 1
label, pred = test_ls[test_index][2:4]

((label_note, label_pos, label_ornament, label_note_pos, label_ornament_pos, label_ornament_note_pos), (pred_note, pred_pos, pred_ornament, pred_note_pos, pred_ornament_pos, pred_ornament_note_pos)), (ner, per, nper, oer, oper, onper) = calc_all_metrics(label, pred, encode=False)

print(calc_metric(label_note, pred_note), label_note, pred_note)
print(calc_metric(label_pos, pred_pos), label_pos, pred_pos)
print(calc_metric(label_note_pos, pred_note_pos), label_note_pos, pred_note_pos)
print(calc_metric(sum(label_ornament, []), sum(pred_ornament, [])), sum(label_ornament, []), sum(pred_ornament, []))
print(calc_metric(label_ornament_pos, pred_ornament_pos), label_ornament_pos, pred_ornament_pos)
print(calc_metric(label_ornament_note_pos, pred_ornament_note_pos), label_ornament_note_pos, pred_ornament_note_pos)

(1.0, 1.0, 0.999999995) ['고', '태'] ['고', '태']
(1.0, 1.0, 0.999999995) ['10', '11'] ['10', '11']
(1.0, 1.0, 0.999999995) [('고', '10'), ('태', '11')] [('고', '10'), ('태', '11')]
(1.0, 0.5, 0.6666666622222223) ['_덧길이표', '_반길이표'] ['_반길이표']
(1.0, 0.5, 0.6666666622222223) [('_덧길이표', '10'), ('_반길이표', '11')] [('_반길이표', '11')]
(1.0, 0.5, 0.6666666622222223) [('_덧길이표', '고', '10'), ('_반길이표', '태', '11')] [('_반길이표', '태', '11')]


## Test new metrics

### test: with decoded tokens

In [16]:
for idx, test_case in enumerate(test_ls):
  label, pred = test_case[2:4]

  label_dec = tokenizer.decode( label )
  pred_dec = tokenizer.decode( pred )

  ((label_note, label_pos, label_ornament, label_note_pos, label_ornament_pos, label_ornament_note_pos), (pred_note, pred_pos, pred_ornament, pred_note_pos, pred_ornament_pos, pred_ornament_note_pos)), (ner, per, nper, oer, oper, onper) = calc_all_metrics(label, pred, encode=False)
  
  round_tuple_el = lambda t: tuple(map(lambda x: round(x, 4), t))
  
  ner, per, nper, oer, oper, onper = ( round_tuple_el(mv) for mv in (ner, per, nper, oer, oper, onper) )

  print(f'#{idx} {round(test_case[4], 4)}')
  print(label_dec)
  print(pred_dec)
  print(f'NER: {ner}')
  print(label_note)
  print(pred_note)
  print(f'PER: {per}')
  print(label_pos)
  print(pred_pos)
  print(f'NPER: {nper}')
  print(label_note_pos)
  print(pred_note_pos)
  print(f'OER: {oer}')
  print(label_ornament)
  print(pred_ornament)
  print(f'OPER: {oper}')
  print(label_ornament_pos)
  print(pred_ornament_pos)
  print(f'ONPER: {onper}')
  print(label_ornament_note_pos)
  print(pred_ornament_note_pos)
  print()

#0 0.0833
청남_노네:2 -:4 니나*:6 청임:8
청남:2 -:4 니나*:6 청임:8
NER: (1.0, 1.0, 1.0)
['청남', '-', '니나*', '청임']
['청남', '-', '니나*', '청임']
PER: (1.0, 1.0, 1.0)
['2', '4', '6', '8']
['2', '4', '6', '8']
NPER: (1.0, 1.0, 1.0)
[('청남', '2'), ('-', '4'), ('니나*', '6'), ('청임', '8')]
[('청남', '2'), ('-', '4'), ('니나*', '6'), ('청임', '8')]
OER: (0, 0.0, 0.0)
[['_노네'], [], [], []]
[[], [], [], []]
OPER: (0, 0.0, 0.0)
[('_노네', '2')]
[]
ONPER: (0, 0.0, 0.0)
[('_노네', '청남', '2')]
[]

#1 0.1429
고_덧길이표:10 태_반길이표:11
고:10 태_반길이표:11
NER: (1.0, 1.0, 1.0)
['고', '태']
['고', '태']
PER: (1.0, 1.0, 1.0)
['10', '11']
['10', '11']
NPER: (1.0, 1.0, 1.0)
[('고', '10'), ('태', '11')]
[('고', '10'), ('태', '11')]
OER: (1.0, 0.5, 0.6667)
[['_덧길이표'], ['_반길이표']]
[[], ['_반길이표']]
OPER: (1.0, 0.5, 0.6667)
[('_덧길이표', '10'), ('_반길이표', '11')]
[('_반길이표', '11')]
ONPER: (1.0, 0.5, 0.6667)
[('_덧길이표', '고', '10'), ('_반길이표', '태', '11')]
[('_반길이표', '태', '11')]

#2 0.2222
황:2 -:5 태_니레:8
황:10 태_니레:11
NER: (1.0, 0.6667, 0.8)
['황', '-', '태']
['황', '태']
PER: (0

### test: with encoded indices

In [18]:
for idx, test_case in enumerate(test_ls):
  label, pred = test_case[2:4]

  label_dec = tokenizer.decode( label )
  pred_dec = tokenizer.decode( pred )

  ((label_note, label_pos, label_ornament, label_note_pos, label_ornament_pos, label_ornament_note_pos), (pred_note, pred_pos, pred_ornament, pred_note_pos, pred_ornament_pos, pred_ornament_note_pos)), (ner, per, nper, oer, oper, onper) = calc_all_metrics(label, pred, encode=True)
  
  round_tuple_el = lambda t: tuple(map(lambda x: round(x, 4), t))
  
  ner, per, nper, oer, oper, onper = ( round_tuple_el(mv) for mv in (ner, per, nper, oer, oper, onper) )

  print(f'#{idx} {round(test_case[4], 4)}')
  print(label_dec)
  print(pred_dec)
  print(f'NER: {ner}')
  print(label_note)
  print(pred_note)
  print(f'PER: {per}')
  print(label_pos)
  print(pred_pos)
  print(f'NPER: {nper}')
  print(label_note_pos)
  print(pred_note_pos)
  print(f'OER: {oer}')
  print(label_ornament)
  print(pred_ornament)
  print(f'OPER: {oper}')
  print(label_ornament_pos)
  print(pred_ornament_pos)
  print(f'ONPER: {onper}')
  print(label_ornament_note_pos)
  print(pred_ornament_note_pos)
  print()

#0 0.0833
청남_노네:2 -:4 니나*:6 청임:8
청남:2 -:4 니나*:6 청임:8
NER: (1.0, 1.0, 1.0)
[119, 4, 89, 121]
[119, 4, 89, 121]
PER: (1.0, 1.0, 1.0)
[12, 14, 16, 18]
[12, 14, 16, 18]
NPER: (1.0, 1.0, 1.0)
[(119, 12), (4, 14), (89, 16), (121, 18)]
[(119, 12), (4, 14), (89, 16), (121, 18)]
OER: (0, 0.0, 0.0)
[[34], [], [], []]
[[], [], [], []]
OPER: (0, 0.0, 0.0)
[(34, 12)]
[]
ONPER: (0, 0.0, 0.0)
[(34, 119, 12)]
[]

#1 0.1429
고_덧길이표:10 태_반길이표:11
고:10 태_반길이표:11
NER: (1.0, 1.0, 1.0)
[76, 127]
[76, 127]
PER: (1.0, 1.0, 1.0)
[6, 7]
[6, 7]
NPER: (1.0, 1.0, 1.0)
[(76, 6), (127, 7)]
[(76, 6), (127, 7)]
OER: (1.0, 0.5, 0.6667)
[[52], [58]]
[[], [58]]
OPER: (1.0, 0.5, 0.6667)
[(52, 6), (58, 7)]
[(58, 7)]
ONPER: (1.0, 0.5, 0.6667)
[(52, 76, 6), (58, 127, 7)]
[(58, 127, 7)]

#2 0.2222
황:2 -:5 태_니레:8
황:10 태_니레:11
NER: (1.0, 0.6667, 0.8)
[140, 4, 127]
[140, 127]
PER: (0.0, 0.0, 0.0)
[12, 15, 18]
[6, 7]
NPER: (0.0, 0.0, 0.0)
[(140, 12), (4, 15), (127, 18)]
[(140, 6), (127, 7)]
OER: (1.0, 1.0, 1.0)
[[], [], [49]]
[[], 

In [19]:
for q, v in enumerate( (1, 2, 3, 4) ):
  print(q, v)

0 1
1 2
2 3
3 4


In [20]:
temp = [1, 2, 3]
temp *= 3

temp[0] = 4

temp

[4, 2, 3, 1, 2, 3, 1, 2, 3]

## Test Pipeline

In [23]:
def test(project_root_dir, exp_dir, csv_path):
  conf = OmegaConf.load(exp_dir / '.hydra' / 'config.yaml')
  device = torch.device(conf.general.device)

  print('\n...data_set loading...')
  test_set = LabelStudioDataset(project_root_dir / conf.data_path.test, project_root_dir / 'jeongganbo-png/splited-pngs', remove_borders=conf.test_setting.remove_borders, is_valid=True)
  test_loader = DataLoader(test_set, batch_size=1000, shuffle=False, collate_fn=pad_collate, num_workers=conf.dataloader.num_workers_load)

  print('COMPLETE: data_set loading')

  model_dir = exp_dir / 'model'

  print('\n...tokenizer loading...')

  tokenizer_vocab_fn = model_dir / f'{conf.general.model_name}_tokenizer.txt'
  tokenizer = Tokenizer(vocab_txt_fn=tokenizer_vocab_fn)
  
  test_set.tokenizer = tokenizer

  print('COMPLETE: load tokenizer')

  print('\n...model initializing...')
  model = TransformerOMR(conf.model.dim, len(tokenizer.vocab), enc_depth=conf.model.enc_depth, dec_depth=conf.model.dec_depth, num_heads=conf.model.num_heads, dropout=conf.model.dropout)
  model.load_state_dict(torch.load(model_dir / f'{conf.general.model_name}_HL_{conf.test_setting.target_metric}_best.pt', map_location='cpu')['model'])

  tester = Trainer(model, 
                  None, #optimizer
                  None, #loss_fn
                  None, #train_loader
                  test_loader, 
                  tokenizer,
                  device=device, 
                  scheduler=None,
                  aux_loader=None,
                  aux_freq=None,
                  mix_aux=None,
                  aux_valid_loader=None,
                  wandb=None, 
                  model_name=conf.general.model_name,
                  model_save_path=model_dir,
                  checkpoint_logger=None)

  print('COMPLETE: model initializing')

  print('\n...testing...')

  _, metric_dict, pred_tensor_list, _ = tester.validate(with_confidence=True)
  
  print(metric_dict)
  print(conf.test_setting.target_metric)

  print('COMPLETE: Testing')
  
  print('\n...processing test result...')
  
  pred_list = []

  for b_idx in range(len(pred_tensor_list)):
    pred_list += pred_tensor_list[b_idx].tolist()

  pred_list = list( enumerate(pred_list) )
  pred_list = make_gt_pred_tuple_list(pred_list)
  
  print('COMPLETE: processing test result')
  
  print('\n...calculating metrics...')
  
  num_total_case = 0
  num_metrics = 6
  metric_ls = [ [] for _ in range(6) ] # ner, per, nper, oer, oper, onper
  
  for case in pred_list:
    label, pred = case[2:4]
    
    _, new_metric = calc_all_metrics(label, pred, encode=True)
    
    for m_idx, nmt in enumerate(new_metric):
      metric_ls[m_idx].append(nmt)    
    
    num_total_case += 1
  
  print(num_total_case)
  print(len(metric_ls)) # should be 6
  print(len(metric_ls[1])) # should be the same as num_total_case
  print(len(metric_ls[0][0])) # 3
  
  metric_values = []
  
  for mvl in metric_ls:
    avg_ls = []
    
    for i in range(3):
      [mt[i] for mt in mvl]
      avg_ls.append( sum([mt[i] for mt in mvl]) / num_total_case )
    
    metric_values.append(tuple(avg_ls))
  
  exact_all = metric_dict[conf.test_setting.target_metric]
  
  print([ mv[2] for mv in metric_values ])
  
  with open(csv_path, 'a', newline='', encoding='utf-8') as f:
    writer = csv.writer(f)
    # writer.writerow([conf.general.model_name, metric_dict[conf.test_setting.target_metric]] + metric_values)
    writer.writerow( [conf.general.model_name, exact_all] + [mv[2] for mv in metric_values] )
  
  print('COMPLETE: calculating metrics')

project_root_dir = Path('.')
output_dir = project_root_dir / 'outputs'
dir_date = '2024-05-04'
dir_time = '20-31-38'

exp_dir = output_dir / dir_date / dir_time  

csv_path = project_root_dir / 'test' / 'new_metric_test.csv'

pred_list = test(project_root_dir, exp_dir, csv_path)


...data_set loading...
COMPLETE: data_set loading

...tokenizer loading...
COMPLETE: load tokenizer

...model initializing...
COMPLETE: model initializing

...testing...


                                             

{'exact_all': 0.8610526315789473, 'exact_pos': 0.9450827067669173, 'exact_note': 0.9392631578947368, 'exact_length': 0.9500225563909774, 'note_pitch': 0.9958318735678708, 'note_pcalss': 0.9977233327777721}
exact_all
COMPLETE: Testing

...processing test result...
COMPLETE: processing test result

...calculating metrics...
1532
6
1532
3
[0.9858388950245259, 0.985482524532533, 0.9781987110326292, 0.5130657067076098, 0.49333115500794245, 0.4911553499724041]
COMPLETE: calculating metrics


In [143]:
temp = []

for i in range(3):
  temp.append((1, 2, 3))

len(temp)

3

In [161]:
[()] * 6

[(), (), (), (), (), ()]

In [28]:
num_total_case = 0
num_metrics = 6
metric_ls = [()] * num_metrics # ner, per, nper, oer, oper, onper

for case in pred_list[:2]:
  label, pred = case[2:4]
  
  print(tokenizer.decode(label))
  print(tokenizer.decode(pred))
  
  _, new_metric = calc_all_metrics(label, pred, encode=True)
  print(new_metric)
  
  for m_idx in range(num_metrics):
    m_t = new_metric[m_idx]
    m_ls = list(metric_ls[m_idx])
    m_ls += [m_t]
    
    metric_ls[m_idx] = tuple(m_ls)
  
  num_total_case += 1

print(num_total_case)
print(len(metric_ls)) # should be 6
print(len(metric_ls[0])) # should be the same as num_total_case
print(len(metric_ls[0][0])) # 3
print(metric_ls[0])

metric_values = []

for mvl in metric_ls:
  avg_ls = []
  
  for i in range(3):
    [mt[i] for mt in mvl]
    avg_ls.append( sum([mt[i] for mt in mvl]) / num_total_case )
  
  metric_values.append(tuple(avg_ls))

황:5
황:5
((1.0, 1.0, 0.999999995), (1.0, 1.0, 0.999999995), (1.0, 1.0, 0.999999995), (0, 0, 0.0), (0, 0, 0.0), (0, 0, 0.0))
겹요성표:5
겹요성표:5
((1.0, 1.0, 0.999999995), (1.0, 1.0, 0.999999995), (1.0, 1.0, 0.999999995), (0, 0, 0.0), (0, 0, 0.0), (0, 0, 0.0))
2
6
2
3
((1.0, 1.0, 0.999999995), (1.0, 1.0, 0.999999995))


In [18]:
dir_dates = ['2024-05-03', '2024-05-04']
dir_times = [
  [
    '21-55-11', 
    '22-41-31', 
    '23-27-48', 
  ],
  [
    '00-14-10', 
    '01-48-46', 
    '03-22-26', 
    '04-55-03', 
    '06-29-30', 
    '08-02-58', 
    '09-35-37', 
    '11-10-10', 
    '12-43-58', 
    '14-16-32', 
    '15-50-51', 
    '17-24-27', 
    '18-57-02', 
    '20-31-38', 
    '01-01-26', 
    '02-36-04', 
    '04-08-48', 
    '05-42-17', 
    '07-16-44', 
    '08-49-17', 
    '10-22-54', 
    '11-57-33', 
    '13-30-12', 
    '15-03-40', 
    '16-38-09', 
    '18-10-43', 
    '19-44-22',     
    '20-31-38',
  ]
]

In [21]:
from test_new_metric import test as new_test

project_root_dir = Path('.')
output_dir = project_root_dir / 'outputs'
csv_path = project_root_dir / 'test' / 'new_metric_test.csv'

dir_date, dir_time = (dir_dates[1], dir_times[1][-1])

exp_dir = output_dir / dir_date / dir_time  
new_test(project_root_dir, exp_dir, csv_path)


...tokenizer loading...
COMPLETE: load tokenizer

...data_set loading...
COMPLETE: data_set loading

...model initializing...
COMPLETE: model initializing

...testing...


                                             

COMPLETE: Testing

...processing test result...
COMPLETE: processing test result

...calculating metrics...
COMPLETE: calculating metrics


In [24]:
from test_new_metric import test as new_test

project_root_dir = Path('.')
output_dir = project_root_dir / 'outputs'
csv_path = project_root_dir / 'test' / 'new_metric_test.csv'

for date_idx, dir_date in enumerate(dir_dates):
  for dir_time in dir_times[date_idx]:
    
    exp_dir = output_dir / dir_date / dir_time  
    new_test(project_root_dir, exp_dir, csv_path)

TypeError: 'enumerate' object is not subscriptable