In [25]:
import soundfile as sf
import torchaudio
from speechbrain.pretrained import EncoderClassifier
classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", run_opts={"device":"cuda"})

from speechbrain.pretrained import SpeakerRecognition
verification = SpeakerRecognition.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", run_opts={"device":"cuda"}, 
                                               savedir="pretrained_models/spkrec-ecapa-voxceleb")



In [2]:
import librosa
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import random

from numpy.fft import fft, ifft
import matplotlib.pyplot as plt

In [3]:
class CosineSimilarity(nn.Module):
    def forward(self, tensor_1, tensor_2):
        normalized_tensor_1 = tensor_1 / tensor_1.norm(dim=-1, keepdim=True)
        normalized_tensor_2 = tensor_2 / tensor_2.norm(dim=-1, keepdim=True)
        return (normalized_tensor_1 * normalized_tensor_2).sum(dim=-1)

In [15]:
def shift_pitch(signal, fs, f_ratio):
    """
    Calls psola pitch shifting algorithm
    :param signal: original signal in the time-domain
    :param fs: sample rate
    :param f_ratio: ratio by which the frequency will be shifted
    :return: pitch-shifted signal
    """
    peaks = find_peaks(signal, fs)
    new_signal = psola(signal, peaks, f_ratio)
    return new_signal


def find_peaks(signal, fs, max_hz=950, min_hz=75, analysis_win_ms=40, max_change=1.005, min_change=0.995):
    """
    Find sample indices of peaks in time-domain signal
    :param max_hz: maximum measured fundamental frequency
    :param min_hz: minimum measured fundamental frequency
    :param analysis_win_ms: window size used for autocorrelation analysis
    :param max_change: restrict periodicity to not increase by more than this ratio from the mean
    :param min_change: restrict periodicity to not decrease by more than this ratio from the mean
    :return: peak indices
    """
    N = len(signal)
    min_period = fs // max_hz
    max_period = fs // min_hz

    # compute pitch periodicity
    sequence = int(analysis_win_ms / 1000 * fs)  # analysis sequence length in samples
    periods = compute_periods_per_sequence(signal, sequence, min_period, max_period, N)

    # simple hack to avoid octave error: assume that the pitch should not vary much, restrict range
    mean_period = np.mean(periods)
    max_period = int(mean_period * 1.1)
    min_period = int(mean_period * 0.9)
    periods = compute_periods_per_sequence(signal, sequence, min_period, max_period, N)

    # find the peaks
    peaks = [np.argmax(signal[:int(periods[0]*1.1)])]
    while True:
        prev = peaks[-1]
        idx = prev // sequence  # current autocorrelation analysis window
        if prev + int(periods[idx] * max_change) >= N:
            break
        # find maximum near expected location
        peaks.append(prev + int(periods[idx] * min_change) +
                np.argmax(signal[prev + int(periods[idx] * min_change): prev + int(periods[idx] * max_change)]))
    return np.array(peaks)


def compute_periods_per_sequence(signal, sequence, min_period, max_period, N):
    """
    Computes periodicity of a time-domain signal using autocorrelation
    :param sequence: analysis window length in samples. Computes one periodicity value per window
    :param min_period: smallest allowed periodicity
    :param max_period: largest allowed periodicity
    :return: list of measured periods in windows across the signal
    """
    
    offset = 0  # current sample offset
    periods = []  # period length of each analysis sequence

    while offset < N:
        fourier = fft(signal[offset: offset + sequence])
        fourier[0] = 0  # remove DC component
        autoc = ifft(fourier * np.conj(fourier)).real
        autoc_peak = min_period + np.argmax(autoc[min_period: max_period])
        periods.append(autoc_peak)
        offset += sequence
    return periods


def psola(signal, peaks, f_ratio):
    """
    Time-Domain Pitch Synchronous Overlap and Add
    :param signal: original time-domain signal
    :param peaks: time-domain signal peak indices
    :param f_ratio: pitch shift ratio
    :return: pitch-shifted signal
    """
    N = len(signal)
    # Interpolate
    new_signal = np.zeros(N)
    new_peaks_ref = np.linspace(0, len(peaks) - 1, int(len(peaks) * f_ratio))
    new_peaks = np.zeros(len(new_peaks_ref)).astype(int)

    for i in range(len(new_peaks)):
        weight = new_peaks_ref[i] % 1
        left = np.floor(new_peaks_ref[i]).astype(int)
        right = np.ceil(new_peaks_ref[i]).astype(int)
        new_peaks[i] = int(peaks[left] * (1 - weight) + peaks[right] * weight)

    # PSOLA
    for j in range(len(new_peaks)):
        # find the corresponding old peak index
        i = np.argmin(np.abs(peaks - new_peaks[j]))
        # get the distances to adjacent peaks
        P1 = [new_peaks[j] if j == 0 else new_peaks[j] - new_peaks[j-1],
              N - 1 - new_peaks[j] if j == len(new_peaks) - 1 else new_peaks[j+1] - new_peaks[j]]
        # edge case truncation
        if peaks[i] - P1[0] < 0:
            P1[0] = peaks[i]
        if peaks[i] + P1[1] > N - 1:
            P1[1] = N - 1 - peaks[i]
        # linear OLA window
        window = list(np.linspace(0, 1, P1[0] + 1)[1:]) + list(np.linspace(1, 0, P1[1] + 1)[1:])
        # center window from original signal at the new peak
        new_signal[new_peaks[j] - P1[0]: new_peaks[j] + P1[1]] += window * signal[peaks[i] - P1[0]: peaks[i] + P1[1]]
    return new_signal

