In [11]:
import os
import sys
import numpy as np
import torch
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
import epitran

In [2]:
normal_tokenizer = AutoTokenizer.from_pretrained("psktoure/BERT_BPE_wikitext")
phonetic_tokenizer = AutoTokenizer.from_pretrained("psktoure/BERT_BPE_phonetic_wikitext")

tokenizer_config.json:   0%|          | 0.00/1.14k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

In [3]:
print("Len normal tokenizer vocab: ", len(normal_tokenizer.get_vocab()))
print("Len phonetic tokenizer vocab: ", len(phonetic_tokenizer.get_vocab()))

Len normal tokenizer vocab:  30522
Len phonetic tokenizer vocab:  30522


In [8]:
epi = epitran.Epitran("eng-Latn")

In [12]:
text = "Hello, my name is Paul. I am a student at the University of Toronto."
tokenized_text = normal_tokenizer(text)
print(tokenized_text)
ids_to_tokens = normal_tokenizer.convert_ids_to_tokens(tokenized_text["input_ids"])
print(ids_to_tokens)
print(tokenized_text.word_ids())

{'input_ids': [1, 12325, 8, 434, 935, 68, 2034, 9, 31, 79, 23, 3467, 57, 51, 1184, 65, 5758, 9, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
['[CLS]', 'hello', ',', 'my', 'name', 'is', 'paul', '.', 'i', 'am', 'a', 'student', 'at', 'the', 'university', 'of', 'toronto', '.', '[SEP]']
[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, None]


In [13]:
nums = [1, 4, 5, 4]
idx = nums.index(4)
print(idx)

1


In [None]:
def xsampa_tokens(word, phonetic_tokenizer):
    phonetic_word = "".joint(epi.xsampa_list(word))
    tokenized_word = phonetic_tokenizer(phonetic_word, add_special_tokens=False)
    ids = tokenized_word["input_ids"]
    return ids


In [14]:
from transformers import DataCollatorForLanguageModeling
import random

class CustomDataCollatorForMLM(DataCollatorForLanguageModeling):
    def __init__(self, tokenizer, phonetic_tokenizer, word_to_phonetic, mlm_probability=0.15):
        super().__init__(tokenizer=tokenizer, mlm_probability=mlm_probability)
        self.phonetic_tokenizer = phonetic_tokenizer
        self.word_to_phonetic = word_to_phonetic

    def __call__(self, examples):
        # Tokenize normal and phonetic text
        normal_texts = [e['normal_text'] for e in examples]
        phonetic_texts = [e['phonetic_text'] for e in examples]
        
        # Tokenize both
        normal_encodings = self.tokenizer(normal_texts, return_tensors="pt", padding=True, truncation=True)
        phonetic_encodings = self.phonetic_tokenizer(phonetic_texts, return_tensors="pt", padding=True, truncation=True)
        
        # Generate MLM masks for normal text
        input_ids = normal_encodings.input_ids
        labels = input_ids.clone()  # Original labels for computing loss
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        special_tokens_mask = self.tokenizer.get_special_tokens_mask(input_ids.tolist(), already_has_special_tokens=True)
        probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
        
        masked_indices = torch.bernoulli(probability_matrix).bool()
        input_ids[masked_indices] = self.tokenizer.mask_token_id  # Replace with mask token

        # Handle corresponding phonetic tokens
        phonetic_labels = phonetic_encodings.input_ids.clone()
        for idx, (normal_sentence, phonetic_sentence) in enumerate(zip(normal_texts, phonetic_texts)):
            # Get masked words
            for token_idx in masked_indices[idx].nonzero():
                word_id = normal_encodings.word_ids(batch_index=idx)[token_idx.item()]
                if word_id is not None:  # Ignore special tokens
                    word = self.tokenizer.decode(normal_encodings.input_ids[idx][word_id])
                    phonetic_tokens = self.word_to_phonetic.get(word, [])
                    # Find and mask in phonetic text
                    for p_token in phonetic_tokens:
                        p_index = phonetic_encodings.input_ids[idx].tolist().index(p_token)
                        phonetic_encodings.input_ids[idx][p_index] = self.phonetic_tokenizer.mask_token_id

        # Return modified normal and phonetic encodings
        return {
            'input_ids': input_ids,
            'labels': labels,
            'phonetic_input_ids': phonetic_encodings.input_ids,
            'phonetic_labels': phonetic_labels
        }


In [None]:
text = "Hello, my name is Paul. I am a student at the University of Toronto."
encoded = normal_tokenizer(text, return_tensors="pt")
labels = encoded.input_ids.clone()
probability_matrix = torch.full(labels.shape, 0.15)
print(probability_matrix)
special_tokens_mask = normal_tokenizer.get_special_tokens_mask(encoded.input_ids.tolist(), already_has_special_tokens=True)
print(special_tokens_mask)
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
print(probability_matrix)
masked_indices = torch.bernoulli(probability_matrix).bool()
print(masked_indices)
normal_tokenizer.pad()

tensor([[0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,
         0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,
         0.1500]])
[0]
tensor([[0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,
         0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,
         0.1500]])
tensor([[False, False, False, False, False, False, False,  True, False, False,
         False, False, False,  True, False,  True, False, False, False]])


In [34]:
print(encoded.input_ids)

tensor([[    1, 12325,     8,   434,   935,    68,  2034,     9,    31,    79,
            23,  3467,    57,    51,  1184,    65,  5758,     9,     2]])


In [37]:
print(normal_tokenizer.encode(normal_tokenizer.cls_token, add_special_tokens=False))
print(normal_tokenizer.encode(phonetic_tokenizer.cls_token, add_special_tokens=False))

[1]
[1]
