<a href="https://colab.research.google.com/github/Annesya/voice-speech-metamers/blob/master/model_two_task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install speechbrain



In [2]:
%%capture
!pip install datasets -U
!pip install librosa
!pip install jiwer

In [3]:
from speechbrain.inference.speaker import EncoderClassifier
import torch
from collections import defaultdict
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from transformers import WhisperForConditionalGeneration, WhisperProcessor, AutoFeatureExtractor, WhisperModel, WhisperTokenizer
from datasets import load_dataset, DatasetDict, Audio, load_metric

# ***DATA*** ***PREPARATION***

In [4]:
common_voice = DatasetDict()

common_voice_train = load_dataset("fsicoli/common_voice_17_0", "ja", split="train")
common_voice_test = load_dataset("fsicoli/common_voice_17_0", "ja", split="test")

print(common_voice)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


DatasetDict({
    
})


In [5]:
common_voice_train = common_voice_train.remove_columns(["accent", "age", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "down_votes", "gender", "locale", "segment", "up_votes"])

In [6]:
common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))

## ***Word Vocabulary Processing***

In [7]:
import re
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\⋯\、\。\《\》\「\」\！\（\）\，\：\；\？\～\|]'
# Feel free to add more unwanted symbols

def remove_special_characters(batch):
    batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
    return batch

In [8]:
common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)

In [None]:
## Inlcude text normalization if needed