In [5]:
def revers_f(org_voice, dis_voice, sr, a):
    '''
    step 1: revers disguished voice
    step 2: calculate similarity
    '''
    rev_vc = librosa.effects.pitch_shift(dis_voice, sr=sr, n_steps=-a, bins_per_octave=12)
    org_voice = torch.from_numpy(org_voice).float()
    rev_vc = torch.from_numpy(rev_vc).float()
    org_emb = torch.tensor(classifier.encode_batch(org_voice))
    rev_emb = torch.tensor(classifier.encode_batch(rev_vc))
    sim = F.cosine_similarity(org_emb, rev_emb).mean()
    return sim, a

In [7]:
def TIFS(org_voice, dis_voice, sr):
    org_sim = 0 
    org_alpha = 0
    for a in range(-8, 8, 1):
        sim, alpha = revers_f(org_voice, dis_voice, sr, a)
        if sim >= org_sim:
            org_sim = sim
            org_alpha = alpha
            
    return org_alpha

In [108]:
def revers_f(org_voice, dis_voice, sr, a):
    '''
    step 1: revers disguished voice
    step 2: calculate similarity
    '''
    rev_vc = librosa.effects.pitch_shift(dis_voice, sr=sr, n_steps=-a, bins_per_octave=12)
    rev_vc = torch.from_numpy(rev_vc).float()
    rev_emb = torch.tensor(classifier.encode_batch(rev_vc))
    return rev_emb, a

In [6]:
def revers_f_psola(org_voice, dis_voice, sr, a):
    '''
    step 1: revers disguished voice
    step 2: calculate similarity
    '''
    # N = len(dis_voice)
    f_ratio = 2 ** (-a / 12)
    rev_vc = shift_pitch(dis_voice, sr, f_ratio)
    org_voice = torch.from_numpy(org_voice).float()
    rev_vc = torch.from_numpy(rev_vc).float()
    org_emb = torch.tensor(classifier.encode_batch(org_voice))
    rev_emb = torch.tensor(classifier.encode_batch(rev_vc))
    sim = F.cosine_similarity(org_emb, rev_emb).mean()
    return sim, a

In [109]:
def TIFS(org_voice, dis_voice, sr):
    org_sim = 0 
    org_alpha = 0
    org_voice = torch.from_numpy(org_voice).float()
    org_emb = torch.tensor(classifier.encode_batch(org_voice))
    
    for a in range(-8, 8, 1):
        rev_emb, alpha = revers_f(org_voice, dis_voice, sr, a)
        sim = F.cosine_similarity(org_emb, rev_emb).mean()
        if sim >= org_sim:
            org_sim = sim
            org_alpha = alpha
            
    return org_alpha

In [8]:
def TIFS_psola(org_voice, dis_voice, sr):
    org_sim = 0 
    org_alpha = 0
    for a in range(-8, 8, 1):
        sim, alpha = revers_f_psola(org_voice, dis_voice, sr, a)
        if sim >= org_sim:
            org_sim = sim
            org_alpha = alpha
            
    return org_alpha

In [9]:
def data_prepar_joint_spk(data_file):
    file_list, file_org_list,label_list = [], [], []
    with open(data_file, 'r', encoding='utf-8') as infile:
        # file_list, label_list = [], []
        for line in infile:
            data_line = line.strip("\n").split()  # 去除首尾换行符，并按空格划分    
            file_list.append(data_line[0])
            file_org_list.append(data_line[1])
            label_list.append(float(data_line[2]))
   
    return file_list, file_org_list, label_list

In [102]:
test_files, org_files, test_labels = data_prepar_joint_spk('pitch_scaling_validation_seen_matlab.txt')

In [103]:
len(test_files)

10720

In [104]:
random.seed(1234)
random.shuffle(test_files)

random.seed(1234)
random.shuffle(org_files)

random.seed(1234)
random.shuffle(test_labels)

test_files = test_files[0:2000]
org_files = org_files[0:2000]
test_labels = test_labels[0:2000]

