In [37]:
from nemo.collections.asr.models import EncDecSpeakerLabelModel
from IPython.display import Audio, display
import numpy as np
import torch
import torchaudio
import torchaudio.transforms as T
import torchaudio.functional as F
import math
import os
import requests
import boto3
import wave
import sys
import contextlib
import collections
from tqdm import tqdm
from tqdm.notebook import tqdm
import yaml
from transformers import AlbertConfig, AlbertModel

In [38]:
torch.cuda.is_available()

True

# Preparation

## [PL-BERT](https://github.com/yl4579/PL-BERT) Pretrained (1M steps)


### Loading PL-BERT

In [3]:
plbert_root = "plbert/"

In [4]:
log_dir = plbert_root+"Checkpoint/"
config_path = os.path.join(log_dir, "config.yml")
plbert_config = yaml.safe_load(open(config_path))

albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
plbert = AlbertModel(albert_base_configuration)

files = os.listdir(log_dir)
ckpts = []
for f in os.listdir(log_dir):
    if f.startswith("step_"): ckpts.append(f)

iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
iters = sorted(iters)[-1]

checkpoint = torch.load(log_dir + "step_" + str(iters) + ".t7", map_location='cpu')
print(log_dir + "step_" + str(iters) + ".t7")
state_dict = checkpoint['net']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    if name.startswith('encoder.'):
        name = name[8:] # remove `encoder.`
        new_state_dict[name] = v

plbert.load_state_dict(new_state_dict)

plbert/Checkpoint/step_1000000.t7


<All keys matched successfully>

### Loading tokenizer

In [5]:
import sys
sys.path.append("plbert/")

In [6]:
from plbert.phonemize import phonemize
from phonemizer.backend import EspeakBackend
from transformers import TransfoXLTokenizer
from plbert.text_utils import TextCleaner
from plbert.text_normalize import normalize_text, remove_accents


In [7]:
global_phonemizer = EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True) #make sure brew install espeak and export location of .dylib
tokenizer = TransfoXLTokenizer.from_pretrained(plbert_config['dataset_params']['tokenizer'])
text = "And also can you please check what is the current temperature setting of your unit both fridge and the freezer?"
text_cleaner = TextCleaner()

177


In [8]:
#I may need to make this function be able to batch input together (or I can later create a collate function)
#pad token = '$'
def tokenize(sents, global_phonemizer, tokenizer, text_cleaner):
    batched = []
    max_id_length = 0
    lengths = []
    for sent in sents:
        pretextcleaned = ' '.join(phonemize(sent, global_phonemizer, tokenizer)['phonemes'])
        cleaned = text_cleaner(pretextcleaned)
        batched.append(torch.LongTensor(cleaned))
        max_id_length = max(max_id_length, len(cleaned))
    phoneme_ids = torch.zeros((len(sents), max_id_length)).long()
    mask = torch.zeros((len(sents), max_id_length)).float()
    for i, c in enumerate(batched):
        phoneme_ids[i,:len(c)] = c
        mask[i,:len(c)] = 1
    return phoneme_ids, mask

In [9]:
def get_pltbert_embs(s, global_phonemizer, tokenizer, text_cleaner):
    """
    Input: list of texts

    Output: output of pretrained Albert model - (batch_size, num_tokens, 768)
    """
    phoneme_ids, attention_mask = tokenize(s, global_phonemizer, tokenizer, text_cleaner)
    return plbert(phoneme_ids, attention_mask=attention_mask).last_hidden_state, attention_mask

In [116]:
get_pltbert_embs([text,text,text], global_phonemizer, tokenizer, text_cleaner).shape

torch.Size([3, 116, 768])

## Whisper

### HuggingFace

In [8]:
from transformers import WhisperProcessor, WhisperModel
# from datasets import load_dataset

processor = WhisperProcessor.from_pretrained("openai/whisper-base")
model = WhisperModel.from_pretrained("openai/whisper-base")

import librosa
ref_waveform, sr = librosa.load("/Users/ajaybati/Downloads/0.23.0612.1.AT.PHL.alorica/results/originalTestFiles/maleb2.wav", sr=16000)
input_features = processor(ref_waveform, sampling_rate=16000, return_tensors="pt").input_features 

input_features.shape

decoder_input_ids = torch.tensor([[1,1]]) * model.config.decoder_start_token_id

