In [1]:
import sys
sys.path.append("plbert/")

from nemo.collections.asr.models import EncDecSpeakerLabelModel
from IPython.display import Audio, display
import numpy as np
import torch
import torchaudio
import torchaudio.transforms as T
import torchaudio.functional as F
import math
import os
import requests
import boto3
import wave
import sys
import contextlib
import collections
from tqdm import tqdm
from tqdm.notebook import tqdm
import yaml
from transformers import AlbertConfig, AlbertModel

from phonemize import phonemize
from phonemizer.backend import EspeakBackend
from transformers import TransfoXLTokenizer
from text_normalize import normalize_text, remove_accents
from text_utils import TextCleaner

import whisper
import librosa


[NeMo W 2023-07-27 19:30:15 experimental:27] Module <class 'nemo.collections.asr.modules.audio_modules.SpectrogramToMultichannelFeatures'> is experimental, not ready for production and is not fully supported. Use at your own risk.
      def backtrace(trace: np.ndarray):
    


## PL-BERT Batch Processing

In [2]:
plbert_root = "plbert/"
log_dir = plbert_root+"Checkpoint/"
config_path = os.path.join(log_dir, "config.yml")
plbert_config = yaml.safe_load(open(config_path))

albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
plbert = AlbertModel(albert_base_configuration)

files = os.listdir(log_dir)
ckpts = []
for f in os.listdir(log_dir):
    if f.startswith("step_"): ckpts.append(f)

iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
iters = sorted(iters)[-1]

checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')

state_dict = checkpoint['net']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    if name.startswith('encoder.'):
        name = name[8:] # remove `encoder.`
        new_state_dict[name] = v

plbert.load_state_dict(new_state_dict)
plbert.eval()

