In [2]:
import os
import random
from itertools import product
from pathlib import Path

import fire
import jiwer
import librosa
import numpy as np
import pandas as pd
import torch
import torchaudio
import tqdm

import transformers
from huggingface_hub import hf_hub_download
from loguru import logger
from pydub import AudioSegment
from transformers import pipeline



In [4]:
device = 'cuda'

In [5]:
ecapa2_file = hf_hub_download(repo_id='Jenthe/ECAPA2', filename='ecapa2.pt', cache_dir=None)
ecapa2 = torch.jit.load(ecapa2_file, map_location='cpu').to(device)

In [15]:
def torch_rms_norm(wav, db_level=-27.0):
    r = 10 ** (db_level / 20)
    a = torch.sqrt((wav.size(-1) * (r ** 2)) / torch.sum(wav ** 2))
    return wav * a

In [16]:
def get_ecapa2_spk_embedding(path=None, audio=None, ref_dBFS=None, model_sr=16000):
    if path is not None and path.exists():
        audio, sr = torchaudio.load(path)
        if audio.size(1) == 0:
            return None
    elif audio is not None:
        audio, sr = audio
        audio = torch.FloatTensor(audio).unsqueeze(0)
    else:
        raise ValueError('One of `path` or `audio` arguments should not be None')

    # sample rate of 16 kHz expected
    if sr != model_sr:
        audio = torchaudio.functional.resample(audio, sr, model_sr)

    # RMS norm based on the reference audio dBFS it make all models output in the same db level and it avoid issues
    if ref_dBFS is not None:
        audio = torch_rms_norm(audio, db_level=ref_dBFS)

    # compute speaker embedding
    embed = ecapa2(audio.to(device))
    # ensures that l2 norm is applied on output
    embed = torch.nn.functional.normalize(embed, p=2, dim=1)
    return embed.cpu().detach().squeeze().numpy()

In [9]:
vp_speakers = pd.read_csv('../data/facebook_voxpopuli/test_metadata.csv')
lg_speakers = pd.read_csv('../data/keithito_lj_speech/test_metadata.csv')

In [12]:
speaker_4525 = vp_speakers[vp_speakers['speaker_id'].isin([4525])]
speaker_197469= vp_speakers[vp_speakers['speaker_id'].isin([197469])]
speaker_lg = lg_speakers.sample(10, random_state=42)

In [17]:
emb_4525 = []
for _, row in speaker_4525.iterrows():
    audio_path = Path(f'../data/facebook_voxpopuli/wavs/{row["audio_id"]}.wav')
    ref_dBFS = AudioSegment.from_file(audio_path).dBFS
    emb_4525.append(get_ecapa2_spk_embedding(path=audio_path, ref_dBFS=ref_dBFS))

In [18]:
emb_197469 = []
for _, row in speaker_197469.iterrows():
    audio_path = Path(f'../data/facebook_voxpopuli/wavs/{row["audio_id"]}.wav')
    ref_dBFS = AudioSegment.from_file(audio_path).dBFS
    emb_197469.append(get_ecapa2_spk_embedding(path=audio_path, ref_dBFS=ref_dBFS))

In [19]:
emb_lg = []
for _, row in speaker_lg.iterrows():
    audio_path = Path(f'../data/keithito_lj_speech/wavs/{row["audio_id"]}.wav')
    ref_dBFS = AudioSegment.from_file(audio_path).dBFS
    emb_lg.append(get_ecapa2_spk_embedding(path=audio_path, ref_dBFS=ref_dBFS))

In [26]:
emb_4525 = np.vstack(emb_4525)

In [27]:
emb_197469=np.vstack(emb_197469)

In [32]:
emb_lg=np.vstack(emb_lg)

In [29]:
from sklearn.metrics.pairwise import cosine_distances

In [33]:
cosine_distances(emb_4525, emb_197469).mean()

0.9190702

In [34]:
cosine_distances(emb_lg, emb_197469).mean()

1.0102003