out = model(input_features=input_features, decoder_input_ids=decoder_input_ids)

out['encoder_last_hidden_state'].shape

torch.Size([1, 1500, 512])

### Main Whisper

In [3]:
import whisper

model = whisper.load_model("small.en")

# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio("miipherTestDataset/train/clean_trainset_wav/p234_003.wav")
audio = whisper.pad_or_trim(audio)

# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio)

      def backtrace(trace: np.ndarray):
    


In [5]:
a = model.embed_audio(mel.reshape(1,*mel.shape).cuda())

In [6]:
a.dtype

torch.float32

In [19]:
model.dims

ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=768, n_audio_head=12, n_audio_layer=12, n_vocab=51864, n_text_ctx=448, n_text_state=768, n_text_head=12, n_text_layer=12)

## ECAPA/SpeakerNet Embeddings

### NeMo

In [None]:
from nemo.collections.asr.models import EncDecSpeakerLabelModel
speaker_encoder = EncDecSpeakerLabelModel.from_pretrained(model_name="speakerverification_speakernet")
ecapa = EncDecSpeakerLabelModel.from_pretrained(model_name='ecapa_tdnn')

ecapa.eval()
speaker_encoder.eval()

def get_emb(model, wav_path):
    return model.get_embedding(wav_path)

### JIT traced

In [20]:
import librosa
ref_path = "miipherTestDataset/train/clean_trainset_wav/p234_003.wav"
audio, sr = librosa.load(ref_path, sr=16000)
audio = np.array([audio])
audio_length = audio.shape[-1]
audio_signal, audio_signal_len = (
    torch.tensor(audio),
    torch.tensor([audio_length])
)
kwarg = {'input_signal': audio_signal.cuda(), 'input_signal_length': audio_signal_len.cuda()} #batched input

# spNet = torch.jit.load("/Users/ajaybati/Downloads/0.23.0612.1.AT.PHL.alorica/spNet_traced.jit")
ecapa = torch.jit.load("ecapa2_traced.jit")

ecapa(**kwarg)[-1].shape #logits, embedding

torch.Size([1, 192])

## Conformer Blocks (loading 1 block)

In [10]:
from conformer.conformer.encoder import ConformerBlock
"""
Conformer block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module
and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing
the original feed-forward layer in the Transformer block into two half-step feed-forward layers,
one before the attention layer and one after.

Args:
    encoder_dim (int, optional): Dimension of conformer encoder
    num_attention_heads (int, optional): Number of attention heads
    feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
    conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
    feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
    attention_dropout_p (float, optional): Probability of attention module dropout
    conv_dropout_p (float, optional): Probability of conformer convolution module dropout
    conv_kernel_size (int or tuple, optional): Size of the convolving kernel
    half_step_residual (bool): Flag indication whether to use half step residual or not

Inputs: inputs
    - **inputs** (batch, time, dim): Tensor containing input vector

Returns: outputs
    - **outputs** (batch, time, dim): Tensor produces by conformer block.
"""
ConformerBlock(
    encoder_dim: int = 512,
    num_attention_heads: int = 8,
    feed_forward_expansion_factor: int = 4,
    conv_expansion_factor: int = 2,
    feed_forward_dropout_p: float = 0.1,
    attention_dropout_p: float = 0.1,
    conv_dropout_p: float = 0.1,
    conv_kernel_size: int = 31,
    half_step_residual: bool = True
)

SyntaxError: invalid syntax (2324367004.py, line 26)

## Post-Net
https://github.com/NVIDIA/tacotron2/blob/master/hparams.py

