In [None]:
#TEST_COPY.IPYNB

import os
import sys
import torch
import torch.nn.functional as F
import torchaudio
import speechbrain as sb
import speechbrain.nnet.schedulers as schedulers
import numpy as np
from tqdm import tqdm
import logging
from hyperpyyaml import load_hyperpyyaml
import csv

logger = logging.getLogger(__name__)

from mir_eval.separation import bss_eval_sources
from speechbrain.dataio.dataio import read_audio

class Separation(sb.Brain):
   def compute_objectives(self, predictions, targets):
        """Computes the si-snr loss"""
        return self.hparams.loss(targets, predictions)
     
   def get_metrics(self, targets, preds, mixtures):

        # Create folders where to store audio
      save_file = os.path.join(self.hparams.output_folder, "test_results.csv")

        # Variable init
      all_sdrs = []
      all_sdrs_i = []
      all_sisnrs = []
      all_sisnrs_i = []
      csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i"]

        

      with open(save_file, "w") as results_csv:
         writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
         writer.writeheader()
            
         for target, pred,mixture_ in zip(targets, preds,mixtures):
            print('YYYYYYY')
            # Compute SI-SNR
            target = torch.cat(
            [target[i].unsqueeze(-1) for i in range(self.hparams.num_spks)],
            dim=-1,
        ).to(self.device)
            
            pred = torch.cat(
            [pred[i].unsqueeze(-1) for i in range(self.hparams.num_spks)],
            dim=-1,
        ).to(self.device)
            print(f" targets: {target.shape}")
            sisnr = self.compute_objectives(pred, target)
            print("DONE")
            # COmpute SI-SNR Improvement
            mixture_signal = torch.stack(
               [mixture_] * self.hparams.num_spks, dim=-1
            )
            print(mixture_signal.shape) 
            mixture_signal = mixture_signal.to(target.device)
            mixture_signal = mixture_signal.to(target.device)
            sisnr_baseline = self.compute_objectives(
               mixture_signal, target
            )
            sisnr_i = sisnr - sisnr_baseline

            # Compute SDR
            sdr, _, _, _ = bss_eval_sources(
               target[0].t().cpu().numpy(),
               pred[0].t().detach().cpu().numpy(),
            )

            sdr_baseline, _, _, _ = bss_eval_sources(
               target[0].t().cpu().numpy(),
               mixture_signal[0].t().detach().cpu().numpy(),
            )

            sdr_i = sdr.mean() - sdr_baseline.mean()

            # Saving on a csv file
            row = {
               # "snt_id": snt_id[0],
               "sdr": sdr.mean(),
               "sdr_i": sdr_i,
               "si-snr": -sisnr.item(),
               "si-snr_i": -sisnr_i.item(),
            }
            writer.writerow(row)

            # Metric Accumulation
            all_sdrs.append(sdr.mean())
            all_sdrs_i.append(sdr_i.mean())
            all_sisnrs.append(-sisnr.item())
            all_sisnrs_i.append(-sisnr_i.item())
      logger.info("Mean SISNR is {}".format(np.array(all_sisnrs).mean()))
      logger.info("Mean SISNRi is {}".format(np.array(all_sisnrs_i).mean()))
      logger.info("Mean SDR is {}".format(np.array(all_sdrs).mean()))
      logger.info("Mean SDRi is {}".format(np.array(all_sdrs_i).mean()))   

def read_all_audio(path):
   out_target = {}
   out_est = {}
   out_mix = []
   for filename in os.listdir(path):
      filepath = os.path.join(path, filename)
      item_num = filename.split('_')[0]
      audio = read_audio(filepath).unsqueeze(0)
      if "hat" in filename:
         if  item_num in out_est:
            out_est[item_num].append(audio)
         else: 
            out_est[item_num]  = [audio]
      elif "mix" not in filename:
         if  item_num in out_est:
            out_target[item_num].append(audio)
         else: 
            out_target[item_num]  = [audio]
      elif "mix" in filename:
         out_mix.append(read_audio(filepath))
   out_target = list(out_target.values())
   out_est = list(out_est.values())
      
   return out_target, out_est, out_mix

out_target, out_est, out_mix = read_all_audio('data/test')

def run():
    # Load hyperparameters file with command-line overrides
   hparams_file, run_opts, overrides = sb.parse_arguments(["hyperparams.yaml"])
   hparams_file = 'hyperparams.yaml'
   with open(hparams_file) as fin:
      hparams = load_hyperpyyaml(fin, overrides)
        
      # Load pretrained model if pretrained_separator is present in the yaml
   if "pretrained_separator" in hparams:
       # run_on_main(hparams["pretrained_separator"].collect_files)
       hparams["pretrained_separator"].load_collected(
            device=run_opts["device"]
       )
   # Brain class initialization
   separator = Separation(
        modules=hparams["modules"],
        hparams=hparams,
        run_opts=run_opts,
    )
   separator.get_metrics(out_target, out_est, out_mix)
run()
