In [15]:
import os
import sys
import torch
import torch.nn.functional as F
import torchaudio
import speechbrain as sb
import speechbrain.nnet.schedulers as schedulers
import numpy as np
from tqdm import tqdm
import logging
from hyperpyyaml import load_hyperpyyaml
import csv
device = torch.device('cuda')
logger = logging.getLogger(__name__)

# from mir_eval.separation import bss_eval_sources
from speechbrain.dataio.dataio import read_audio
from fast_bss_eval import bss_eval_sources
import librosa

In [21]:
import time
import concurrent.futures
def parse_audio(path):
   audio,sr = torchaudio.load(path)
   audio = torchaudio.functional.resample(audio, sr, 8000)
   # audio, sr = librosa.load(path,sr=8000)
   # audio = torch.from_numpy(audio).unsqueeze(0)

   # audio = audio
   # print(audio.shape)
   
   return audio
def append_audio(item_name, path, out_est, out_target, out_est_up, out_mix):
   print(path)
   source1_path = f"{path}/{item_name}_source1hat_up.wav"
   source2_path = f"{path}/{item_name}_source2hat_up.wav"
   if not os.path.isfile(source1_path) or not os.path.isfile(source2_path):
      return

   # Orig
   source1_path = f"{path}/{item_name}_source1.wav"
   source2_path = f"{path}/{item_name}_source2.wav"
   out_target[item_name] = [parse_audio(source1_path), parse_audio(source2_path)]

   source1_path = f"{path}/{item_name}_source1hat.wav"
   source2_path = f"{path}/{item_name}_source2hat.wav"
   out_est[item_name] = [parse_audio(source1_path), parse_audio(source2_path)]
   
   source1_path = f"{path}/{item_name}_source1hat_up.wav"
   source2_path = f"{path}/{item_name}_source2hat_up.wav"
   out_est_up[item_name] = [parse_audio(source1_path), parse_audio(source2_path)]

   mix_path = f"{path}/{item_name}_mix.wav"
   out_mix[item_name]=(parse_audio(mix_path))

def read_all_audio(path, upsampled=False):
   
   audio_ids = []
   out_target= {}
   out_est = {}
   out_est_upsampled = {}
   out_mix = {}
   all_files = os.listdir(path)
   item_names = [file_name.split('_')[0] for file_name in all_files] 
   item_names = list(set(item_names))
   
   # with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
   for item_name in tqdm(item_names, total=len(item_names)):
      append_audio(item_name, path, out_est, out_target, out_est_upsampled, out_mix)
      
   audio_ids =list(sorted(out_target.keys())) 
   out_target = list(dict(sorted(out_target.items())).values())
   out_est= list(dict(sorted(out_est.items())).values())
   out_est_upsampled= list(dict(sorted(out_est_upsampled.items())).values())
   out_mix= list(dict(sorted(out_mix.items())).values())
   return audio_ids, out_target, out_est,out_est_upsampled, out_mix


In [17]:
test = {"dog":"arf","ant":"ayy"}
sorted(test.items())

[('ant', 'ayy'), ('dog', 'arf')]