In [110]:
def est(test_files, org_files, test_labels):   
    estimate_alpha, GT = [], []
    i = 0
    since = time.time()
    for dis_v, org_v, label in zip(test_files, org_files, test_labels):          
        dis_voice, dis_sr = librosa.load(dis_v) # , sr=22050
        org_voice, _ = librosa.load(org_v) # , sr=22050
        try:
            predict_a = TIFS(org_voice, dis_voice, dis_sr)
            # predict_a = TIFS_psola(org_voice, dis_voice, dis_sr)
            estimate_alpha.append(predict_a) 
            GT.append(label) 
        except:
            print(dis_v)
        if i % 128 == 0:
            time_elapsed = time.time() - since 
            print (int(i/128),'/',int(len(test_files)/128)) 
            print('Testing complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        i += 1
     
    label = torch.tensor(GT)
    alpha = torch.tensor(estimate_alpha)
    err = (alpha - label).abs().mean()
    
    return err

In [111]:
err = est(test_files, org_files, test_labels)

  org_emb = torch.tensor(classifier.encode_batch(org_voice))
  rev_emb = torch.tensor(classifier.encode_batch(rev_vc))


0 / 15
Testing complete in 0m 1s
1 / 15
Testing complete in 3m 1s
2 / 15
Testing complete in 5m 57s
3 / 15
Testing complete in 9m 5s
4 / 15
Testing complete in 12m 1s
5 / 15
Testing complete in 15m 4s
6 / 15
Testing complete in 18m 9s
7 / 15
Testing complete in 21m 1s
8 / 15
Testing complete in 23m 53s
9 / 15
Testing complete in 26m 49s
10 / 15
Testing complete in 29m 40s
11 / 15
Testing complete in 32m 32s
12 / 15
Testing complete in 35m 20s
13 / 15
Testing complete in 38m 23s
14 / 15
Testing complete in 41m 18s
15 / 15
Testing complete in 44m 14s


In [112]:
err

tensor(0.9793)

In [47]:
estimate_alpha = []
i= 0
for dis_v, org_v in zip(test_files, org_files):       
    dis_voice, dis_sr = librosa.load(dis_v)
    org_voice, _ = librosa.load(org_v)
    predict_a = TIFS(org_voice, dis_voice, dis_sr)
    estimate_alpha.append(predict_a)     
    if i % 128 == 0:
        print (int(i/128),'/',int(len(test_files)/128)) 
    i += 1

  org_emb = torch.tensor(classifier.encode_batch(org_voice))
  rev_emb = torch.tensor(classifier.encode_batch(rev_vc))


0 / 7


KeyboardInterrupt: 

In [None]:
estimate_alpha

In [36]:
estimate_alpha

[0, 5]

In [None]:
E:/datasets/project3/AISHELL-3/unseen_Audacity/-0.5\SSB06930015.wav E:/datasets/project3/AISHELL-3/unseen/SSB0693\SSB06930229.wav -0.5
E:/datasets/project3/AISHELL-3/unseen_Audacity/4.5\SSB10020405.wav E:/datasets/project3/AISHELL-3/unseen/SSB1002\SSB10020246.wav 4.5
E:/datasets/project3/AISHELL-3/unseen_Audacity/6.0\SSB07170241.wav E:/datasets/project3/AISHELL-3/unseen/SSB0717\SSB07170008.wav 6.0
E:/datasets/project3/AISHELL-3/unseen_Audacity/6.0\SSB07170246.wav E:/datasets/project3/AISHELL-3/unseen/SSB0717\SSB07170471.wav 6.0
E:/datasets/project3/AISHELL-3/unseen_Audacity/-6.5\SSB09970303.wav E:/datasets/project3/AISHELL-3/unseen/SSB0997\SSB09970006.wav -6.5
E:/datasets/project3/AISHELL-3/unseen_Audacity/-6.5\SSB09970370.wav E:/datasets/project3/AISHELL-3/unseen/SSB0997\SSB09970291.wav -6.5
E:/datasets/project3/AISHELL-3/unseen_Audacity/-6.5\SSB09970450.wav E:/datasets/project3/AISHELL-3/unseen/SSB0997\SSB09970124.wav -6.5
E:/datasets/project3/AISHELL-3/unseen_Audacity/-6.5\SSB10000040.wav E:/datasets/project3/AISHELL-3/unseen/SSB1000\SSB10000320.wav -6.5
E:/datasets/project3/AISHELL-3/unseen_Audacity/-3.0\SSB10020039.wav E:/datasets/project3/AISHELL-3/unseen/SSB1002\SSB10020011.wav -3.0
E:/datasets/project3/AISHELL-3/unseen_Audacity/-3.0\SSB10020045.wav E:/datasets/project3/AISHELL-3/unseen/SSB1002\SSB10020425.wav -3.0