In [39]:
import musdb
import museval
import torch as th
import numpy as np
from glob import glob
import torchaudio
import torchaudio.transforms as transforms
import os
from inference import *
from model import MusicSep

In [40]:
# Load the pre-trained model checkpoint
checkpoint = th.load('./models/unet/model_vocals_33.pt')

# Create an instance of your model
model = MusicSep() 

# Load the model's learned parameters from the checkpoint
model.load_state_dict(checkpoint)

# Set the model to evaluation mode
model.eval()

print("eval mode")

eval mode


In [41]:
def new_sdr(references, estimates):
    assert references.dim() == 4
    assert estimates.dim() == 4
    delta = 1e-7  # avoid numerical errors
    num = th.sum(th.square(references), dim=(2, 3))
    den = th.sum(th.square(references - estimates), dim=(2, 3))
    num += delta
    den += delta
    scores = 10 * th.log10(num / den)
    return scores

In [42]:
def eval_track(references, estimates, win, hop, compute_sdr=True):
    references = references.transpose(1, 2).double()
    estimates = estimates.transpose(1, 2).double()

    new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0]

    if not compute_sdr:
        return None, new_scores
    else:
        references = references.numpy()
        estimates = estimates.numpy()
        scores = museval.metrics.bss_eval(
            references, estimates,
            compute_permutation=False,
            window=win,
            hop=hop,
            framewise_filters=False,
            bsseval_sources_version=False)[:-1]
        return scores, new_scores

In [43]:
def get_torch(song_idx, idx_spec, dir = "../data/data_spect_real_513/train/vocals/"):
    paths_vocals = sorted(glob(dir + "/" + song_idx + "*.pt"))
    return th.load(paths_vocals[idx])

In [44]:
song_idx = "00100"
idx = 3
dir_v = "../data/data_spect_real_513/test/vocals"
dir_m = "../data/data_spect_real_513/test/mix"

def eval_clip(song_idx, idx, dir_gt, dir_mix):
    gt = get_torch(song_idx, idx, dir_gt)
    gt = transform2to4(gt)
    reference = get_wav(gt)
    
    mix = get_torch(song_idx, idx, dir_m)
    result = calculate(model, input_tensor = mix)
    estimate = get_wav(result)
    score = eval_track(reference.unsqueeze(0), estimate.unsqueeze(0), win = 1024, hop = 512, compute_sdr=True)
    score1 = np.array(score[0])
    score2 = np.array(score[1])
    return np.median(score1), np.median(score2)

In [45]:
song_idx = "00100"
idx = 3
dir_v = "../data/data_spect_real_513/test/vocals"
dir_m = "../data/data_spect_real_513/test/mix"

def eval_clip_given_path(gt_path, mix_path):
    gt = th.load(gt_path)
    gt = transform2to4(gt)
    reference = get_wav(gt)
    
    mix = th.load(mix_path)
    result = calculate(model, input_tensor = mix)
    estimate = get_wav(result)
    #print(th.sum(reference))
    if th.sum(reference) == 0:
        return None, None
    score = eval_track(reference.unsqueeze(0), estimate.unsqueeze(0), win = 1024, hop = 512, compute_sdr=True)
    score1 = np.array(score[0])
    score2 = np.array(score[1])
    return np.median(score1), np.median(score2)

In [46]:
song_idx = "002"
# idx = 3
dir_v = "../data/data_spect_real_513/test/vocals"
dir_m = "../data/data_spect_real_513/test/mix"
# eval_clip(song_idx, idx, dir_v, dir_m)

In [47]:
def eval_song(song_idx, dir_gt, dir_mix):
    score1 = []
    score2 = []
    spec_mix_files = sorted(glob(dir_mix + '/' + song_idx + "*.pt"))
    spec_gt_files = sorted(glob(dir_gt + '/' + song_idx + "*.pt"))
    length = len(spec_mix_files)
    for idx in range(length):
        s1, s2 = eval_clip(song_idx, idx, dir_v, dir_m)
        score1.append(s1)
        score2.append(s2)
    score1 = np.array(score1)
    score2 = np.array(score2)
    return np.median(score1), np.median(score2)

In [48]:
#eval_song(song_idx, dir_v, dir_m)

In [49]:
def eval_train(dir_gt, dir_mix):
    score1 = []
    score2 = []
    spec_mix_folder = sorted(os.listdir(dir_mix))
    spec_gt_folder = sorted(os.listdir(dir_gt))
    length = len(spec_mix_folder)
    print(length)
    for i in range(length):
        print(i, "out of", length, end='\r')
        mix_filename = spec_mix_folder[i]
        gt_filename = spec_gt_folder[i]
        mix_path = os.path.join(dir_mix, mix_filename)
        gt_path = os.path.join(dir_gt, gt_filename)
        
        s1, s2 = eval_clip_given_path(gt_path, mix_path)
        if s1 == None:
            print("SKIP            ")
            continue
        score1.append(s1)
        score2.append(s2)
        print("s1:", round(s1, 2), "   s2:", round(s2, 2), "        ")
        
    score1 = np.array(score1)
    score2 = np.array(score2)
    return np.median(score1), np.median(score2)

In [None]:
eval_train(dir_v, dir_m)

735
s1: 0.0    s2: 6.06         
s1: 8.76    s2: 3.19         
s1: 5.93    s2: 3.35         
s1: 7.3    s2: 5.33         
s1: 5.9    s2: 5.35         
s1: 10.42    s2: 5.81         
s1: 6.96    s2: 5.03         
s1: 5.28    s2: 5.38         
s1: 4.35    s2: 3.75         
s1: 7.11    s2: 5.33         
SKIP            
SKIP            
s1: nan    s2: 1.67         
s1: nan    s2: 1.0         
s1: 2.94    s2: 2.36         
s1: 0.89    s2: 0.83         
s1: 1.29    s2: 0.77         
s1: 1.78    s2: 0.88         
s1: nan    s2: 1.59         
s1: nan    s2: 2.06         
s1: nan    s2: 2.5         
s1: 1.6    s2: 1.19         
s1: 0.94    s2: 0.54         
s1: 1.74    s2: 0.51         
s1: nan    s2: 0.83         
s1: nan    s2: -2.31         
s1: nan    s2: -10.64         
s1: 0.22    s2: -1.55         
s1: 1.95    s2: 0.73         
s1: 1.87    s2: 1.0         
s1: nan    s2: 0.63         
SKIP            
SKIP            
SKIP            
SKIP            
SKIP            
SKIP            
s