In [18]:
import numpy as np
class Separation(sb.Brain):
   def compute_objectives(self, predictions, targets):
        """Computes the si-snr loss"""
        return self.hparams.loss(targets, predictions)
     
   def get_metrics(self,audio_ids, targets, preds, mixtures,output_path):

        # Create folders where to store audio
      # save_file = os.path.join(output_path, "test_results.csv")
      save_file = output_path

        # Variable init

      all_sdrs = []
      all_sdrs_i = []
      all_sisnrs = []
      all_sisnrs_i = []
      csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i"]

        
      
      with open(save_file, "w") as results_csv:
         writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
         writer.writeheader()
            
         for audio_id, target, pred,mixture_ in tqdm(zip(audio_ids,targets, preds,mixtures), total=len(targets)):
            target = torch.cat(
            [target[i].unsqueeze(-1) for i in range(self.hparams.num_spks)],
            dim=-1,
        ).to(self.device)
            
            pred = torch.cat(
            [pred[i].unsqueeze(-1) for i in range(self.hparams.num_spks)],
            dim=-1,
        ).to(self.device)
            
  
            sisnr = self.compute_objectives(pred, target)
            # COmpute SI-SNR Improvement
            mixture_signal = torch.stack(
               [mixture_] * self.hparams.num_spks, dim=-1
            )
            
            arrs = [pred[0], target[0], mixture_signal[0]]
            for item in arrs:
               temp_item = (item!=np.inf).all() and (item!=np.NINF).all()
               if not temp_item:
                  continue
            # if mixture_signal[0] == np.inf or mixture_signal[1] ==np.inf:
            
            mixture_signal = mixture_signal.to(target.device)
            sisnr_baseline = self.compute_objectives(
               mixture_signal, target
            )
            sisnr_i = sisnr.cpu().numpy() - sisnr_baseline.cpu().numpy()
  # Compute SDR
            sdr, _, _, _ = bss_eval_sources(
               target[0].t(),
               pred[0].t(),
            )

            sdr_baseline, _, _, _ = bss_eval_sources(
               target[0].t(),
               mixture_signal[0].t(),
            )
            
            sdr = sdr.cpu().numpy()
            sdr_baseline = sdr_baseline.cpu().numpy()
           
            sdr_i = sdr.mean() - sdr_baseline.mean()

            # Saving on a csv file
            row = {
               "snt_id": audio_id,
               "sdr": sdr.mean(),
               "sdr_i": sdr_i,
               "si-snr": -sisnr.item(),
               "si-snr_i": -sisnr_i.item(),
            }
            writer.writerow(row)

            # Metric Accumulation
            all_sdrs.append(sdr.mean())
            all_sdrs_i.append(sdr_i.mean())
            all_sisnrs.append(-sisnr.item())
            all_sisnrs_i.append(-sisnr_i.item())
      # logger.info("Mean SISNR is {}".format(np.array(all_sisnrs).mean()))
      # logger.info("Mean SISNRi is {}".format(np.array(all_sisnrs_i).mean()))
      # logger.info("Mean SDR is {}".format(np.array(all_sdrs).mean()))
      # logger.info("Mean SDRi is {}".format(np.array(all_sdrs_i).mean()))   


In [25]:
# Load hyperparameters file with command-line overrides
hparams_file, run_opts, overrides = sb.parse_arguments(["hyperparams.yaml"])
hparams_file = 'hyperparams.yaml'
#    run_opts = {"device": "cuda:0"}
with open(hparams_file) as fin:
   hparams = load_hyperpyyaml(fin, overrides)
   
# Load pretrained model if pretrained_separator is present in the yaml
if "pretrained_separator" in hparams:
   # run_on_main(hparams["pretrained_separator"].collect_files)
   hparams["pretrained_separator"].load_collected(
      device=run_opts["device"]
   )
# Brain class initialization
separator = Separation(
   modules=hparams["modules"],
   run_opts={"device": "cuda"},
   hparams=hparams,
)

configs = ["standard_model"]
# configs = ["w_noise_speedperturb","w_noise_wavedrop"]
# configs = ["no_noise_speedperturb"]
UPSAMPLE= [False]
# UPSAMPLE= [True]
curr_aud =[]
for config in configs:
   print(f"Reading {config}")
   # out = read_all_audio(f'results/{config}/audio_results')
   # audio_ids,out_target, out_est, out_est_upsampled, out_mix = out
   # audio_ids,out_target, out_est, out_est_upsampled, out_mix = read_all_audio(f'results/{config}/audio_results')
   audio_ids,out_target, out_est, out_est_upsampled, out_mix = read_all_audio(f'../sequential/separated')
   print(audio_ids)
   
   for is_upsample in UPSAMPLE:
      if is_upsample:
         out_target = out_est
         out_est = out_est_upsampled   
      # upsample = "upsampled" if upsample else "not_upsampled"
      output_path = f"results/{config}/{config}_{is_upsample}_test.csv"
      separator.get_metrics(audio_ids, out_target, out_est, out_mix, output_path=output_path)
      print(f"Done with {config} | upsamle: {is_upsample}")
# output_path = f"results/epoch_30/true_test_result.csv"
# separator.get_metrics(audio_ids, out_target, out_est, out_mix, output_path=output_path)
# separator.get_metrics(audio_ids, out_est, out_est_upsampled, out_mix, output_path=output_path)

Reading standard_model


100%|██████████| 2/2 [00:00<?, ?it/s]


../sequential/separated
../sequential/separated
[]


0it [00:00, ?it/s]

Done with standard_model | upsamle: False





### Retrieves the worst, median and best separation audio mixtures

In [24]:

import numpy as np
import pandas as pd
import shutil
# Get worst
def get_worst(column,df):
   return df.loc[df[column].idxmin()]
# Get best 
def get_best(column,df):
   return df.loc[df[column].idxmax()]
# Get Median
def get_median(column,df):
   df.loc[df[column]== df[column].median()]
   df.sort_values(by=column, inplace=True)
   return df[df[column] < df[column].median()].iloc[-1]