In [35]:
cosine_distances(emb_lg, emb_4525).mean()

1.0549443

In [36]:
cosine_distances(emb_lg, emb_lg).mean()

0.15678295

In [37]:
cosine_distances(emb_4525, emb_4525).mean()

0.23634158

In [None]:
cosine_distances(emb_4525, emb_4525).mean()

In [5]:
metadata = pd.read_csv('../data/keithito_lj_speech/metadata.csv')
metadata

Unnamed: 0,audio_id,raw_text,speaker_id,whisper_transcription
0,LJ049-0081,There was no Federal criminal jurisdiction ove...,lg_speaker,There was no federal criminal jurisdiction ov...
1,LJ028-0384,So Babylon was buried and forgotten.,lg_speaker,So Babylon was buried and forgotten.
2,LJ034-0111,"When arrested, he gave his weight as 140 pound...",lg_speaker,When arrested he gave his weight as 140 pound...
3,LJ030-0101,"Admiral George G. Burkley, physician to the Pr...",lg_speaker,"Admiral George G. Berkeley, physician to the ..."
4,LJ034-0171,"Edwards said, quote, Look at that guy there in...",lg_speaker,"Edwards said, quote, look at that guy there i..."
...,...,...,...,...
13094,LJ013-0152,at Montreuil. He was arraigned at the Old Bail...,lg_speaker,at Montreuil. He was arraigned at the Old Bai...
13095,LJ045-0052,which Mrs. Paine had made in part to give her ...,lg_speaker,"which Mrs. Payne had made, in part, to give h..."
13096,LJ019-0382,"As the years passed, great want of uniformity ...",lg_speaker,"As the years passed, great want of uniformity..."
13097,LJ007-0186,"the undue authority given to prisoners, the le...",lg_speaker,"the undue authority given to prisoners, the l..."


In [9]:
metadata.sort_values('audio_id')

Unnamed: 0,audio_id,raw_text,speaker_id,whisper_transcription
4693,LJ001-0002,in being comparatively modern.,lg_speaker,in being comparatively modern.
3108,LJ001-0003,For although the Chinese took impressions from...,lg_speaker,For although the Chinese took impressions fro...
5778,LJ001-0004,"produced the block books, which were the immed...",lg_speaker,"produced the block books, which were the imme..."
3187,LJ001-0005,the invention of movable metal letters in the ...,lg_speaker,The invention of movable metal letters in the...
662,LJ001-0006,"And it is worth mention in passing that, as an...",lg_speaker,And it is worth mention in passing that as an...
...,...,...,...,...
2350,LJ050-0274,made certain recommendations which it believes...,lg_speaker,"made certain recommendations, which it believ..."
10228,LJ050-0275,materially improve upon the procedures in effe...,lg_speaker,materially improve upon the procedures in eff...
1171,LJ050-0276,"As has been pointed out, the Commission has no...",lg_speaker,"As has been pointed out, the Commission has n..."
5405,LJ050-0277,with the active cooperation of the responsible...,lg_speaker,with the active cooperation of the responsibl...


In [7]:
metadata['raw_text'].unique().shape

(13073,)

In [10]:
lg_prompts = metadata.sort_values('audio_id').iloc[1:]['raw_text'].unique()
lg_prompts = pd.Series(lg_prompts, name='prompt')
lg_prompts

0        For although the Chinese took impressions from...
1        produced the block books, which were the immed...
2        the invention of movable metal letters in the ...
3        And it is worth mention in passing that, as an...
4        "the earliest book printed with movable types,...
                               ...                        
13067    made certain recommendations which it believes...
13068    materially improve upon the procedures in effe...
13069    As has been pointed out, the Commission has no...
13070    with the active cooperation of the responsible...
13071    the recommendations we have here suggested wou...
Name: prompt, Length: 13072, dtype: object

In [11]:
lg_prompts.to_frame().to_csv('../data/dpo_dataset/lg_prompts.csv', index=False)

In [3]:
df = pd.read_parquet('../data/dpo_dataset/finale_samples.parquet')
df

