<a href="https://colab.research.google.com/github/JamesBolt22/Supervised_Contrastive_learning_for_onset_detection/blob/main/test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Load libraries
from scipy.ndimage import uniform_filter1d as meanfilt
from scipy.signal import medfilt as medfilt
from scipy.ndimage import maximum_filter1d as maxfilt
from scipy.ndimage import median_filter as medfilt2
import madmom
import librosa
%matplotlib inline
import matplotlib.pyplot as plt
import librosa.display
import tensorflow as tf
import numpy as np
import os
import math
import random
import h5py
import keras.backend as K
import scipy

In [None]:
#@title Calculate frame times from onset timings in seconds
def calculate_frame_times(hop_len, sample_rate, start_times, length):
    
    start_frames = np.array([])
    
    for i in start_times:
    
      start_frames = np.append(start_frames,math.ceil((i*sample_rate)/hop_len))
            
    return start_frames

In [None]:
#@title Calculate chunks and zero pad spectrograms
#Both functions taken from https://github.com/rohitma38/cnn-onset-detection

#function to create N-frame overlapping chunks of the full audio spectrogram  
def makechunks(x,duration):
	y=np.zeros([x.shape[1],x.shape[0],duration])
	for i_frame in range(x.shape[1]-duration):
		y[i_frame]=x[:,i_frame:i_frame+duration]
	return y

  #function to zero pad ends of spectrogram
def zeropad2d(x,n_frames):
	y=np.hstack((np.zeros([x.shape[0],n_frames]), x))
	y=np.hstack((y,np.zeros([x.shape[0],n_frames])))
	return y


In [None]:
#@title Peak picking functions
def moving_av(odf, thresh):
    
    Filterwidth = 3
    threshold = thresh
    median_Filter_Odf = medfilt2(odf,Filterwidth)
    maximum_Filter_Odf = maxfilt(median_Filter_Odf, Filterwidth, mode='nearest', axis=0)
    threshold_function = [max(i, threshold) for i in maximum_Filter_Odf]  
    
    return  threshold_function, median_Filter_Odf

In [None]:
#@title Calculate f1 score, precision and recall

#Returns true postives, false positives, false negatives and true negatives
def f1_score_new(gt_frames, final_output, total_frames):
    
    TP = 0
    TN = 0
    FP = 0
    FN = 0
    
    for i in final_output:
        
        if gt_frames.size == 0:
            
            FP += 1
        else:
            min_array = abs(gt_frames - i)
            minimum_value = np.amin(min_array)
            min_index = np.where(min_array == minimum_value)

            #sets how many frames ofinnacuracy to be allowed
            if abs(minimum_value) < 3:


                TP += 1

                gt_frames = np.delete(gt_frames, min_index)

            else:

                FP += 1

    FN = gt_frames.shape[0]
    
    TN = total_frames - TP - FP - FN

    return TP, FP, FN, TN

In [None]:
#@title Create input test data

#Code adapted from https://github.com/rohitma38/cnn-onset-detection
def create_data(audio_path):

  #context parameters
  contextlen=7 #+- frames
  duration=2*contextlen+1
  x,fs=librosa.load(os.path.join(audio_path), sr=44100)

  stats=np.load('/content/drive/MyDrive/means_stds.npy')
  means=stats[0]
  stds=stats[1]

  #get mel spectrogram
  melgram1=librosa.feature.melspectrogram(x,sr=fs,n_fft=1024, hop_length=441,n_mels=80, fmin=27.5, fmax=16000)
  melgram2=librosa.feature.melspectrogram(x,sr=fs,n_fft=2048, hop_length=441,n_mels=80, fmin=27.5, fmax=16000)
  melgram3=librosa.feature.melspectrogram(x,sr=fs,n_fft=4096, hop_length=441,n_mels=80, fmin=27.5, fmax=16000)

  #log scaling
  melgram1=10*np.log10(1e-10+melgram1)
  melgram2=10*np.log10(1e-10+melgram2)
  melgram3=10*np.log10(1e-10+melgram3)

  #normalize
  melgram1=(melgram1-np.atleast_2d(means[0]).T)/np.atleast_2d(stds[0]).T
  melgram2=(melgram2-np.atleast_2d(means[1]).T)/np.atleast_2d(stds[1]).T
  melgram3=(melgram3-np.atleast_2d(means[2]).T)/np.atleast_2d(stds[2]).T

  #zero pad ends
  melgram1=zeropad2d(melgram1,contextlen)
  melgram2=zeropad2d(melgram2,contextlen)
  melgram3=zeropad2d(melgram3,contextlen)

  #make chunks
  melgram1_chunks=makechunks(melgram1,duration)
  melgram2_chunks=makechunks(melgram2,duration)
  melgram3_chunks=makechunks(melgram3,duration)

  melgram1_chunks = np.expand_dims(melgram1_chunks, axis=3)
  melgram2_chunks = np.expand_dims(melgram2_chunks, axis=3)
  melgram3_chunks = np.expand_dims(melgram3_chunks, axis=3)

  full_input = np.concatenate((melgram1_chunks, melgram2_chunks,melgram3_chunks), axis=3)

  return full_input