In [40]:
class ConvNorm(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear'):
        super(ConvNorm, self).__init__()
        if padding is None:
            assert(kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        self.conv = torch.nn.Conv1d(in_channels, out_channels,
                                    kernel_size=kernel_size, stride=stride,
                                    padding=padding, dilation=dilation,
                                    bias=bias)

        torch.nn.init.xavier_uniform_(
            self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, signal):
        conv_signal = self.conv(signal)
        return conv_signal

In [41]:
import torch
from torch import nn
from torch.nn import functional as F
class PostNet(nn.Module):
    """Postnet
        - Five 1-d convolution with 512 channels and kernel size 5
    """

    def __init__(self, hparams):
        super(PostNet, self).__init__()
        self.convolutions = nn.ModuleList()

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
                         kernel_size=hparams.postnet_kernel_size, stride=1,
                         padding=int((hparams.postnet_kernel_size - 1) / 2),
                         dilation=1, w_init_gain='tanh'),
                nn.BatchNorm1d(hparams.postnet_embedding_dim))
        )

        for i in range(1, hparams.postnet_n_convolutions - 1):
            self.convolutions.append(
                nn.Sequential(
                    ConvNorm(hparams.postnet_embedding_dim,
                             hparams.postnet_embedding_dim,
                             kernel_size=hparams.postnet_kernel_size, stride=1,
                             padding=int((hparams.postnet_kernel_size - 1) / 2),
                             dilation=1, w_init_gain='tanh'),
                    nn.BatchNorm1d(hparams.postnet_embedding_dim))
            )

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
                         kernel_size=hparams.postnet_kernel_size, stride=1,
                         padding=int((hparams.postnet_kernel_size - 1) / 2),
                         dilation=1, w_init_gain='linear'),
                nn.BatchNorm1d(hparams.n_mel_channels))
            )

    def forward(self, x):
        for i in range(len(self.convolutions) - 1):
            x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
        x = F.dropout(self.convolutions[-1](x), 0.5, self.training)

        return x

# Miipher2.0

### Model Setup

In [12]:
from conformer.conformer.encoder import ConformerBlock

import torch
import torch.nn as nn
from torch import Tensor
import pytorch_lightning as pl