Unnamed: 0,gen_id,audio_id,speaker_id,prompt_id,text,gpt_codes,cer,mer,wer,wil,wip,secs,utmos,cer_rank,secs_rank,utmos_rank
0,0,p236_441.wav,p236_441,5600,"Through rigorous market analysis, the investme...","[8, 28, 256, 728, 105, 371, 325, 577, 487, 225...",0.000000,0.000000,0.000000,0.000000,1.000000,0.318090,4.197962,1,3,9
1,1,p236_441.wav,p236_441,5600,"Through rigorous market analysis, the investme...","[8, 28, 256, 604, 487, 82, 838, 467, 225, 587,...",0.000000,0.000000,0.000000,0.000000,1.000000,0.247824,3.638512,2,1,1
2,2,p236_441.wav,p236_441,5600,"Through rigorous market analysis, the investme...","[294, 256, 325, 445, 728, 487, 467, 459, 225, ...",0.181818,0.166667,0.181818,0.242424,0.757576,0.317722,3.883627,10,2,4
3,3,p236_441.wav,p236_441,5600,"Through rigorous market analysis, the investme...","[746, 256, 333, 728, 487, 82, 577, 225, 28, 12...",0.000000,0.000000,0.000000,0.000000,1.000000,0.356188,4.129373,3,10,8
4,4,p236_441.wav,p236_441,5600,"Through rigorous market analysis, the investme...","[305, 82, 256, 270, 105, 487, 467, 536, 225, 2...",0.000000,0.000000,0.000000,0.000000,1.000000,0.337849,4.014897,4,8,7
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,5,p239_429.wav,p239_429,1016,"While negotiating with the opposing party, the...","[483, 288, 49, 775, 257, 584, 161, 128, 119, 2...",0.000000,0.000000,0.000000,0.000000,1.000000,0.629990,4.193303,6,2,8
96,6,p239_429.wav,p239_429,1016,"While negotiating with the opposing party, the...","[483, 288, 775, 257, 299, 304, 433, 584, 871, ...",0.000000,0.000000,0.000000,0.000000,1.000000,0.627538,3.659092,7,1,1
97,7,p239_429.wav,p239_429,1016,"While negotiating with the opposing party, the...","[555, 288, 49, 435, 368, 299, 906, 128, 584, 1...",0.000000,0.000000,0.000000,0.000000,1.000000,0.657178,3.848847,8,9,5
98,8,p239_429.wav,p239_429,1016,"While negotiating with the opposing party, the...","[30, 288, 299, 56, 269, 371, 193, 64, 59, 256,...",0.000000,0.000000,0.000000,0.000000,1.000000,0.650794,3.839072,9,5,4


In [13]:
def f3(row):
    cer = row['cer_rank'] /10
    secs = row['secs_rank']/10
    utmos = row['utmos_rank'] / 10

    return 3 / (1/cer+1/secs+1/utmos)

In [4]:
df['secs_rank'] = 11 - df['secs_rank']
df['utmos_rank'] = 11 - df['utmos_rank']
df['rank'] = df[['cer_rank', 'secs_rank', 'utmos_rank']].mean(axis=1)

In [14]:
df['f3_rank'] = df.apply(f3, axis=1)

In [36]:
def build_dataset(group):
    sorted_g = group.sort_values('f3_rank')
    good = sorted_g.iloc[1]
    bad = sorted_g.iloc[-2]
    res = pd.Series(
        {
            'audio_id': bad['audio_id'],
            'speaker_id': bad['speaker_id'],
            'text': bad['text'],
            'mel_cond_l': bad['gpt_codes'],
            'mel_cond_w': good['gpt_codes'],
            'l_rank':  bad['rank'],
            'w_rank': good['rank'],
        }
    )
    return res

In [37]:
df.groupby('prompt_id').apply(build_dataset).to_parquet('../data/dpo_dataset/finale_dpo_data.parquet')

In [16]:
df[df['prompt_id'] == 100].sort_values('f3_rank')

