In [None]:
# 180K step model works really well. But is the generation result copied from training data?
# Let's check it.
# Here we extract melody from all songs

import os
import sys
sys.path.append('/home/longshen/work/AccGen/AccGen')
from remi_z import MultiTrack
from sonata_utils import ls, save_json
from tqdm import tqdm

data_dir = '/data1/longshen/Datasets/Piano/POP909/pop909_longshen/data_key_normed'
song_ids = ls(data_dir)
midi_fps = [os.path.join(data_dir, song_id, f'{song_id}.mid') for song_id in song_ids]

res = {}
pbar = tqdm(midi_fps)
for midi_fp in pbar:
    pbar.set_description(os.path.basename(os.path.dirname(midi_fp)))
    # print(f'Processing {midi_fp}...')
    mt = MultiTrack.from_midi(midi_fp)
    mt.filter_tracks(insts=[13])
    mel_remiz = mt.to_remiz_str()
    song_id = os.path.basename(os.path.dirname(midi_fp))
    res[song_id] = mel_remiz

save_dir = '/data1/longshen/Datasets/Piano/POP909/remi_z'
save_fp = os.path.join(save_dir, 'melody_remiz_dict.json')
save_json(res, save_fp)

518:  57%|█████▋    | 491/857 [00:17<00:12, 28.31it/s]


KeyboardInterrupt: 

In [3]:
import os
import sys
sys.path.append('/home/longshen/work/AccGen/AccGen')
from remi_z import MultiTrack
from sonata_utils import ls, save_json, read_json
from tqdm import tqdm

from jiwer import wer as jiwer_wer

save_dir = '/data1/longshen/Datasets/Piano/POP909/remi_z'
save_fp = os.path.join(save_dir, 'melody_remiz_dict.json')
res = read_json(save_fp)

out_fp = '/data1/longshen/Results/AccGenResults/diffusion_prior/phr_latent/full_song/64dim_latent/lencond/tb_logs/version_0/val_gen_midis/step_200000/sample_2_bars_65.mid'
out_mt = MultiTrack.from_midi(out_fp)
mel_program_id = 87
out_mt.filter_tracks(insts=[mel_program_id])
out_remiz = out_mt.to_remiz_str()

wers = {}
pbar = tqdm(res.keys())
for song_id in pbar:
    ref_remiz = res[song_id]
    pbar.set_description(f'Computing WER for {song_id}')
    error_rate = jiwer_wer(ref_remiz, out_remiz)
    wers[song_id] = error_rate

# Get key and value of the minimum WER, top 5
sorted_wers = sorted(wers.items(), key=lambda x: x[1])
print("Top 5 closest matches (lowest WER):")
for i in range(5):
    song_id, error_rate = sorted_wers[i]
    print(f"{i+1}. Song ID: {song_id}, WER: {error_rate:.4f}")

mem_rate = 1 - sorted_wers[0][1]
print(f'Melody Memory Rate: {mem_rate:.4f}')

from evaluation.metrics import Metric
metric = Metric()
top_2_ratio = top_2_ratio = metric.top_two_ratio(out_mt)
print(f'Top 2 Ratio: {top_2_ratio:.4f}')



Computing WER for 909: 100%|██████████| 857/857 [00:00<00:00, 983.37it/s] 


Top 5 closest matches (lowest WER):
1. Song ID: 332, WER: 0.6715
2. Song ID: 221, WER: 0.6988
3. Song ID: 294, WER: 0.7014
4. Song ID: 291, WER: 0.7053
5. Song ID: 851, WER: 0.7053
Melody Memory Rate: 0.3285
Top 2 Ratio: 0.9980


In [4]:
# Test new metric functions in metrics.py
metric = Metric()
out_mt = MultiTrack.from_midi(out_fp)
res = metric.calculate_memorization_metrics_single(out_mt, mel_program_id=87)
res

{'mem_rate': 0.3289085545722714,
 'top2_ratio': 0.9932153392330383,
 'new_sample_flag': 1}

In [15]:
# Calculate metrics for a bunch of MIDI under a directory

out_dir = '/data1/longshen/Results/AccGenResults/diffusion_prior/phr_latent/full_song/64dim_latent/sec_permute/tb_logs/version_0/test_gen_midis/step_200000'

import os
import sys
sys.path.append('/home/longshen/work/AccGen/AccGen')
from remi_z import MultiTrack
from sonata_utils import ls, save_json, read_json
from tqdm import tqdm
from evaluation.metrics import Metric

from jiwer import wer as jiwer_wer

metric = Metric()

all_metrics = {}

out_fns = ls(out_dir, ext='.mid')
for out_fn in out_fns:
    out_fp = os.path.join(out_dir, out_fn)
    out_mt = MultiTrack.from_midi(out_fp)
    
    insts = out_mt.get_unique_insts()
    if 87 in insts:
        out_mt.change_instrument(87, 13)
        out_mt.shift_pitch(-12, 87)

    mel_program_id = 13
    out_mt.filter_tracks(insts=[mel_program_id])
    res = metric.calculate_memorization_metrics_single(out_mt, mel_program_id=mel_program_id)
    print(f'Results for {out_fn}: {res}')

    for k, v in res.items():
        if k not in all_metrics:
            all_metrics[k] = []
        all_metrics[k].append(v)

for k in all_metrics.keys():
    all_metrics[k] = sum(all_metrics[k]) / len(all_metrics[k])

print('Final averaged results:')
print(all_metrics)


Results for sample_0_bars_127.mid: {'mem_rate': 0.36221498371335503, 'top2_ratio': 0.970721125755291, 'new_sample_flag': 1}
Results for sample_10_bars_128.mid: {'mem_rate': 0.3568136932192232, 'top2_ratio': 0.9885257900628606, 'new_sample_flag': 1}
Results for sample_11_bars_125.mid: {'mem_rate': -0.04009762900976299, 'top2_ratio': 0.7906492495417833, 'new_sample_flag': 1}
Results for sample_12_bars_128.mid: {'mem_rate': 0.2997319034852547, 'top2_ratio': 0.9645369969783786, 'new_sample_flag': 1}
Results for sample_13_bars_127.mid: {'mem_rate': 0.1879358437935844, 'top2_ratio': 0.8923239182237755, 'new_sample_flag': 1}
Results for sample_14_bars_125.mid: {'mem_rate': 0.29836956521739133, 'top2_ratio': 0.9953814393634975, 'new_sample_flag': 1}
Results for sample_15_bars_124.mid: {'mem_rate': 0.3352697095435685, 'top2_ratio': 0.999326074243226, 'new_sample_flag': 1}
Results for sample_16_bars_122.mid: {'mem_rate': 0.392507204610951, 'top2_ratio': 0.9457942725348997, 'new_sample_flag': 1}