In [24]:
class Miipher2(pl.LightningModule):
    """
    Args:
        num_classes (int): Number of classification classes

    Inputs: inputs, input_lengths
        - plbertDim: feature dimension of pl-bert embeddings 
        - spEncDim: feature dimension of speaker encoder embeddings 
        - whisperLen: first dimension (time steps) of whisper embeddings
        - whisperDim: feature dimension of whisper embeddings
        - modelDim: hidden_size that most of the model operates on
        - crossAttHeads: number of heads in cross attention layer
        - conformerBlockSettings: Conformer Block hyper-parameters (look above in ConformerBlock section)
        - hparams: PostNet hyper-parameters, look here: https://github.com/NVIDIA/tacotron2/blob/master/hparams.py

    Returns: outputs, output_lengths
        - **outputs**
        - **output_lengths** 
    """
    def __init__(self, plbertDim, spEncDim, whisperLen, whisperDim, modelDim,
                 crossAttHeads, crossAttDim, conformerBlockSettings, hparams, learning_rate=1e-4) -> None:
        super(Miipher2, self).__init__()

        self.learning_rate = learning_rate
        #init 3 linear layers
        self.plbert_lin = nn.Linear(plbertDim, modelDim)
        self.whisper_lin = nn.Linear(whisperDim, modelDim)
                         
        #Stack 4 times:
            #init cross attention module - input (128), hidden_size (512)
            #init Conformer Block - attention -> input (128), hidden_size (512)
            #init layer norm
        self.cross_attention = nn.ModuleList([])
        self.layer_norm = nn.ModuleList([])
        self.conformerBlock = nn.ModuleList([])
        for x in range(2): #changed 1->2
            # nn.Linear(modelDim, crossAttDim), 
            # nn.Linear(modelDim, crossAttDim), 
            # nn.Linear(modelDim, crossAttDim), 
            self.cross_attention.append(nn.MultiheadAttention(modelDim, crossAttHeads, batch_first=True)) #changed->dropout=0
            self.layer_norm.append(nn.LayerNorm([whisperLen, modelDim]))
            self.conformerBlock.append(ConformerBlock(*conformerBlockSettings))

        #make sure to reset hparams while loading it. input dimension to postnet is whisperDim => n_mel_channels=whisperDim. 
        #maybe change postnet_embedding_dim
        self.postnet = PostNet(hparams)
        self.whisper_proj = nn.Linear(modelDim, whisperDim)
        self.layer_norm2 = nn.LayerNorm([whisperLen, modelDim])
        #LOSS
        self.l2_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()


    
    def forward(self, plbert_embs, spEnc_embs, whisper_embs, attention_mask):
        """
        Inputs: plbert_embs, spEnc_embs, whisper_embs
            - plbert_embs (Tensor[batch_size, num_tokens, hidden_size]): png-bert replacement; Phoneme-level embeddings from transcript
            - spEnc_embs (Tensor[batch_size, hidden_size]): Speaker Encoder Embeddings (speakerNet/Ecapa)
            - whisper_embs (Tensor[batch_size, seq_len, hidden_size]): w2v-bert replacement; Whisper encoder-level embeddings of audio
    
        Returns: outputs
            - outputs (Tensor[batch_size, seq_len, hidden_size]): predicted, clean whisper-level embeddings
        """
        #all linear output will be hidden_dim=modelDim
        plbert_out = self.plbert_lin(plbert_embs) #(b, num_tokens, model_dim)

        whisper_out = self.whisper_lin(whisper_embs)
        
        film1_out = plbert_out
        for x in range(2): #changed 1->2
            att = self.cross_attention[x] #need to do attention masking here
            query, key, value = whisper_out, film1_out, film1_out
            att_out, _ = att(query, key, value, key_padding_mask=(1-attention_mask))

            whisper_out = whisper_out + att_out
            layer_out = self.layer_norm[x](whisper_out)
            conf_out = self.conformerBlock[x](layer_out)
            whisper_out = whisper_out + conf_out

        # self.layer_norm2(whisper_out)
        whisper_out = self.whisper_proj(whisper_out)
        whisper_permute = torch.permute(whisper_out, (0,2,1))
        post_out = self.postnet(whisper_permute)
        output = whisper_out + torch.permute(post_out, (0,2,1))
        return output

    def norm_loss(self, gt, preds):
        """
        Inputs: gt, preds
            - gt(Tensor[batch_size, seq_len, hidden_size])
            - preds(Tensor[batch_size, seq_len, hidden_size])
    
        Returns: loss (reduction='mean')
            - loss (Tensor[1]) = (1 norm + 2 norm + spectral convergence), reduced
        """
        norm1 = self.l1_loss(preds, gt)
        norm2 = self.l2_loss(preds, gt)
        spectr = norm2/((gt**2).sum()/np.prod(list(gt.shape)))
        loss = norm1 + norm2 + spectr
        return loss, norm1, norm2, spectr
    
    def training_step(self, train_batch, batch_idx):
        plbertembs, whisper_noisy, speakerembs, whisper_clean, att_mask = train_batch
        logits = self.forward(plbertembs, speakerembs, whisper_noisy, att_mask)
        loss, norm1, norm2, spectr = self.norm_loss(logits, whisper_clean)
        # print("-"*10+"DEBUGGING"+"-"*10)
        # print(logits)
        # print("*"*20)
        # print(whisper_clean)
        # print("*"*20)
        # print(loss, norm1, norm2, spectr)
        # print("-"*10+"DEBUGGING done"+"-"*10)
        self.log_dict(
            {'train_loss':loss,
             'train_norm1_loss':norm1,
             'train_norm2_loss':norm2,
             'train_spectral':spectr}, 
            on_step=True, 
            on_epoch=True, 
            prog_bar=True
        )
        return norm2 #changed from loss->norm2->loss
    
    def validation_step(self, val_batch, batch_idx):
        plbertembs, whisper_noisy, speakerembs, whisper_clean, att_mask = val_batch
        logits = self.forward(plbertembs, speakerembs, whisper_noisy, att_mask)
        loss, norm1, norm2, spectr = self.norm_loss(logits, whisper_clean)
        self.log_dict(
            {'test_loss':loss,
             'test_norm1_loss':norm1,
             'test_norm2_loss':norm2,
             'test_spectral':spectr}, 
            on_step=True, 
            on_epoch=True, 
            prog_bar=True
        )
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
        # return {"optimizer": optimizer, "lr_scheduler": {
        #     "scheduler": torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=0.5, total_iters=7)}}


    # def on_before_optimizer_step(self, optimizer, optimizer_idx):
    #     # Compute the 2-norm for each layer
    #     # If using mixed precision, the gradients are already unscaled here
    #     norms = grad_norm(self.layer, norm_type=2)
    #     self.log_dict(norms)

### Model Instantiation

In [25]:
import sys
sys.path.append("plbert/")

In [26]:
from hparams import create_hparams
from hparams import HParams