AlbertModel(
  (embeddings): AlbertEmbeddings(
    (word_embeddings): Embedding(178, 128, padding_idx=0)
    (position_embeddings): Embedding(512, 128)
    (token_type_embeddings): Embedding(2, 128)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0, inplace=False)
  )
  (encoder): AlbertTransformer(
    (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)
    (albert_layer_groups): ModuleList(
      (0): AlbertLayerGroup(
        (albert_layers): ModuleList(
          (0): AlbertLayer(
            (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (attention): AlbertAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (attention_dropout): Dropout(p=0, inplace=False)
        

In [3]:
"""
This section will most likely be placed in collator function. With the batch of sentences, we need to pad them, and then pass it
through the model.
"""
batch_of_text = ["Hi, And also can you please check what is the current temperature setting of your unit both fridge and the freezer?",
                 "Hi, do you mind returning the toy back to my house?",
                 "Hi, I love the dress you are wearing!", 
                 "With the batch of sentences, we need to pad them, and then pass it through the model."]
batch_of_text = batch_of_text

In [4]:
global_phonemizer = EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True) #make sure brew install espeak and export location of .dylib
tokenizer = TransfoXLTokenizer.from_pretrained(plbert_config['dataset_params']['tokenizer'])
text_cleaner = TextCleaner()

177


In [5]:
#I may need to make this function be able to batch input together (or I can later create a collate function)
#pad token = '$'
def tokenize(sents, global_phonemizer, tokenizer, text_cleaner):
    batched = []
    max_id_length = 0
    for sent in sents:
        pretextcleaned = ' '.join(phonemize(sent, global_phonemizer, tokenizer)['phonemes'])
        cleaned = text_cleaner(pretextcleaned)
        batched.append(torch.LongTensor(cleaned))
        max_id_length = max(max_id_length, len(cleaned))
    phoneme_ids = torch.zeros((len(sents), max_id_length)).long()
    mask = torch.zeros((len(sents), max_id_length)).long()
    for i, c in enumerate(batched):
        phoneme_ids[i,:len(c)] = c
        mask[i,:len(c)] = 1
    return phoneme_ids, mask

In [6]:
def get_pltbert_embs(s, global_phonemizer, tokenizer, text_cleaner):
    """
    Input: list of texts

    Output: output of pretrained Albert model - (batch_size, num_tokens, 768)
    """
    phoneme_ids, attention_mask = tokenize(s, global_phonemizer, tokenizer, text_cleaner)
    return plbert(phoneme_ids, attention_mask=attention_mask).last_hidden_state

In [7]:
get_pltbert_embs(batch_of_text, global_phonemizer, tokenizer, text_cleaner).shape

torch.Size([4, 123, 768])

## SpeakerNet Batch Processing

In [8]:
# spNet = torch.jit.load("/Users/ajaybati/Downloads/0.23.0612.1.AT.PHL.alorica/spNet_traced.jit")
ecapa = torch.jit.load("ecapa2_traced.jit")

In [9]:
audios = ["miipherTestDataset/train/clean_trainset_wav/p234_001.wav", "miipherTestDataset/train/clean_trainset_wav/p234_002.wav",
         "miipherTestDataset/train/clean_trainset_wav/p234_003.wav", "miipherTestDataset/train/clean_trainset_wav/p234_004.wav"]
audios = audios*2
audio_loaded = []
for x in audios:
    ref_path = x
    audio, sr = librosa.load(ref_path, sr=16000)
    audio = np.array([audio])
    audio_signal = torch.tensor(whisper.pad_or_trim(audio))
    audio_loaded.append(audio_signal)
audio_loaded = torch.stack(audio_loaded).squeeze()
audio_loaded.shape

torch.Size([8, 480000])

In [10]:
def batch_spenc(batch):
    kwarg = {'input_signal': batch.cuda(), 'input_signal_length': torch.tensor([batch.shape[-1]]*len(batch)).cuda()} #batched input
    return kwarg
kwarg = batch_spenc(audio_loaded)

In [11]:
ecapa(**kwarg)[-1].shape #logits, embedding

torch.Size([8, 192])

### Speed Testing

In [1]:
import torch
def create_input(seconds):
    return torch.randn(1,16000*seconds)

ten = create_input(10)
thirty = create_input(30)
twomin = create_input(120)
fivemin = create_input(300)

import time
def test(input):
    kwarg = {'input_signal': input, 'input_signal_length': torch.tensor([input.shape[-1]])}
    spNet_times = []
    for x in range(20):
        start = time.time()
        out = spNet(**kwarg)
        spNet_times.append(time.time()-start)

    ecapa_times = []
    for x in range(20):
        start = time.time()
        out = ecapa(**kwarg)
        ecapa_times.append(time.time()-start)

    return spNet_times, ecapa_times

In [2]:
torch.set_num_threads(1)

In [7]:
times = []

In [17]:
import numpy as np
sp, ec = test(fivemin)

sp = np.array(sp)
ec = np.array(ec)

In [18]:
times.append([(np.median(ec), ec.mean()), (np.median(sp), sp.mean())])

In [19]:
times

[[(0.0795280933380127, 0.07995404005050659),
  (0.03753793239593506, 0.041062581539154056)],
 [(0.2759588956832886, 0.27689037322998045),
  (0.09117209911346436, 0.09421477317810059)],
 [(1.3495440483093262, 1.353237557411194),
  (0.4264800548553467, 0.42521281242370607)],
 [(4.634666442871094, 4.688384139537812),
  (0.9876257181167603, 0.9849279642105102)]]

## Whisper Batch Processing

In [11]:
torch.cuda.empty_cache() # PyTorch thing

In [29]:
%load_ext autoreload
%autoreload 2

In [5]:
import whisper

model = whisper.load_model("small.en")

# # load audio and pad/trim it to fit 30 seconds
# audio = whisper.load_audio("/Users/ajaybati/Downloads/0.23.0612.1.AT.PHL.alorica/out.wav")
# audio = whisper.pad_or_trim(audio)

# # make log-Mel spectrogram and move to the same device as the model
# mel = whisper.log_mel_spectrogram(audio)

# model.embed_audio(mel.reshape(1,*mel.shape)).shape

In [31]:
from whisper_encode_batch import Batch, process_data

In [32]:
audios = ["miipherTestDataset/train/clean_trainset_wav/"+x for x in sorted(os.listdir("miipherTestDataset/train/clean_trainset_wav/"))]
len(audios)

23075

In [33]:
process_data(audios, Batch(3,400), "whisperEmbs/clean_trainset_embs/")

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 23075/23075 [1:10:03<00:00,  5.49it/s]


In [34]:
audios = ["miipherTestDataset/train/noisy_trainset_wav/"+x for x in sorted(os.listdir("miipherTestDataset/train/noisy_trainset_wav/"))]
len(audios)

23075

In [35]:
process_data(audios, Batch(3,400), "whisperEmbs/noisy_trainset_embs/")

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 23075/23075 [1:09:35<00:00,  5.53it/s]


In [36]:
audios = ["miipherTestDataset/test/clean_testset_wav/"+x for x in sorted(os.listdir("miipherTestDataset/test/clean_testset_wav/"))]
len(audios)

824

In [37]:
process_data(audios, Batch(3,400), "whisperEmbs/clean_testset_embs/")

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 824/824 [01:41<00:00,  8.08it/s]


In [38]:
audios = ["miipherTestDataset/test/noisy_testset_wav/"+x for x in sorted(os.listdir("miipherTestDataset/test/noisy_testset_wav/"))]
len(audios)

824

In [39]:
process_data(audios, Batch(3,400), "whisperEmbs/noisy_testset_embs/")

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 824/824 [01:45<00:00,  7.79it/s]


In [28]:
query = np.load("whisperEmbs/noisy_testset_embs/batch400_2.npy", mmap_mode='r')

In [29]:
query = query[0]

In [None]:
audio, _ = librosa.load("miipherTestDataset/test/clean_testset_wav/p257_433.wav", sr=16000, duration=20)
audio = whisper.pad_or_trim(audio)

# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio)

test = model.embed_audio(mel.reshape(1,*mel.shape).cuda())

In [None]:
query[0][:10]

In [27]:
test.squeeze()[0][:10]

tensor([-0.2969, -0.1863,  0.2197, -0.3074, -0.9798, -0.3804, -1.6542,  0.2116,
        -1.4315,  0.3202], device='cuda:0', grad_fn=<SliceBackward0>)

In [31]:
audios = ["miipherTestDataset/test/noisy_testset_wav/"+x for x in sorted(os.listdir("miipherTestDataset/test/noisy_testset_wav/"))]
audios[822]

'miipherTestDataset/test/noisy_testset_wav/p257_433.wav'

In [39]:
whisperEmbTrainClean = 'whisperEmbs/clean_trainset_embs'
whisperEmbTrainNoisy = 'whisperEmbs/noisy_trainset_embs'

In [40]:
import os
whisperembclean = [f"{whisperEmbTrainClean}/{a}" for a in sorted([f for f in os.listdir(whisperEmbTrainClean) if 'batch' in f], key=lambda x: int(x.split('_')[-1].split('.')[0]))]
whisperembnoisy = [f"{whisperEmbTrainNoisy}/{a}" for a in sorted([f for f in os.listdir(whisperEmbTrainNoisy) if 'batch' in f], key=lambda x: int(x.split('_')[-1].split('.')[0]))]
clean = sorted(os.listdir("miipherTestDataset/train/clean_trainset_wav/"))
noisy = sorted(os.listdir("miipherTestDataset/train/noisy_trainset_wav/"))


In [47]:
whisperembclean

['whisperEmbs/clean_trainset_embs/batch400_1.npy',
 'whisperEmbs/clean_trainset_embs/batch400_2.npy',
 'whisperEmbs/clean_trainset_embs/batch400_3.npy',
 'whisperEmbs/clean_trainset_embs/batch400_4.npy',
 'whisperEmbs/clean_trainset_embs/batch400_5.npy',
 'whisperEmbs/clean_trainset_embs/batch400_6.npy',
 'whisperEmbs/clean_trainset_embs/batch400_7.npy',
 'whisperEmbs/clean_trainset_embs/batch400_8.npy',
 'whisperEmbs/clean_trainset_embs/batch400_9.npy',
 'whisperEmbs/clean_trainset_embs/batch400_10.npy',
 'whisperEmbs/clean_trainset_embs/batch400_11.npy',
 'whisperEmbs/clean_trainset_embs/batch400_12.npy',
 'whisperEmbs/clean_trainset_embs/batch400_13.npy',
 'whisperEmbs/clean_trainset_embs/batch400_14.npy',
 'whisperEmbs/clean_trainset_embs/batch400_15.npy',
 'whisperEmbs/clean_trainset_embs/batch400_16.npy',
 'whisperEmbs/clean_trainset_embs/batch400_17.npy',
 'whisperEmbs/clean_trainset_embs/batch400_18.npy',
 'whisperEmbs/clean_trainset_embs/batch400_19.npy',
 'whisperEmbs/clean_t

In [None]:
audios = ["miipherTestDataset/test/noisy_testset_wav/"+x for x in sorted(os.listdir("miipherTestDataset/test/noisy_testset_wav/"))]

audios2 = ["miipherTestDataset/test/clean_testset_wav/"+x for x in sorted(os.listdir("miipherTestDataset/test/clean_testset_wav/"))]
[(a,b) for a,b in zip(audios,audios2)]

## Creating Dataset

In [36]:
from tqdm.notebook import tqdm

In [15]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import os
class MiipherDataset(Dataset):
    def __init__(self, noisy_filepath, clean_filepath, text_filepath, whisperEmbfilepathNoisy, whisperEmbfilepathClean):
        self.noisyfilepaths = noisy_filepath
        self.cleanfilepaths = clean_filepath
        self.sentpaths = text_filepath
        self.spencembs = []
        self.all_sents = []
        
        #change to enable batched input
        for x in tqdm(range(len(self.noisyfilepaths))):
            noisy, clean, text = self.noisyfilepaths[x], self.cleanfilepaths[x], self.sentpaths[x]

            #Text: load the text from the file as a string
            with open(text, 'r') as f:
                self.all_sents.append(f.read().strip())
                
            #Speaker Encoder: load audio file and save the loaded array
            ref_path = noisy
            audio, sr = librosa.load(ref_path, sr=16000)
            audio = np.array([audio])
            audio_signal = torch.tensor(whisper.pad_or_trim(audio))
            self.spencembs.append(audio_signal)

        batchNoises = []
        for batchNoisy in tqdm(whisperEmbfilepathNoisy):
            batchNoises.append(np.load(batchNoisy, mmap_mode='r'))
        num_batches = sum([a.shape[0] for a in batchNoises])
        self.whisper_noisy = []
        for loaded in batchNoises:
            for i in tqdm(range(len(loaded))):
                self.whisper_noisy.append(loaded[i,:,:])
        
        batchCleans = []
        for batchClean in tqdm(whisperEmbfilepathClean):
            batchCleans.append(np.load(batchClean, mmap_mode='r'))
        num_batches = sum([a.shape[0] for a in batchNoises])
        self.whisper_clean = []
        for loaded in batchCleans:
            for i in tqdm(range(len(loaded))):
                self.whisper_clean.append(loaded[i,:,:])
                
    def __len__(self):
        return len(self.noisyfilepaths)

    def __getitem__(self, idx):
        return self.whisper_noisy[idx], self.spencembs[idx], self.all_sents[idx], self.whisper_clean[idx]

In [38]:
def collate_fn(batch):
    """
    batch = list of tuples(whisper_noisy, speaker wav, raw sentences, whisper_clean)
    """
    whisper_noisy, speaker_wav, raw_sents, whisper_clean = [],[],[],[]
    for a,b,c,d in batch:
        whisper_noisy.append(a)
        speaker_wav.append(b)
        raw_sents.append(c) 
        whisper_clean.append(d)
    whisper_noisy = torch.stack(whisper_noisy)
    whisper_clean = torch.stack(whisper_clean)
    print("Here")
    plbertembs = get_pltbert_embs(raw_sents, global_phonemizer, tokenizer, text_cleaner)

    print("Done with plbert, now speaker Emb")
    speaker_wav = torch.stack(speaker_wav).squeeze()
    batched = batch_spenc(speaker_wav)
    speakerembs = ecapa(**batched)[-1]
    print("speaker embs done")
    return plbertembs, whisper_noisy, speakerembs, whisper_clean

In [None]:
# from torch.utils.data import DataLoader
# from torch.utils.data import RandomSampler

# BATCH_SIZE = 8
# # train_sampler = RandomSampler(train_data)

# train_iterator = DataLoader(train_data, batch_size=BATCH_SIZE, collate_fn=collate_fn) #, sampler=train_sampler

# whisperembclean = [f"whisperEmbs/{a}" for a in os.listdir("whisperEmbs/")[:-1]]
# whisperembnoisy = [f"whisperEmbs/{a}" for a in os.listdir("whisperEmbs/")[:-1]]
# clean = [f"miipherTestDataset/clean_testset_wav/{a}" for a in sorted(os.listdir("miipherTestDataset/clean_testset_wav/"))]
# noisy = [f"miipherTestDataset/noisy_testset_wav/{a}" for a in sorted(os.listdir("miipherTestDataset/noisy_testset_wav/"))]
# text = [f"miipherTestDataset/testset_txt/{a}" for a in sorted(os.listdir("miipherTestDataset/testset_txt"))]

# train_data = MiipherDataset(noisy, clean, text, whisperembnoisy, whisperembclean)

In [22]:
# import os
# test_wav_clean = 'miipherTestDataset/test/clean_testset_wav'
# test_wav_noisy = 'miipherTestDataset/test/noisy_testset_wav'
# test_txt = 'miipherTestDataset/test/testset_txt'
# clean = [f"{test_wav_clean}/{a}" for a in sorted(os.listdir(test_wav_clean))]
# noisy = [f"{test_wav_noisy}/{a}" for a in sorted(os.listdir(test_wav_noisy))]
# text = [f"{test_txt}/{a}" for a in sorted(os.listdir(test_txt))]
# d = [(a,b,c) for a,b,c in zip(clean,noisy,text) if 'wav' in a+b+c]

In [13]:
import pytorch_lightning as pl
class MiipherLightningModule(pl.LightningDataModule):
    def __init__(self, batch_size, collate_fn):
        super().__init__()
        self.collate_fn = collate_fn
        self.batch_size = batch_size
        self.whisperEmbTrainClean = 'whisperEmbs/clean_trainset_embs'
        self.whisperEmbTrainNoisy = 'whisperEmbs/noisy_trainset_embs'
        self.whisperEmbTestClean = 'whisperEmbs/clean_testset_embs'
        self.whisperEmbTestNoisy = 'whisperEmbs/noisy_testset_embs'
        self.train_wav_clean = 'miipherTestDataset/train/clean_trainset_wav'
        self.train_wav_noisy = 'miipherTestDataset/train/noisy_trainset_wav'
        self.train_txt = 'miipherTestDataset/train/trainset_txt'
        self.test_wav_clean = 'miipherTestDataset/test/clean_testset_wav'
        self.test_wav_noisy = 'miipherTestDataset/test/noisy_testset_wav'
        self.test_txt = 'miipherTestDataset/test/testset_txt'
    
    def setup(self, stage):
        #TRAIN
        # whisperembclean = [f"{self.whisperEmbTrainClean}/{a}" for a in sorted([f for f in os.listdir(self.whisperEmbTrainClean) if 'batch' in f], key=lambda x: int(x.split('_')[-1].split('.')[0]))]
        # whisperembnoisy = [f"{self.whisperEmbTrainNoisy}/{a}" for a in sorted([f for f in os.listdir(self.whisperEmbTrainNoisy) if 'batch' in f], key=lambda x: int(x.split('_')[-1].split('.')[0]))]
        # clean = [f"{self.train_wav_clean}/{a}" for a in sorted(os.listdir(self.train_wav_clean))]
        # noisy = [f"{self.train_wav_noisy}/{a}" for a in sorted(os.listdir(self.train_wav_noisy))]
        # text = [f"{self.train_txt}/{a}" for a in sorted(os.listdir(self.train_txt))]
        # self.miipher_train = MiipherDataset(noisy, clean, text, whisperembnoisy, whisperembclean)

        #TEST
        whisperembclean = [f"{self.whisperEmbTestClean}/{a}" for a in sorted([f for f in os.listdir(self.whisperEmbTestClean) if 'batch' in f], key=lambda x: int(x.split('_')[-1].split('.')[0]))]
        whisperembnoisy = [f"{self.whisperEmbTestNoisy}/{a}" for a in sorted([f for f in os.listdir(self.whisperEmbTestNoisy) if 'batch' in f], key=lambda x: int(x.split('_')[-1].split('.')[0]))]
        clean = [f"{self.test_wav_clean}/{a}" for a in sorted(os.listdir(self.test_wav_clean))]
        noisy = [f"{self.test_wav_noisy}/{a}" for a in sorted(os.listdir(self.test_wav_noisy))]
        text = [f"{self.test_txt}/{a}" for a in sorted(os.listdir(self.test_txt))]
        self.miipher_test = MiipherDataset(noisy, clean, text, whisperembnoisy, whisperembclean)
    
    def train_dataloader(self):
        return DataLoader(self.miipher_train, batch_size=self.batch_size, collate_fn=collate_fn)
    
    def val_dataloader(self):
        return DataLoader(self.miipher_test, batch_size=self.batch_size, collate_fn=self.collate_fn)

In [15]:
d = MiipherLightningModule(4, collate_fn)

In [18]:
d.setup("")

  0%|          | 0/824 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/822 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/822 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

In [21]:
d.miipher_test[1]

(memmap([[-0.80728567, -0.41651925,  0.22006467, ...,  1.21752286,
           0.61488628,  1.4338907 ],
         [ 0.09944215,  0.35345295,  0.40228635, ...,  0.79812413,
          -0.41561455,  0.63997132],
         [ 0.64994407,  0.40558368,  0.09346313, ..., -1.13413262,
           1.2571727 , -0.55739397],
         ...,
         [-0.05057729, -0.03693019, -0.03641871, ..., -0.0162362 ,
           0.00329678, -0.00825675],
         [-1.67652476, -0.28942066,  0.3632668 , ...,  0.26832759,
          -1.15833521,  1.09190559],
         [-0.87837517, -0.14397915,  0.12409495, ...,  0.74352312,
          -0.74356645,  0.60087365]]),
 tensor([[-0.0113, -0.0181, -0.0153,  ...,  0.0000,  0.0000,  0.0000]]),
 'Ask her to bring these things with her from the store.',
 memmap([[-0.42953297, -0.6658935 ,  0.09071337, ...,  1.18186104,
          -0.14253475,  1.30860972],
         [ 0.53896588, -0.06958443, -0.15426952, ..., -0.78864527,
           0.20755263, -0.21173766],
         [ 1.4121798

In [17]:
torch.tensor(d.miipher_test.whisper_clean[0])

tensor([[-0.2210, -0.5719, -0.3844,  ...,  1.0616, -0.4346,  1.1844],
        [ 0.8678,  0.6076,  0.0555,  ..., -0.0454, -0.4874,  0.8251],
        [ 0.7029, -0.2993,  0.0506,  ...,  0.5186, -0.7027,  0.4898],
        ...,
        [-0.0417, -0.0383, -0.0326,  ..., -0.0070,  0.0041, -0.0073],
        [-1.6980, -0.2917,  0.4673,  ...,  0.4527, -1.3656,  1.0160],
        [-0.8866, -0.4458,  0.4003,  ...,  0.6203, -0.9785,  0.3831]],
       dtype=torch.float64)

In [376]:
i = 0
for a,b,c,d in train_iterator:
    print(a.shape,b.shape,c.shape,d.shape)
    break

Here
Done with plbert, now speaker Emb
speaker embs done
torch.Size([8, 116, 768]) torch.Size([8, 1500, 768]) torch.Size([8, 192]) torch.Size([8, 1500, 768])


In [8]:
import torch
import torch.nn as nn
import numpy as np
loss = nn.L1Loss()
loss2 = nn.MSELoss()
input = torch.randn(8,1500, 768)
target = torch.randn(8,1500, 768)
zeros = torch.zeros_like(target)
output = loss(input, target)
output2 = loss2(input,target)

In [9]:
output2

tensor(1.9999)

In [None]:
a = abs(target - input)
a

In [14]:
(a**2).sum()/(8*1500*768)

tensor(1.9999)

In [12]:
loss2(target,zeros)

tensor(0.9999)

In [13]:
(target**2).sum()/np.prod(list(target.shape))

tensor(0.9999)

### Testing

In [16]:
class MiipherDataset(Dataset):
    def __init__(self, noisy_filepath, clean_filepath, text_filepath, whisperEmbfilepathNoisy, whisperEmbfilepathClean):
        self.noisyfilepaths = noisy_filepath
        self.cleanfilepaths = clean_filepath
        self.sentpaths = text_filepath
        self.spencembs = []
        self.all_sents = []
        
        #change to enable batched input
        for x in tqdm(range(len(self.noisyfilepaths))):
            noisy, clean, text = self.noisyfilepaths[x], self.cleanfilepaths[x], self.sentpaths[x]

            #Text: load the text from the file as a string
            with open(text, 'r') as f:
                self.all_sents.append(f.read().strip())
                
            #Speaker Encoder: load audio file and save the loaded array
            # ref_path = noisy
            # audio, sr = librosa.load(ref_path, sr=16000)
            # audio = np.array([audio])
            # audio_signal = torch.tensor(whisper.pad_or_trim(audio))
            self.spencembs.append(noisy)

        batchNoises = []
        for batchNoisy in tqdm(whisperEmbfilepathNoisy):
            batchNoises.append(np.load(batchNoisy, mmap_mode='r'))
        num_batches = sum([a.shape[0] for a in batchNoises])
        self.whisper_noisy = []
        for loaded in batchNoises:
            for i in tqdm(range(len(loaded))):
                self.whisper_noisy.append(loaded[i,:,:])
        
        batchCleans = []
        for batchClean in tqdm(whisperEmbfilepathClean):
            batchCleans.append(np.load(batchClean, mmap_mode='r'))
        num_batches = sum([a.shape[0] for a in batchNoises])
        self.whisper_clean = []
        for loaded in batchCleans:
            for i in tqdm(range(len(loaded))):
                self.whisper_clean.append(loaded[i,:,:])
                
    def __len__(self):
        return len(self.noisyfilepaths)

    def __getitem__(self, idx):
        return self.whisper_noisy[idx], self.spencembs[idx], self.all_sents[idx], self.whisper_clean[idx]

def collate_fn(batch):
    """
    batch = list of tuples(whisper_noisy, speaker wav, raw sentences, whisper_clean)
    """
    whisper_noisy, speaker_wav, raw_sents, whisper_clean = [],[],[],[]
    for a,b,c,d in batch:
        whisper_noisy.append(torch.tensor(a, dtype=torch.float32))
        
        ref_path = b
        audio, sr = librosa.load(ref_path, sr=16000)
        audio = np.array([audio])
        audio_signal = torch.tensor(whisper.pad_or_trim(audio))
        speaker_wav.append(audio_signal)
        
        raw_sents.append(c) 
        whisper_clean.append(torch.tensor(d,dtype=torch.float32))
    whisper_noisy = torch.stack(whisper_noisy)
    whisper_clean = torch.stack(whisper_clean)

    plbertembs = get_pltbert_embs(raw_sents, global_phonemizer, tokenizer, text_cleaner)

    
    speaker_wav = torch.stack(speaker_wav).squeeze()
    batched = batch_spenc(speaker_wav)
    speakerembs = ecapa(**batched)[-1]

    return plbertembs.cpu().detach(), whisper_noisy.cpu().detach(), speakerembs.cpu().detach(), whisper_clean.cpu().detach(), raw_sents


""" 
Lightning Module 
"""

import pytorch_lightning as pl
class MiipherLightningModule(pl.LightningDataModule):
    def __init__(self, batch_size, collate_fn):
        super().__init__()
        self.collate_fn = collate_fn
        self.batch_size = batch_size
        self.whisperEmbTrainClean = 'whisperEmbs/clean_trainset_embs'
        self.whisperEmbTrainNoisy = 'whisperEmbs/noisy_trainset_embs'
        self.whisperEmbTestClean = 'whisperEmbs/clean_testset_embs'
        self.whisperEmbTestNoisy = 'whisperEmbs/noisy_testset_embs'
        self.train_wav_clean = 'miipherTestDataset/train/clean_trainset_wav'
        self.train_wav_noisy = 'miipherTestDataset/train/noisy_trainset_wav'
        self.train_txt = 'miipherTestDataset/train/trainset_txt'
        self.test_wav_clean = 'miipherTestDataset/test/clean_testset_wav'
        self.test_wav_noisy = 'miipherTestDataset/test/noisy_testset_wav'
        self.test_txt = 'miipherTestDataset/test/testset_txt'
    
    def setup(self, stage):
        #TRAIN
        whisperembclean = [f"{self.whisperEmbTrainClean}/{a}" for a in sorted([f for f in os.listdir(self.whisperEmbTrainClean) if 'batch' in f], key=lambda x: int(x.split('_')[-1].split('.')[0]))]
        whisperembnoisy = [f"{self.whisperEmbTrainNoisy}/{a}" for a in sorted([f for f in os.listdir(self.whisperEmbTrainNoisy) if 'batch' in f], key=lambda x: int(x.split('_')[-1].split('.')[0]))]
        clean = [f"{self.train_wav_clean}/{a}" for a in sorted(os.listdir(self.train_wav_clean))]
        noisy = [f"{self.train_wav_noisy}/{a}" for a in sorted(os.listdir(self.train_wav_noisy))]
        text = [f"{self.train_txt}/{a}" for a in sorted(os.listdir(self.train_txt))]
        self.miipher_train = MiipherDataset(noisy, clean, text, whisperembnoisy, whisperembclean)

        #TEST
        whisperembclean = [f"{self.whisperEmbTestClean}/{a}" for a in sorted([f for f in os.listdir(self.whisperEmbTestClean) if 'batch' in f], key=lambda x: int(x.split('_')[-1].split('.')[0]))]
        whisperembnoisy = [f"{self.whisperEmbTestNoisy}/{a}" for a in sorted([f for f in os.listdir(self.whisperEmbTestNoisy) if 'batch' in f], key=lambda x: int(x.split('_')[-1].split('.')[0]))]
        clean = [f"{self.test_wav_clean}/{a}" for a in sorted(os.listdir(self.test_wav_clean))]
        noisy = [f"{self.test_wav_noisy}/{a}" for a in sorted(os.listdir(self.test_wav_noisy))]
        text = [f"{self.test_txt}/{a}" for a in sorted(os.listdir(self.test_txt))]
        self.miipher_test = MiipherDataset(noisy, clean, text, whisperembnoisy, whisperembclean)
    
    def train_dataloader(self):
        # print("-"*10 + "Data Loading Sanity Check" + "-"*10)
        # _, a, b, _ = self.miipher_train[5]
        # print(a,b)
        # print("-"*10 + "Data Loading Sanity Check DONE" + "-"*10)
        return DataLoader(self.miipher_train, batch_size=self.batch_size, collate_fn=collate_fn)
    
    def val_dataloader(self):
        # print("-"*10 + "Data Loading Sanity Check" + "-"*10)
        # _, a, b, _ = self.miipher_test[5]
        # print(a,b)
        # print("-"*10 + "Data Loading Sanity Check DONE" + "-"*10)
        return DataLoader(self.miipher_test, batch_size=self.batch_size, collate_fn=self.collate_fn)

In [None]:
dataTest = MiipherLightningModule(3, collate_fn)
dataTest.setup("")

In [18]:
training = dataTest.val_dataloader()

In [19]:
iterator = iter(training)

In [21]:
batch = next(iterator)
batch[-1]

    To debug try disable codegen fallback path via setting the env variable `export PYTORCH_NVFUSER_DISABLE=fallback`
     (Triggered internally at ../third_party/nvfuser/csrc/manager.cpp:335.)
      return forward_call(*args, **kwargs)
    


['She can scoop these things into three red bags, and we will go meet her Wednesday at the train station.',
 'When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow.',
 'The rainbow is a division of white light into many beautiful colors.']

In [22]:
plbert_embs, whisper_noisy, ecapa_embs, whisper_clean, raw_sents = batch
raw_sents

['She can scoop these things into three red bags, and we will go meet her Wednesday at the train station.',
 'When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow.',
 'The rainbow is a division of white light into many beautiful colors.']

In [23]:
whisper_noisy[2]

tensor([[-0.0519,  0.0268,  0.2603,  ...,  1.1658, -0.5511,  0.8337],
        [ 0.2664,  0.6931,  0.0195,  ...,  0.6229, -0.9968,  1.4332],
        [ 0.7123,  0.7288, -0.2482,  ..., -0.0896, -0.8955,  0.8289],
        ...,
        [-0.0462, -0.0370, -0.0352,  ..., -0.0165,  0.0076, -0.0127],
        [-1.6867, -0.2436,  0.3077,  ...,  0.2227, -1.0488,  1.1712],
        [-0.8994, -0.0606,  0.0365,  ...,  0.5178, -0.5199,  0.7744]])

In [24]:
whisper_clean[2]

tensor([[-0.3804, -0.2556, -0.0031,  ...,  0.8562, -0.3259,  1.2493],
        [ 0.4227,  0.6743, -0.1595,  ...,  0.0965, -0.5177,  0.8075],
        [ 0.5985,  0.7990, -0.2802,  ..., -0.0729, -0.3918,  0.4857],
        ...,
        [-0.0409, -0.0365, -0.0354,  ..., -0.0131,  0.0075, -0.0129],
        [-1.7670, -0.2528,  0.4785,  ...,  0.3382, -1.3123,  1.0631],
        [-0.9448, -0.2337,  0.3379,  ...,  0.6196, -0.8520,  0.6278]])

In [48]:
import whisper

model = whisper.load_model("small.en")

In [49]:
def convert(path):
    wav, _ = librosa.load(path, sr=16000, duration=20)
    
    # write mel and metadata to batch
    wav = whisper.pad_or_trim(wav)
    mel = whisper.log_mel_spectrogram(wav)
    return model.embed_audio(mel.reshape(1,*mel.shape).cuda())

In [51]:
a = convert("miipherTestDataset/test/clean_testset_wav/p232_007.wav")

In [34]:
torch.equal(a.detach().cpu().squeeze(),whisper_noisy[2])

False