In [9]:
def extract_all_chars(batch):
  all_text = " ".join(batch["sentence"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

In [10]:
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

Map:   0%|          | 0/10039 [00:00<?, ? examples/s]

Map:   0%|          | 0/6261 [00:00<?, ? examples/s]

In [11]:
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
print(len(vocab_dict))
print()
print(vocab_dict)

2610

{' ': 0, '&': 1, '(': 2, ')': 3, '/': 4, '[': 5, ']': 6, 'a': 7, 'b': 8, 'c': 9, 'd': 10, 'e': 11, 'f': 12, 'g': 13, 'h': 14, 'i': 15, 'j': 16, 'k': 17, 'l': 18, 'm': 19, 'n': 20, 'o': 21, 'p': 22, 'q': 23, 'r': 24, 's': 25, 't': 26, 'u': 27, 'v': 28, 'w': 29, 'x': 30, 'y': 31, 'z': 32, '–': 33, '—': 34, '―': 35, '’': 36, '…': 37, '☆': 38, '♡': 39, '々': 40, '〇': 41, '〈': 42, '〉': 43, '『': 44, '』': 45, '〜': 46, 'ぁ': 47, 'あ': 48, 'ぃ': 49, 'い': 50, 'ぅ': 51, 'う': 52, 'ぇ': 53, 'え': 54, 'ぉ': 55, 'お': 56, 'か': 57, 'が': 58, 'き': 59, 'ぎ': 60, 'く': 61, 'ぐ': 62, 'け': 63, 'げ': 64, 'こ': 65, 'ご': 66, 'さ': 67, 'ざ': 68, 'し': 69, 'じ': 70, 'す': 71, 'ず': 72, 'せ': 73, 'ぜ': 74, 'そ': 75, 'ぞ': 76, 'た': 77, 'だ': 78, 'ち': 79, 'っ': 80, 'つ': 81, 'づ': 82, 'て': 83, 'で': 84, 'と': 85, 'ど': 86, 'な': 87, 'に': 88, 'ぬ': 89, 'ね': 90, 'の': 91, 'は': 92, 'ば': 93, 'ぱ': 94, 'ひ': 95, 'び': 96, 'ぴ': 97, 'ふ': 98, 'ぶ': 99, 'ぷ': 100, 'へ': 101, 'べ': 102, 'ぺ': 103, 'ほ': 104, 'ぼ': 105, 'ぽ': 106, 'ま': 107, 'み': 108, 'む': 109, 'め'

In [12]:
vocab_dict["^"] = vocab_dict[" "] # | is a valid punctuation in bengali, equivalent to full stop (".")
del vocab_dict[" "]

In [13]:
# truncating the test set size to 1024
if len(common_voice_test) > 1024:
    common_voice_test_full = common_voice_test
    common_voice_test = common_voice_test.select(range(1024))

## ***Speaker ID processing***

In [14]:
speaker_ids_train = [common_voice_train[i]["client_id"] for i in range(len(common_voice_train))]
#speaker_ids_test = common_voice_test.features["speaker_id"].feature.names

In [15]:
# Count the number of unique strings
speaker_ids_unique = sorted(set(speaker_ids_train))

In [16]:
## map the speaker ids to the corresponding labels
speaker_id_map = {speaker_id: i for i, speaker_id in enumerate(speaker_ids_unique)}

In [17]:
len(speaker_ids_unique[0])

128

## ***Prepare data loader***

In [18]:
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Japanese", task="transcribe")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [19]:
def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched"
    batch["input_values"] = audio["array"]
    batch["input_length"] = len(batch["input_values"])
    batch["word_labels"] = tokenizer(batch["sentence"]).input_ids
    batch["client_id"] = batch["client_id"].split("\n")
    #print(batch["client_id"])
    #batch["client_id"] = batch["client_id"].tolist()
    # batch["speaker_labels"] = []
    # for i, id in enumerate(batch["client_id"]):
    #   print(id)
    #   print(speaker_id_map[id])
    #   batch["speaker_labels"].append(speaker_id_map[id])
    batch["speaker_labels"] = [speaker_id_map[speaker_id] for speaker_id in batch["client_id"]]
    return batch

Speaker IDs in the train data does not always match speaker id in the test data

In [20]:
common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
# common_voice_test_1 = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)

In [21]:
# filtering out longer inputs
max_input_length_in_sec = 5.0
sampling_rate = 16000

common_voice_train = common_voice_train.filter(
    lambda x: x < max_input_length_in_sec * sampling_rate,
    input_columns=["input_length"]
)
print(len(common_voice_train))

duration = 0
for i in range(len(common_voice_train)):
    duration += common_voice_train[i]["input_length"] / sampling_rate
print(f"Train set duration: {duration / 3600:.2f} hours")

common_voice_test = common_voice_test.filter(
    lambda x: x < max_input_length_in_sec * sampling_rate,
    input_columns=["input_length"]
)
print(len(common_voice_test))

Filter:   0%|          | 0/10039 [00:00<?, ? examples/s]

5993
Train set duration: 6.28 hours


ValueError: Input column ['input_length'] not in the dataset. Current columns in the dataset: ['client_id', 'path', 'audio', 'sentence', 'variant']

## ***Define Data collator***

In [48]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union

from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Japanese", task="transcribe")

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: processor
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_values"]} for feature in features]
        # print(input_features.keys())
        #pad the inputs to max length
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["word_labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["word_labels"] = labels

        return batch


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [49]:
model_whisper = WhisperModel.from_pretrained("openai/whisper-base")
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, decoder_start_token_id=model_whisper.config.decoder_start_token_id)

In [50]:
train_dataloader = DataLoader(common_voice_train, batch_size=32, shuffle=True, collate_fn=data_collator)
for batch in train_dataloader:
    break

print({k:v.shape for k,v in batch.items()})

ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (2,2)  and requested shape (1,2)

## ***Defining evaluation matrices***

In [24]:
metric = load_metric("wer")

  metric = load_metric("wer")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [25]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}


# ***MODEL BUILDING***

In [None]:
## ECAPA encoding
# classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb")
# signal = common_voice_train[0]["audio"]["array"]
# # fs = common_voice_train[0]["audio"]["sampling_rate"]
# embeddings = classifier.encode_batch(torch.tensor(signal))

# ## Whisper Encoding
# model = WhisperModel.from_pretrained("openai/whisper-base")
# feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")

# inputs = feature_extractor(common_voice_train[0]["audio"]["array"], sampling_rate=common_voice_train[0]["audio"]["sampling_rate"], return_tensors="pt")
# input_features = inputs.input_features
# decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
# last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).encoder_last_hidden_state
# list(last_hidden_state.shape)

In [None]:
decoder_input_ids = torch.tensor([[1, 1]]) * model_whisper.config.decoder_start_token_id
decoder_input_ids

tensor([[50258, 50258]])

In [None]:
last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids)

In [None]:
model

WhisperModel(
  (encoder): WhisperEncoder(
    (conv1): Conv1d(80, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(512, 512, kernel_size=(3,), stride=(2,), padding=(1,))
    (embed_positions): Embedding(1500, 512)
    (layers): ModuleList(
      (0-5): 6 x WhisperEncoderLayer(
        (self_attn): WhisperSdpaAttention(
          (k_proj): Linear(in_features=512, out_features=512, bias=False)
          (v_proj): Linear(in_features=512, out_features=512, bias=True)
          (q_proj): Linear(in_features=512, out_features=512, bias=True)
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELUActivation()
        (fc1): Linear(in_features=512, out_features=2048, bias=True)
        (fc2): Linear(in_features=2048, out_features=512, bias=True)
        (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    

Embedding shape from Whisper: Batch X 1500 X 512 --> Batch X 1500 X 64

Embedding shape from ECAPA: Batch X 1 X 192 --> Batch X 1 X 64

In [None]:
# Define your downsampling convolutional layer
downsample_conv = nn.Conv1d(in_channels=512, out_channels=64, kernel_size=1)

# Assuming your input tensor is named input_tensor
# Reshape the tensor to fit the convolutional layer
input_tensor = torch.ones(2,1500,512)
input_tensor = input_tensor.permute(0, 2, 1)  # Change shape to batch * 512 * 1500

# Apply convolution
output_tensor = downsample_conv(input_tensor)

# Check the shape
print(output_tensor.shape)  # Should be batch * 64 * 500


torch.Size([2, 64, 1500])


In [None]:
x = torch.ones(2, 1500, 512)
x = torch.mean(x,dim=1)
print(x.shape)

torch.Size([2, 512])


In [26]:
class SpeechModel(torch.nn.Module):
    def __init__(self):
        super(SpeechModel, self).__init__()
        self.sampling_rate = 16000
        self.num_speaker_class = 578 # change this after analyzing dataset
        self.word_vocab = 2610 # change this after analyzing dataset -> len(vocab_dict)
        self.whisper_encoder = WhisperModel.from_pretrained("openai/whisper-base")
        self.decoder_input_ids = torch.tensor([[1, 1]]) * self.whisper_encoder.config.decoder_start_token_id
        self.whisper_feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
        self.ecapa_encoder = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb")

        # Define downsampling layers
        self.whisper_downsample = nn.Conv1d(in_channels=512, out_channels=64, kernel_size=1)
        self.ecapa_downsample = nn.Conv1d(in_channels=192, out_channels=64, kernel_size=1)

        # Define Transformer layers
        self.transformer_encoder_single = nn.TransformerEncoderLayer(d_model=128, nhead=8, dim_feedforward=512, batch_first=True)
        self.transformer_decoder_single = nn.TransformerDecoderLayer(d_model=128, nhead=8, dim_feedforward=512, batch_first=True)
        self.transformer_encoder_stack = nn.TransformerEncoder(self.transformer_encoder_single, num_layers=6)
        self.transformer_decoder_stack = nn.TransformerDecoder(self.transformer_decoder_single, num_layers=6)

        self.decoder_input_ids = torch.tensor([[1, 1]]) * self.whisper_encoder.config.decoder_start_token_id

        # Define prediction heads
        self.next_word_prediction_head = nn.Linear(128,self.word_vocab)
        self.speaker_recognition_head = nn.Linear(128,self.num_speaker_class)
        self.speaker_softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # Forward pass for whisper branch
        x_whisper = self.whisper_feature_extractor(x, sampling_rate=self.sampling_rate, return_tensors="pt").input_features
        whisper_embedding = self.whisper_encoder(x_whisper, decoder_input_ids=self.decoder_input_ids).encoder_last_hidden_state
        whisper_embedding = whisper_embedding.permute(0, 2, 1)
        downsampled_whisper_embedding = self.whisper_downsample(whisper_embedding)
        downsampled_whisper_embedding = downsampled_whisper_embedding.permute(0, 2, 1)
        downsampled_whisper_embedding = torch.mean(downsampled_whisper_embedding, dim=1)

        # Forward pass for ECAPA branch
        ecapa_embedding = self.ecapa_encoder.encode_batch(x)
        ecapa_embedding = ecapa_embedding.permute(0, 2, 1)
        downsampled_ecapa_embedding = self.ecapa_downsample(ecapa_embedding)
        downsampled_ecapa_embedding = downsampled_ecapa_embedding.permute(0, 2, 1)
        downsampled_ecapa_embedding = torch.squeeze(downsampled_ecapa_embedding,dim=1)

        # Concatenate downscaled embeddings
        concatenated_embeddings = torch.cat((downsampled_whisper_embedding, downsampled_ecapa_embedding), dim=-1)

        # Transformer layers
        transformer_output = self.transformer_encoder_stack(concatenated_embeddings)
        #transformer_decoder_output = self.transformer_decoder_stack(self.decoder_input_ids, transformer_output)
        transformer_decoder_output = self.transformer_decoder_stack(concatenated_embeddings, transformer_output)

        # Task-specific heads
        word_prediction = self.next_word_prediction_head(transformer_decoder_output)
        speaker_recognition = self.speaker_recognition_head(transformer_output)
        speaker_recognition = self.speaker_softmax(speaker_recognition)

        #next_word_prediction = self.next_word_prediction_head(transformer_output

        return word_prediction, speaker_recognition

# Define model
model = SpeechModel()
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
input_data = torch.tensor(ds[0]["audio"]["array"])
output_word, output_speaker = model.forward(input_data)
print(output_word.shape)
print(output_speaker.shape)


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


torch.Size([1, 2610])
torch.Size([1, 578])


In [27]:
# Define loss functions
next_word_loss_function = nn.CTCLoss()
speaker_recognition_loss_function = nn.CrossEntropyLoss()

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Load dataset
# dataset = CommonVoiceDataset(root_dir='path_to_commonvoice_dataset', transform=transform)

# # Training loop
# num_epochs = 3
# for epoch in range(num_epochs):
#     for batch in dataloader:
#         optimizer.zero_grad()
#         inputs, targets = batch
#         next_word_prediction, speaker_recognition = model(inputs)

#         # Calculate loss
#         next_word_loss = next_word_loss_function(next_word_prediction, targets)
#         # Calculate speaker recognition loss
#         speaker_recognition_loss = speaker_recognition_loss_function(speaker_recognition, speaker_labels)

#         total_loss = next_word_loss + speaker_recognition_loss

#         # Backpropagation
#         total_loss.backward()
#         optimizer.step()

RuntimeError: each element in list of batch should be of equal size

In [32]:
common_voice_train

Dataset({
    features: ['input_values', 'input_length', 'word_labels', 'speaker_labels'],
    num_rows: 10039
})

In [29]:
common_voice_train['input_length'][9]

50304