Unnamed: 0,gen_id,audio_id,speaker_id,prompt_id,text,gpt_codes,cer,mer,wer,wil,wip,secs,utmos,cer_rank,secs_rank,utmos_rank,rank,f3_rank
89,9,20150610-0900-PLENARY-15-en_20150610-19_26_22_...,20150610-0900-PLENARY-15-en_20150610-19_26_22_9,100,Key performance indicators showed a substantia...,"[154, 392, 28, 82, 200, 640, 225, 338, 279, 62...",0.0,0.0,0.0,0.0,1.0,0.130334,3.601701,8,1,1,3.333333,0.141176
80,0,20150610-0900-PLENARY-15-en_20150610-19_26_22_...,20150610-0900-PLENARY-15-en_20150610-19_26_22_9,100,Key performance indicators showed a substantia...,"[631, 225, 28, 82, 128, 640, 484, 76, 105, 365...",0.0,0.0,0.0,0.0,1.0,0.083068,2.871134,1,7,6,4.666667,0.229091
83,3,20150610-0900-PLENARY-15-en_20150610-19_26_22_...,20150610-0900-PLENARY-15-en_20150610-19_26_22_9,100,Key performance indicators showed a substantia...,"[807, 28, 82, 640, 76, 225, 338, 154, 534, 487...",0.0,0.0,0.0,0.0,1.0,0.099155,3.43714,2,4,2,2.666667,0.24
81,1,20150610-0900-PLENARY-15-en_20150610-19_26_22_...,20150610-0900-PLENARY-15-en_20150610-19_26_22_9,100,Key performance indicators showed a substantia...,"[225, 28, 256, 487, 200, 534, 76, 627, 640, 44...",0.3,0.3,0.3,0.51,0.49,0.123393,1.74554,10,2,10,7.333333,0.428571
86,6,20150610-0900-PLENARY-15-en_20150610-19_26_22_...,20150610-0900-PLENARY-15-en_20150610-19_26_22_9,100,Key performance indicators showed a substantia...,"[154, 99, 28, 82, 487, 85, 640, 225, 627, 13, ...",0.0,0.0,0.0,0.0,1.0,0.08475,3.412921,5,6,3,4.666667,0.428571
84,4,20150610-0900-PLENARY-15-en_20150610-19_26_22_...,20150610-0900-PLENARY-15-en_20150610-19_26_22_9,100,Key performance indicators showed a substantia...,"[401, 28, 128, 627, 640, 76, 487, 338, 225, 36...",0.0,0.0,0.0,0.0,1.0,0.054122,3.043425,3,9,5,5.666667,0.465517
88,8,20150610-0900-PLENARY-15-en_20150610-19_26_22_...,20150610-0900-PLENARY-15-en_20150610-19_26_22_9,100,Key performance indicators showed a substantia...,"[640, 225, 82, 368, 11, 730, 701, 551, 571, 91...",0.0,0.0,0.0,0.0,1.0,0.115861,2.2851,7,3,9,6.333333,0.510811
87,7,20150610-0900-PLENARY-15-en_20150610-19_26_22_...,20150610-0900-PLENARY-15-en_20150610-19_26_22_9,100,Key performance indicators showed a substantia...,"[807, 99, 28, 82, 627, 76, 151, 487, 225, 640,...",0.0,0.0,0.0,0.0,1.0,0.072053,3.133993,6,8,4,6.0,0.553846
85,5,20150610-0900-PLENARY-15-en_20150610-19_26_22_...,20150610-0900-PLENARY-15-en_20150610-19_26_22_9,100,Key performance indicators showed a substantia...,"[85, 28, 82, 640, 571, 225, 487, 76, 831, 551,...",0.0,0.0,0.0,0.0,1.0,0.045662,2.521324,4,10,7,7.0,0.608696
82,2,20150610-0900-PLENARY-15-en_20150610-19_26_22_...,20150610-0900-PLENARY-15-en_20150610-19_26_22_9,100,Key performance indicators showed a substantia...,"[571, 225, 28, 1003, 82, 99, 640, 487, 76, 13,...",0.2,0.181818,0.2,0.263636,0.736364,0.097615,2.461971,9,5,8,7.333333,0.687898