In [27]:
MODEL_DIM = 512 #changed from 256->512
whisperLen = 1500
whisper_dim = 768
plbert_dim = 768
speakernet_dim = 192
crossAttHeads = 8 #changed from 16->8
crossAttDim = 256

hparams = create_hparams()
hparams.n_mel_channels = whisper_dim
hparams.postnet_embedding_dim = 512

conformerBlockSettings = HParams(encoder_dim = MODEL_DIM,
    num_attention_heads = 4, #changed from 8->4
    feed_forward_expansion_factor = 2, #changed from 4->2
    conv_expansion_factor = 2,
    #removed all dropouts
    feed_forward_dropout_p = 0,
    attention_dropout_p = 0,
    conv_dropout_p = 0,
    conv_kernel_size = 7, #changed from 31->11->3->7
    half_step_residual = True)
conformerBlockSettings = conformerBlockSettings.tup()

### Evaluation

In [None]:
miipher = Miipher2.load_from_checkpoint("checkpoints_test4iter3/epoch=79-step=461520-train_loss=0.04.ckpt",
                                        plbertDim=plbert_dim, 
                                        spEncDim=speakernet_dim, 
                                        whisperLen=whisperLen, 
                                        whisperDim=whisper_dim, 
                                        modelDim=MODEL_DIM,
                                        crossAttHeads=crossAttHeads, 
                                        crossAttDim=crossAttDim, 
                                        conformerBlockSettings=conformerBlockSettings, 
                                        hparams=hparams)
miipher.eval()

In [18]:
import whisper

whispermodel = whisper.load_model("small.en")

# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio("miipherTestDataset/test/noisy_testset_wav/p232_104.wav")
audio = whisper.pad_or_trim(audio)

# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio)
a = whispermodel.embed_audio(mel.reshape(1,*mel.shape).cuda()).detach().cpu()

      def backtrace(trace: np.ndarray):
    


In [19]:
import torch
vocoder = torch.jit.load("gs-vec2wav-base_19_July_23.jit")
spNet = torch.jit.load("spNet_traced.jit")
ecapa = torch.jit.load("ecapa_traced.jit")

In [20]:
import librosa
def evaluate(model, text, wavPath):
    #get plbert
    plbert_embs, att_mask = get_pltbert_embs([text], global_phonemizer, tokenizer, text_cleaner)
    #get whisper
    wav, _ = librosa.load(wavPath, sr=16000, duration=20)
    audio = whisper.pad_or_trim(wav)
    
    mel = whisper.log_mel_spectrogram(audio)
    a = whispermodel.embed_audio(mel.reshape(1,*mel.shape).cuda()).detach().cpu()
    
    #get spenc
    kwarg = {'input_signal': torch.tensor(audio).reshape((1,-1)), 'input_signal_length': torch.tensor([audio.shape[-1]])}
    spnet_embs = spNet(**kwarg)[-1]
    ecapa_embs = ecapa(**kwarg)[-1]
    # ecapa_embs = torch.zeros(*ecapa_embs.shape)


    return model(plbert_embs, ecapa_embs, a, att_mask), spnet_embs, a

def convert(vocoder, spEnc, modelOut, path):
    
    out = vocoder(torch.permute(modelOut,(0,2,1)), spEnc.reshape((1,-1,1)))
    torchaudio.save(path, out.squeeze(0), 16000)
    
    

In [54]:
#"miipherTestDataset/train/trainset_txt/p299_148.txt"
#"miipherTestDataset/train/noisy_traainset_wav/p299_148.wav"
with open("unseen/sp07_exhibition_sn15.txt", 'r') as f:
    text = "hi my name is dumbo"
print(text)
wavPath = "unseen/sp07_exhibition_sn15.wav"

clean_whisper_out, spnet_embs, noisy = evaluate(miipher, text, wavPath)

hi my name is dumbo


In [56]:
convert(vocoder, spnet_embs, clean_whisper_out, "vocoderTest3.wav")

In [90]:
wav, _ = librosa.load(wavPath.replace("noisy","clean"), sr=16000, duration=20)
audio = whisper.pad_or_trim(wav)

# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio)
actual = whispermodel.embed_audio(mel.reshape(1,*mel.shape).cuda()).detach().cpu()

In [91]:
miipher.norm_loss(actual, clean_whisper_out)[0]

tensor(0.3413, grad_fn=<AddBackward0>)