In [None]:
#@title Merge peaks

#If two or more peaks are in consecutive frames then they are merged
def merge_consecutive_peaks(peaks):

  idx = 1

  while idx < peaks.shape[0]:

    if abs(peaks[idx] - peaks[idx-1]) < 2:

      peaks[idx-1] = -1
    
    idx += 1
  peaks = peaks[peaks != -1]
  return peaks

In [None]:
#@title Test model

def test_model_on_file(model_contrastive,song_name, thresh):

  audio_path = "/content/drive/MyDrive/onsets_audio/onsets/audio/" + song_name + ".flac"
  onset_path = "/content/drive/MyDrive/onsets_annotations/onsets/annotations/onsets/" + song_name + ".onsets"

  hop_len = 441
  nfft = 512
  num_of_bins = 80
  size_of_slice = 14

  #Creates the input data and loads true onsets
  inputs = create_data(audio_path)
  onsets = np.loadtxt(onset_path)
		
	#predict outputs
  pred_contrastive = model_contrastive.predict(inputs) 

  #peak picking
  threshold_function, median_Filter_Odf = moving_av(pred_contrastive, thresh) 
  contrastive_frames = np.array([])
  for i in range(median_Filter_Odf.shape[0]):
    
    if median_Filter_Odf[i] >= threshold_function[i]:
      contrastive_frames = np.append(contrastive_frames,i)
  contrastive_frames =   merge_consecutive_peaks(contrastive_frames)

  #calculates true onsets
  onset_frames = calculate_frame_times(hop_len, 44100, onsets, np.max(onsets))

  #finds max onset frame
  max_frame = inputs.shape[0]

  #calculates model scores
  contrastive_classifier_f1 = f1_score_new(onset_frames, contrastive_frames,  max_frame)  

  return contrastive_classifier_f1

#sets model to test
model_contrastive = tf.keras.models.load_model("/content/drive/MyDrive/Masters/Models/Hyperparam/temp/temp_0.7",compile=False)

#sets song names for test set
song_names = np.genfromtxt('/content/drive/MyDrive/onsets_splits/onsets/splits/8-fold_cv_random_7.fold',dtype='str')

f1_total = 0
tot = 0

TP_tot = 0
TN_tot = 0
FP_tot = 0
FN_tot = 0

#Runs through threshold values
thresholds = [0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1,0.2,0.3,0.4,0.5,0.5,0.6,0.7,0.8,0.9]

for thresh in thresholds:

  TP_tot = 0
  TN_tot = 0
  FP_tot = 0
  FN_tot = 0

  for song_name in song_names:

    TP, FP, FN, TN = test_model_on_file(model_contrastive,song_name, thresh)
    TP_tot += TP
    TN_tot += TN
    FN_tot += FN
    FP_tot += FP
    tot += 1

  #Calculates final metrics  
  f1_score = TP_tot/(TP_tot+ 0.5*(FP_tot + FN_tot))
  precision = TP_tot / (TP_tot + FP_tot)
  recall = TP_tot / (TP_tot + FN_tot)
  accuracy = (TP_tot + TN_tot) / (TP_tot + TN_tot + FN_tot + FP_tot)

  print("f1",f1_score)
  print(precision)
  print(recall)
  print(accuracy)