In [18]:
df[df['prompt_id'] == 101].sort_values('f3_rank')

Unnamed: 0,gen_id,audio_id,speaker_id,prompt_id,text,gpt_codes,cer,mer,wer,wil,wip,secs,utmos,cer_rank,secs_rank,utmos_rank,rank,f3_rank
50,0,LJ001-0001.wav,LJ001-0001,101,The novelist's brushstrokes danced with vibran...,"[701, 25, 270, 256, 487, 258, 627, 151, 28, 22...",0.0,0.0,0.0,0.0,1.0,0.380009,3.960942,1,3,3,2.333333,0.18
55,5,LJ001-0001.wav,LJ001-0001,101,The novelist's brushstrokes danced with vibran...,"[534, 299, 368, 211, 371, 12, 225, 627, 487, 1...",0.0,0.0,0.0,0.0,1.0,0.327664,4.182682,3,8,1,4.0,0.205714
59,9,LJ001-0001.wav,LJ001-0001,101,The novelist's brushstrokes danced with vibran...,"[534, 642, 115, 256, 487, 258, 627, 225, 128, ...",0.1,0.1,0.1,0.19,0.81,0.402301,3.67485,8,1,5,4.666667,0.226415
56,6,LJ001-0001.wav,LJ001-0001,101,The novelist's brushstrokes danced with vibran...,"[200, 323, 11, 266, 173, 7, 226, 338, 65, 66, ...",0.0,0.0,0.0,0.0,1.0,0.337886,4.10921,4,6,2,4.0,0.327273
51,1,LJ001-0001.wav,LJ001-0001,101,The novelist's brushstrokes danced with vibran...,"[534, 25, 527, 767, 68, 323, 226, 200, 266, 66...",0.0,0.0,0.0,0.0,1.0,0.295025,3.718721,2,10,4,5.333333,0.352941
54,4,LJ001-0001.wav,LJ001-0001,101,The novelist's brushstrokes danced with vibran...,"[17, 256, 269, 368, 299, 56, 74, 627, 487, 301...",0.3,0.3,0.3,0.455556,0.544444,0.399505,3.647018,10,2,6,6.0,0.391304
53,3,LJ001-0001.wav,LJ001-0001,101,The novelist's brushstrokes danced with vibran...,"[627, 767, 527, 160, 59, 481, 775, 30, 487, 35...",0.1,0.1,0.1,0.19,0.81,0.371423,3.643049,5,4,7,5.333333,0.506024
52,2,LJ001-0001.wav,LJ001-0001,101,The novelist's brushstrokes danced with vibran...,"[551, 405, 369, 299, 59, 435, 561, 323, 266, 6...",0.2,0.2,0.2,0.288889,0.711111,0.341844,3.468195,9,5,10,8.0,0.72973
57,7,LJ001-0001.wav,LJ001-0001,101,The novelist's brushstrokes danced with vibran...,"[627, 871, 527, 68, 323, 59, 727, 405, 256, 48...",0.1,0.1,0.1,0.1,0.9,0.300127,3.623864,6,9,8,7.666667,0.744828
58,8,LJ001-0001.wav,LJ001-0001,101,The novelist's brushstrokes danced with vibran...,"[534, 487, 167, 82, 225, 258, 627, 587, 151, 2...",0.1,0.1,0.1,0.19,0.81,0.330368,3.603032,7,7,9,7.666667,0.756


In [34]:
def f(group):
    sort_group = group.sort_values('rank').iloc[0]
    return sum(sort_group[['cer_rank', 'secs_rank', 'utmos_rank']] > 5)

In [35]:
df.groupby('prompt_id').apply(f).mean()

0.3156

In [32]:
_.mean()

0.0

In [12]:
df['cer'].max()

39.09090909090909