In [1]:
from typing import Any
import numpy as np
import random
import os
import argparse
from pytorch_lightning.utilities.types import EVAL_DATALOADERS
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import pytorch_lightning as pl
from lightning.pytorch.accelerators import find_usable_cuda_devices
from asteroid.data import DNSDataset,LibriMix
from asteroid.models import DCCRNet, DCCRNet_mini
from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr
from tools_for_model import near_avg_index, max_index, min_index, Bar
import yaml
from pprint import pprint
from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict
from asteroid.metrics import get_metrics
from dataloader import create_dataloader
from tools_for_model import cal_pesq, cal_stoi
from torch_stoi import NegSTOILoss
from asteroid.utils import tensors_to_device
from asteroid.dsp.normalization import normalize_estimates
from framework import MultiResolutionSTFTLoss, SPKDLoss, build_review_kd
import feature_extraction
import config as cfg
from asteroid.metrics import WERTracker, MockWERTracker
import pandas as pd

In [2]:
data_set = LibriMix(
            csv_dir='/root/NTH_student/Speech_Enhancement_new/knowledge_distillation_CLSKD/data/wav16k/min/dev',
            task='enh_single',
            sample_rate=16000,
            n_src=1,
            segment=3,
        )

Drop 0 utterances from 3000 (shorter than 3 seconds)


In [3]:
train_loader = create_dataloader(mode='valid',dataset=data_set)
device = torch.device('cpu')

COMPUTE_METRICS = ["stoi"]
#COMPUTE_METRICS = ["si_sdr", "sdr", "sir", "sar", "stoi"]
wer_tracker = (MockWERTracker())

In [4]:
X,Y = 0,0
i = 0
for x,y in Bar(train_loader):
    if i == 2: break
    X=x
    Y=y
    i += 1
#model = DCCRNet.from_pretrained('JorisCos/DCCRNet_Libri1Mix_enhsingle_16k')
model = DCCRNet_mini.from_pretrained('/root/NTH_student/Speech_Enhancement_new/knowledge_distillation_CLSKD/checkpoint/the_best_model.pth')
loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")

  96/3000: [=>................................................] - ETA 4.7s

In [5]:
model_device = next(model.parameters()).device

In [9]:
n_save_ex = 5
sample_rate = 16000
ex_save_dir = '/root/NTH_student/Speech_Enhancement_new/knowledge_distillation_CLSKD/example_CLSKD'
series_list = []
torch.no_grad().__enter__()
save_idx = random.sample(range(len(data_set)), n_save_ex)

for idx in range(len(X)):
    mix = X[idx]
    sources = Y[idx]
    mix, sources = tensors_to_device([mix, sources], device=model_device)
    est_sources = model(mix.unsqueeze(0))
    loss, reordered_sources = loss_func(est_sources, sources[None], return_est=True)
    mix_np = mix.cpu().data.numpy()
    sources_np = sources.cpu().data.numpy()
    est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy()
    # For each utterance, we get a dictionary with the mixture path,
    # the input and output metrics
    utt_metrics = get_metrics(
        mix_np,
        sources_np,
        est_sources_np,
        sample_rate=16000,
        metrics_list=COMPUTE_METRICS)
    utt_metrics["mix_path"] = data_set.mixture_path
    est_sources_np_normalized = normalize_estimates(est_sources_np, mix_np)
    utt_metrics.update(
        **wer_tracker(
            mix=mix_np,
            clean=sources_np,
            estimate=est_sources_np_normalized,
            sample_rate=sample_rate,
        )
    )
    series_list.append(pd.Series(utt_metrics))


    #save some examples
    if idx in save_idx:
        local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx))
        os.makedirs(local_save_dir, exist_ok=True)
        sf.write(local_save_dir + "mixture.wav", mix_np, sample_rate)
        # Loop over the sources and estimates
        for src_idx, src in enumerate(sources_np):
            sf.write(local_save_dir + "s{}.wav".format(src_idx), src, sample_rate)
        for src_idx, est_src in enumerate(est_sources_np_normalized):
            sf.write(
                local_save_dir + "s{}_estimate.wav".format(src_idx),
                est_src,
                sample_rate,
            )

In [7]:
utt_metrics.items()

dict_items([('input_stoi', 0.7868099251140626), ('stoi', 0.872142777344521), ('mix_path', '/root/NTH_student/data/Libri2Mix/wav16k/min/dev/mix_single/2277-149874-0002_2412-153947-0014.wav')])

In [8]:
all_metrics_df = pd.DataFrame(series_list)

 # Print and save summary metrics
final_results = {}
for metric_name in COMPUTE_METRICS:
    input_metric_name = "input_" + metric_name
    ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name]
    final_results[metric_name] = all_metrics_df[metric_name].mean()
    final_results[metric_name + "_imp"] = ldf.mean()

print("Overall metrics :")
pprint(final_results)

Overall metrics :
{'stoi': 0.8651579246456313, 'stoi_imp': 0.07627895924089204}