def print_audio_info(stat, df):
   print(f"status: {stat} | sdr: {df['sdr']} | sdr_i: {df['sdr_i']} | si-snr: {df['si-snr']} | si-snr_i: {df['si-snr_i']}")

def write_demo_audio(src_path,dest_path):
   shutil.copy(src_path, dest_path)

def write_image(dest_path):
   pass
   
status = ['Worst', 'Median', 'Best']

configs = ["standard_model", "no_noise_speedperturb","w_noise_speedperturb", "w_noise_wavedrop"]
# configs = ["no_noise_speedperturb"]
for config in configs:
   for upsample in [False, True]:
      df = pd.read_csv(f'results/{config}/{config}_{upsample}_test.csv')
      df.replace([np.inf, -np.inf], np.nan, inplace=True)
   # drop rows with NaNs
      df.dropna(inplace=True)
      # df.replace()
      print(f"------------CONFIG: {config} | upsampled: {upsample}")
      worst = get_worst('si-snr',df)
      median = get_median('si-snr',df)
      best = get_best('si-snr',df)
      print_audio_info(status[0], worst)
      print_audio_info(status[1], median)
      print_audio_info(status[2], best)

      mean_values = df.loc[:, df.columns != 'snt_id'].mean()
      print(f'mean_values {mean_values}')

------------CONFIG: standard_model | upsampled: False


TypeError: reduction operation 'argmin' not allowed for this dtype

In [None]:
df = df.sort_values(by='snt_id')
df.head()

Unnamed: 0,snt_id,sdr,sdr_i,si-snr,si-snr_i
0,item0,8.709541,12.099934,7.763442,11.358368
1,item1,-1.304001,-0.404314,-2.374741,-1.409151
2,item10,10.37476,11.388518,9.455255,10.703505
3,item100,8.781614,10.849915,7.796484,10.323214
4,item1000,8.809902,9.547372,8.316668,9.167922


In [None]:
# import numpy as np
# import librosa

# # Load the original audio signal
# audio_file_orig = 'original_audio.wav'
# y_orig, sr_orig = librosa.load(audio_file_orig)

# # Load the upsampled audio signal
# audio_file_upsamp = 'upsampled_audio.wav'
# y_upsamp, sr_upsamp = librosa.load(audio_file_upsamp)

# # Compute the SNR of the upsampled audio signal compared to the original signal
# diff = np.sum((y_orig - y_upsamp)**2)
# snr = 10 * np.log10(np.sum(y_orig**2) / diff)

# print('SNR: {:.2f} dB'.format(snr))


In [None]:
import librosa

# Load the original audio signal
audio_file_orig = 'original_audio.wav'
y_orig, sr_orig = librosa.load(audio_file_orig)

# Load the upsampled audio signal
audio_file_upsamp = 'upsampled_audio.wav'
y_upsamp, sr_upsamp = librosa.load(audio_file_upsamp)

# Compute the spectral centroid and rolloff of the original and upsampled audio signals
spec_centroid_orig = librosa.feature.spectral_centroid(y_orig, sr=sr_orig)[0]
spec_rolloff_orig = librosa.feature.spectral_rolloff(y_orig, sr=sr_orig)[0]

spec_centroid_upsamp = librosa.feature.spectral_centroid(y_upsamp, sr=sr_upsamp)[0]
spec_rolloff_upsamp = librosa.feature.spectral_rolloff(y_upsamp, sr=sr_upsamp)[0]

# Compute the mean and standard deviation of the spectral features
centroid_diff = spec_centroid_orig.mean() - spec_centroid_upsamp.mean()
rolloff_diff = spec_rolloff_orig.mean() - spec_rolloff_upsamp.mean()

centroid_std_diff = spec_centroid_orig.std() - spec_centroid_upsamp.std()
rolloff_std_diff = spec_rolloff_orig.std() - spec_rolloff_upsamp.std()

# Print the results
print('Spectral centroid difference: {:.2f}'.format(centroid_diff))
print('Spectral rolloff difference: {:.2f}'.format(rolloff_diff))

print('Spectral centroid std difference: {:.2f}'.format(centroid_std_diff))
print('Spectral rolloff std difference: {:.2f}'.format(rolloff_std_diff))


Spectral centroid difference: -215.47
Spectral rolloff difference: -267.58
Spectral centroid std difference: -97.98
Spectral rolloff std difference: -108.21


In [None]:
sr_upsamp, sr_orig

(22050, 22050)