In [None]:
from e2e_st.trainer import STDataset, SpecConfig
from torch.utils.data import DataLoader
from e2e_st.utils.attention_masks import key_padding_mask, causal_mask
import matplotlib.pyplot as plt

In [None]:
import torchaudio
print("Available backends:", torchaudio.list_audio_backends())

In [None]:
from e2e_st.text.tokenizer import CustomTokenizer
from transformers import AutoTokenizer

# Register as a fast tokenizer in the second parameter
AutoTokenizer.register("custom", None, CustomTokenizer)


In [None]:
tokenizer = AutoTokenizer.from_pretrained("alexgichamba/iwslt25_uncased_4096", use_fast=True)

In [None]:
# find vocab size
print(tokenizer.vocab_size)

In [None]:
print(tokenizer.bem_lang_token, tokenizer.eng_lang_token, tokenizer.fra_lang_token, tokenizer.fon_lang_token)
print(tokenizer.bem_lang_token_id, tokenizer.eng_lang_token_id, tokenizer.fra_lang_token_id, tokenizer.fon_lang_token_id)

In [None]:
print(tokenizer.tokenize("I shall also refer the matter to the College of Quaestors, and I am certain that they will be keen to ensure that we comply with the regulations we ourselves vote on.".lower()))
print(len(tokenizer.tokenize("I shall also refer the matter to the College of Quaestors, and I am certain that they will be keen to ensure that we comply with the regulations we ourselves vote on.".lower())))

In [None]:
print(tokenizer.tokenize("Je vais soumettre également le problème au Collège des questeurs et je suis certaine que nos questeurs auront à cur de faire en sorte que nous respections la réglementation qu' en effet nous votons.".lower()))
print(len(tokenizer.tokenize("Je vais soumettre également le problème au Collège des questeurs et je suis certaine que nos questeurs auront à cur de faire en sorte que nous respections la réglementation qu' en effet nous votons.".lower())))

In [None]:
print(tokenizer.tokenize("Ée yě ɖɔ mɔ̌ ɔ́, Mɔyízi lɛ́ kɔ bó yi ɖɔ nú Mawu Mavɔmavɔ ɖɔ: \"Aklúnɔ, étɛ́wú a wa nǔ xá togun élɔ́?".lower()))
print(len(tokenizer.tokenize("Ée yě ɖɔ mɔ̌ ɔ́, Mɔyízi lɛ́ kɔ bó yi ɖɔ nú Mawu Mavɔmavɔ ɖɔ: \"Aklúnɔ, étɛ́wú a wa nǔ xá togun élɔ́?".lower())))

In [None]:
print(tokenizer.tokenize("\"Pa kuti kasebanya naikila pali imwe, ali ne cipyu cickalamba, pa kwishibo kuti ali ne nshita inono fye.\" - Ukusokoloa 12:12.".lower()))
print(len(tokenizer.tokenize("\"Pa kuti kasebanya naikila pali imwe, ali ne cipyu cickalamba, pa kwishibo kuti ali ne nshita inono fye.\" - Ukusokoloa 12:12.".lower())))

In [None]:
# make spec config instance
spec_config = SpecConfig(
    n_mels=80,
    hop_length=256,
    n_fft=1024,
    sample_rate=16000
)

In [None]:
sample_dataset = STDataset(dataset_json="../corpora/train.json",
                           tokenizer=tokenizer,
                            spec_config=spec_config,
                            case_standardization="lower")

sample_loader = DataLoader(sample_dataset, batch_size=8, collate_fn=sample_dataset.collate_fn, shuffle=True)

In [None]:
for batch in sample_loader:
    mels = batch["mel"]
    speech_lengths = batch["speech_lengths"]
    text_lengths = batch["text_lengths"]
    input_tokens = batch["input_tokens"]
    st_target_tokens = batch["st_target_tokens"]
    asr_target_tokens = batch["asr_target_tokens"]

    print(f"mel shape: {mels.shape}")
    print(f"text shape: {asr_target_tokens.shape}")
    print(f"speech lengths: {speech_lengths}")
    print(f"text lengths: {text_lengths}")

    for i in range(len(input_tokens)):
        print(f"input: {tokenizer.decode(input_tokens[i])}")
        print(f"st target: {tokenizer.decode(st_target_tokens[i])}")
        print(f"asr target: {tokenizer.decode(asr_target_tokens[i])}")
        print()
    break

In [None]:
pad_mask_text = key_padding_mask(input_tokens, pad_idx=tokenizer.pad_token_id)
print(f"pad mask text shape: {pad_mask_text.shape}")

In [None]:
# plot the masks
plt.figure(figsize=(16, 16))
plt.imshow(pad_mask_text, cmap='gray', interpolation='nearest')

In [None]:
pad_mask_speeech = key_padding_mask(mels.permute(0,2,1), speech_lengths)
print(f"pad mask speech shape: {pad_mask_speeech.shape}")
print(speech_lengths)
# plot the masks
plt.figure(figsize=(12, 24))
plt.imshow(pad_mask_speeech, cmap='gray', interpolation='nearest')

In [None]:
causal_mask_text = causal_mask(input_tokens)
print(f"causal mask text shape: {causal_mask_text.shape}")
# plot the masks
plt.figure(figsize=(4, 4))
plt.imshow(causal_mask_text, cmap='gray', interpolation='nearest')