In [1]:
from transformers import AutoTokenizer
from datasets import load_dataset
from transformers import (
    T5ForConditionalGeneration,
    BartForConditionalGeneration,
)
import collections
import numpy as np
from transformers.data.data_collator import default_data_collator
from indobenchmark import IndoNLGTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#kalau t5
MODEL_CHECKPOINT = "../models/pt-indot5-TA_PT"
TOK_CHECKPOINT = "Wikidepia/IndoT5-base"
SAVE_PATH = "../models/pt-indot5-MLM_TA_PT"
#kalau bart
# MODEL_CHECKPOINT = "indobenchmark/indobart-v2"
# SAVE_PATH = "models/pt-indobart-MLM_PT"

In [4]:
# Set up the t5 tokenizer
tokenizer = AutoTokenizer.from_pretrained(TOK_CHECKPOINT)

#set up bart tokenizer
#tokenizer = IndoNLGTokenizer.from_pretrained(MODEL_CHECKPOINT)

# Data

In [5]:
CHUNK_SIZE = 128

In [6]:
def tokenize_function(examples):
    result = tokenizer(examples["clean_tweet"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result

In [7]:
def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // CHUNK_SIZE) * CHUNK_SIZE
    # Split by chunks of max_len
    result = {
        k: [t[i : i + CHUNK_SIZE] for i in range(0, total_length, CHUNK_SIZE)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

In [8]:
def drop_duplicate_id(tokens, highest_id):
    extra_ids = tokenizer.convert_tokens_to_ids([f'<extra_id_{id}>' for id in range(highest_id)])
    extra_ids = {key:0 for key in extra_ids}
    del_idx = []
    for i,tok in enumerate(tokens):
        if tok in extra_ids:
            extra_ids[tok]+=1
            if extra_ids[tok]>1:
                del_idx.append(i)
    new_tokens = [tokens[i] for i in range(len(tokens)) if i not in del_idx]
    return new_tokens

In [9]:
wwm_probability = 0.2


def multi_word_masking(examples):
    new_inputs = []
    new_labels = []
    for i in range(len(examples['input_ids'])):
        word_ids = examples['word_ids'][i]
        # Create a map between words and corresponding token indices
        mapping = collections.defaultdict(list)
        current_word_index = -1
        current_word = None
        for idx, word_id in enumerate(word_ids):
            if word_id is not None:
                if word_id != current_word:
                    current_word = word_id
                    current_word_index += 1
                mapping[current_word_index].append(idx)

        # Randomly mask words
        mask = np.random.binomial(1, wwm_probability, (len(mapping),))
        input_ids = examples["input_ids"][i]
        labels = examples["labels"][i]
        input_masked = -1
        label_masked = -1
        prev_mask = None
        for word_id, masked in enumerate(mask):
            if prev_mask!=masked:
                if masked==0:
                    label_masked+=1
                elif masked==1:
                    input_masked+=1
                prev_mask=masked
            for idx in mapping[word_id]:
                #if not masking input_ids then we mask the label
                if masked==0:
                    labels[idx] = tokenizer.convert_tokens_to_ids(f'<extra_id_{label_masked}>')
                #if masking then we mask the input_ids
                elif masked==1:
                    input_ids[idx] = tokenizer.convert_tokens_to_ids(f'<extra_id_{input_masked}>')
        #dropping the same extra_id
        input_ids = drop_duplicate_id(input_ids, input_masked+1)
        labels = drop_duplicate_id(labels, label_masked+1)
        new_inputs.append(input_ids)
        new_labels.append(labels)
    examples['labels'] = new_labels
    examples['input_ids'] = new_inputs
    examples['attention_mask'] = [[1 for i in range(len(examples['input_ids'][j]))] for j in range(len(examples['input_ids']))]
    return examples

In [10]:
raw_dataset = load_dataset('csv', data_files='../Data/post-train/MLM/clean_tweet.csv')
tokenized_datasets = raw_dataset.map(
    tokenize_function, batched=True, remove_columns=raw_dataset['train'].column_names
)
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
downsampled_dataset = lm_datasets["train"].train_test_split(test_size=0.1, seed=42)
downsampled_dataset = downsampled_dataset.map(multi_word_masking, batched=True)

Found cached dataset csv (C:/Users/danendra/.cache/huggingface/datasets/csv/default-a0da231f93618037/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
100%|██████████| 1/1 [00:00<00:00, 34.50it/s]
Loading cached processed dataset at C:\Users\danendra\.cache\huggingface\datasets\csv\default-a0da231f93618037\0.0.0\6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1\cache-ecfd2fdf1dcb97c8.arrow
Loading cached processed dataset at C:\Users\danendra\.cache\huggingface\datasets\csv\default-a0da231f93618037\0.0.0\6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1\cache-c68c1a0da8738c95.arrow
Loading cached split indices for dataset at C:\Users\danendra\.cache\huggingface\datasets\csv\default-a0da231f93618037\0.0.0\6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1\cache-8e9f8a7eceb1a98c.arrow and C:\Users\danendra\.cache\huggingface\datasets\csv\default-a0da231f93618037\0.0.0\6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bc

In [11]:
tokenizer.decode(downsampled_dataset['train'][0]['input_ids'])

'naksir bgt kulot<extra_id_0> Shopee tp mau beli mikir kali karena<extra_id_1> ngestalk lagi.<extra_id_2> gue beneran naksir ama kulotnya</s><extra_id_3> ga lagi deh pake<extra_id_4> express, kapok :")</s><extra_id_5> emang set<extra_id_6> xpress aja.. tapi setidaknya aku pake yang standard huuuu orang<extra_id_7> free bodoh<extra_id_8> ga cek</s> <unk> gapake bukalapak kakk coba wa<extra_id_9></s> <unk>Langit Merah Jakarta,'

In [12]:
tokenizer.decode(downsampled_dataset['train'][0]['labels'])

'<extra_id_0> di<extra_id_1> takut dianggap<extra_id_2> padahal<extra_id_3></s> <unk>udahlah<extra_id_4> shopee<extra_id_5></s> <unk>tapi sellernya<extra_id_6> shopee<extra_id_7> sama"<extra_id_8> banget<extra_id_9></s> aja</s><extra_id_10>'

# Model training


## Using MLM task

In [13]:
from transformers import (
    T5ForConditionalGeneration,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
)

In [14]:
#model_checkpoint = "Wikidepia/IndoT5-base"
model = T5ForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT).to("cuda")

In [15]:
datacollator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

In [16]:
batch_size = 8
# Show the training loss with every epoch
logging_steps = len(downsampled_dataset["train"]) // batch_size
#model_name = model_checkpoint.split("/")[-1]

training_args = TrainingArguments(
    SAVE_PATH,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    resume_from_checkpoint=True,
    num_train_epochs=10,
    save_total_limit=2,
)

In [17]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=downsampled_dataset["train"],
    eval_dataset=downsampled_dataset["test"],
    tokenizer=tokenizer,
    data_collator=datacollator,
    
)

# Eval

In [18]:
import math

perplexity before training

In [19]:
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


100%|██████████| 245/245 [00:15<00:00, 15.34it/s]

>>> Perplexity: 92.12





In [20]:
trainer.train()

  2%|▏         | 500/22020 [02:12<1:35:09,  3.77it/s]

{'loss': 3.2899, 'learning_rate': 1.954586739327884e-05, 'epoch': 0.23}


  5%|▍         | 1000/22020 [04:24<1:33:08,  3.76it/s]

{'loss': 2.7823, 'learning_rate': 1.9091734786557677e-05, 'epoch': 0.45}


  7%|▋         | 1500/22020 [06:34<1:27:14,  3.92it/s]

{'loss': 2.592, 'learning_rate': 1.8637602179836514e-05, 'epoch': 0.68}


  9%|▉         | 2000/22020 [08:35<1:17:40,  4.30it/s]

{'loss': 2.4979, 'learning_rate': 1.818346957311535e-05, 'epoch': 0.91}


                                                      
 10%|█         | 2202/22020 [09:37<1:05:41,  5.03it/s]

{'eval_loss': 2.2875609397888184, 'eval_runtime': 14.2432, 'eval_samples_per_second': 137.399, 'eval_steps_per_second': 17.201, 'epoch': 1.0}


 11%|█▏        | 2500/22020 [11:09<1:14:36,  4.36it/s] 

{'loss': 2.4182, 'learning_rate': 1.772933696639419e-05, 'epoch': 1.14}


 14%|█▎        | 3000/22020 [13:04<1:13:22,  4.32it/s]

{'loss': 2.3417, 'learning_rate': 1.7275204359673027e-05, 'epoch': 1.36}


 16%|█▌        | 3500/22020 [14:59<1:11:23,  4.32it/s]

{'loss': 2.3218, 'learning_rate': 1.6821071752951864e-05, 'epoch': 1.59}


 18%|█▊        | 4000/22020 [16:55<1:09:17,  4.33it/s]

{'loss': 2.2735, 'learning_rate': 1.63669391462307e-05, 'epoch': 1.82}


                                                      
 20%|██        | 4404/22020 [18:42<1:01:20,  4.79it/s]

{'eval_loss': 2.086479663848877, 'eval_runtime': 13.8511, 'eval_samples_per_second': 141.288, 'eval_steps_per_second': 17.688, 'epoch': 2.0}


 20%|██        | 4500/22020 [19:30<1:05:16,  4.47it/s] 

{'loss': 2.2296, 'learning_rate': 1.591280653950954e-05, 'epoch': 2.04}


 23%|██▎       | 5000/22020 [21:34<1:08:38,  4.13it/s]

{'loss': 2.1842, 'learning_rate': 1.5458673932788377e-05, 'epoch': 2.27}


 25%|██▍       | 5500/22020 [23:33<1:04:40,  4.26it/s]

{'loss': 2.1768, 'learning_rate': 1.5004541326067212e-05, 'epoch': 2.5}


 27%|██▋       | 6000/22020 [25:32<1:03:01,  4.24it/s]

{'loss': 2.1394, 'learning_rate': 1.455040871934605e-05, 'epoch': 2.72}


 30%|██▉       | 6500/22020 [27:31<1:05:24,  3.96it/s]

{'loss': 2.1252, 'learning_rate': 1.4096276112624887e-05, 'epoch': 2.95}


                                                      
 30%|███       | 6606/22020 [28:11<52:35,  4.88it/s]

{'eval_loss': 2.009526252746582, 'eval_runtime': 14.0508, 'eval_samples_per_second': 139.28, 'eval_steps_per_second': 17.437, 'epoch': 3.0}


 32%|███▏      | 7000/22020 [30:13<1:06:52,  3.74it/s] 

{'loss': 2.1046, 'learning_rate': 1.3642143505903725e-05, 'epoch': 3.18}


 34%|███▍      | 7500/22020 [32:15<1:03:50,  3.79it/s]

{'loss': 2.0819, 'learning_rate': 1.3188010899182562e-05, 'epoch': 3.41}


 36%|███▋      | 8000/22020 [34:14<54:19,  4.30it/s]  

{'loss': 2.0623, 'learning_rate': 1.27338782924614e-05, 'epoch': 3.63}


 39%|███▊      | 8500/22020 [36:12<51:45,  4.35it/s]  

{'loss': 2.0691, 'learning_rate': 1.2279745685740236e-05, 'epoch': 3.86}


                                                    
 40%|████      | 8808/22020 [37:39<45:39,  4.82it/s]

{'eval_loss': 1.9624103307724, 'eval_runtime': 14.1465, 'eval_samples_per_second': 138.338, 'eval_steps_per_second': 17.319, 'epoch': 4.0}


 41%|████      | 9000/22020 [38:46<50:09,  4.33it/s]   

{'loss': 2.0562, 'learning_rate': 1.1825613079019073e-05, 'epoch': 4.09}


 43%|████▎     | 9500/22020 [40:43<49:01,  4.26it/s]

{'loss': 2.0304, 'learning_rate': 1.137148047229791e-05, 'epoch': 4.31}


 45%|████▌     | 10000/22020 [42:40<47:20,  4.23it/s]

{'loss': 2.0212, 'learning_rate': 1.0917347865576748e-05, 'epoch': 4.54}


 48%|████▊     | 10500/22020 [44:43<51:04,  3.76it/s]

{'loss': 2.0161, 'learning_rate': 1.0463215258855586e-05, 'epoch': 4.77}


 50%|████▉     | 11000/22020 [46:54<47:25,  3.87it/s]

{'loss': 1.9884, 'learning_rate': 1.0009082652134423e-05, 'epoch': 5.0}


                                                     
 50%|█████     | 11010/22020 [47:13<43:56,  4.18it/s]

{'eval_loss': 1.9358924627304077, 'eval_runtime': 15.6354, 'eval_samples_per_second': 125.164, 'eval_steps_per_second': 15.67, 'epoch': 5.0}


 52%|█████▏    | 11500/22020 [49:43<47:33,  3.69it/s]   

{'loss': 1.9954, 'learning_rate': 9.554950045413262e-06, 'epoch': 5.22}


 54%|█████▍    | 12000/22020 [51:46<38:44,  4.31it/s]

{'loss': 1.9707, 'learning_rate': 9.1008174386921e-06, 'epoch': 5.45}


 57%|█████▋    | 12500/22020 [53:45<37:02,  4.28it/s]

{'loss': 1.9604, 'learning_rate': 8.646684831970936e-06, 'epoch': 5.68}


 59%|█████▉    | 13000/22020 [55:44<34:57,  4.30it/s]

{'loss': 1.9853, 'learning_rate': 8.192552225249773e-06, 'epoch': 5.9}


                                                     
 60%|██████    | 13212/22020 [56:49<30:46,  4.77it/s]

{'eval_loss': 1.9152294397354126, 'eval_runtime': 14.3072, 'eval_samples_per_second': 136.785, 'eval_steps_per_second': 17.124, 'epoch': 6.0}


 61%|██████▏   | 13500/22020 [58:21<33:34,  4.23it/s]   

{'loss': 1.9704, 'learning_rate': 7.73841961852861e-06, 'epoch': 6.13}


 64%|██████▎   | 14000/22020 [1:00:20<32:16,  4.14it/s]

{'loss': 1.9518, 'learning_rate': 7.284287011807448e-06, 'epoch': 6.36}


 66%|██████▌   | 14500/22020 [1:02:19<29:27,  4.26it/s]

{'loss': 1.9625, 'learning_rate': 6.8301544050862855e-06, 'epoch': 6.58}


 68%|██████▊   | 15000/22020 [1:04:19<27:49,  4.20it/s]

{'loss': 1.9395, 'learning_rate': 6.376021798365123e-06, 'epoch': 6.81}


                                                       
 70%|███████   | 15414/22020 [1:06:13<22:54,  4.81it/s]

{'eval_loss': 1.9046376943588257, 'eval_runtime': 14.3251, 'eval_samples_per_second': 136.614, 'eval_steps_per_second': 17.103, 'epoch': 7.0}


 70%|███████   | 15500/22020 [1:06:59<25:41,  4.23it/s]   

{'loss': 1.9325, 'learning_rate': 5.9218891916439605e-06, 'epoch': 7.04}


 73%|███████▎  | 16000/22020 [1:08:58<23:52,  4.20it/s]

{'loss': 1.9173, 'learning_rate': 5.467756584922798e-06, 'epoch': 7.27}


 75%|███████▍  | 16500/22020 [1:10:57<21:27,  4.29it/s]

{'loss': 1.9304, 'learning_rate': 5.013623978201635e-06, 'epoch': 7.49}


 77%|███████▋  | 17000/22020 [1:12:56<19:53,  4.21it/s]

{'loss': 1.9256, 'learning_rate': 4.559491371480473e-06, 'epoch': 7.72}


 79%|███████▉  | 17500/22020 [1:14:54<17:38,  4.27it/s]

{'loss': 1.9416, 'learning_rate': 4.1053587647593104e-06, 'epoch': 7.95}


                                                       
 80%|████████  | 17616/22020 [1:15:36<15:32,  4.72it/s]

{'eval_loss': 1.8973430395126343, 'eval_runtime': 14.3092, 'eval_samples_per_second': 136.765, 'eval_steps_per_second': 17.122, 'epoch': 8.0}


 82%|████████▏ | 18000/22020 [1:17:33<15:46,  4.25it/s]   

{'loss': 1.9251, 'learning_rate': 3.6512261580381475e-06, 'epoch': 8.17}


 84%|████████▍ | 18500/22020 [1:19:32<13:53,  4.22it/s]

{'loss': 1.8973, 'learning_rate': 3.197093551316985e-06, 'epoch': 8.4}


 86%|████████▋ | 19000/22020 [1:21:31<12:06,  4.16it/s]

{'loss': 1.9142, 'learning_rate': 2.742960944595822e-06, 'epoch': 8.63}


 89%|████████▊ | 19500/22020 [1:23:29<09:58,  4.21it/s]

{'loss': 1.9092, 'learning_rate': 2.2888283378746596e-06, 'epoch': 8.86}


                                                       
 90%|█████████ | 19818/22020 [1:25:04<08:33,  4.29it/s]

{'eval_loss': 1.8934144973754883, 'eval_runtime': 15.9221, 'eval_samples_per_second': 122.911, 'eval_steps_per_second': 15.387, 'epoch': 9.0}


 91%|█████████ | 20000/22020 [1:26:17<08:55,  3.77it/s]  

{'loss': 1.9186, 'learning_rate': 1.8346957311534968e-06, 'epoch': 9.08}


 93%|█████████▎| 20500/22020 [1:28:30<06:47,  3.73it/s]

{'loss': 1.9026, 'learning_rate': 1.3805631244323345e-06, 'epoch': 9.31}


 95%|█████████▌| 21000/22020 [1:30:42<04:01,  4.23it/s]

{'loss': 1.8904, 'learning_rate': 9.264305177111717e-07, 'epoch': 9.54}


 98%|█████████▊| 21500/22020 [1:32:39<02:02,  4.26it/s]

{'loss': 1.9046, 'learning_rate': 4.722979109900091e-07, 'epoch': 9.76}


100%|█████████▉| 22000/22020 [1:34:35<00:04,  4.25it/s]

{'loss': 1.9098, 'learning_rate': 1.8165304268846506e-08, 'epoch': 9.99}


                                                       
100%|██████████| 22020/22020 [1:34:54<00:00,  4.80it/s]

{'eval_loss': 1.8925002813339233, 'eval_runtime': 14.1359, 'eval_samples_per_second': 138.442, 'eval_steps_per_second': 17.332, 'epoch': 10.0}


100%|██████████| 22020/22020 [1:35:14<00:00,  3.85it/s]

{'train_runtime': 5714.9929, 'train_samples_per_second': 30.812, 'train_steps_per_second': 3.853, 'train_loss': 2.0988503976695436, 'epoch': 10.0}





TrainOutput(global_step=22020, training_loss=2.0988503976695436, metrics={'train_runtime': 5714.9929, 'train_samples_per_second': 30.812, 'train_steps_per_second': 3.853, 'train_loss': 2.0988503976695436, 'epoch': 10.0})

perplexity after training

In [21]:
model.save_pretrained(SAVE_PATH)

In [22]:
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

100%|██████████| 245/245 [00:14<00:00, 17.46it/s]

>>> Perplexity: 6.64





In [24]:
# eval_results = trainer.evaluate()
# print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

100%|██████████| 245/245 [00:14<00:00, 17.18it/s]

>>> Perplexity: 6.46





# For BART

In [18]:
from transformers import PreTrainedTokenizerFast

In [39]:
raw_dataset = load_dataset('csv', data_files='data/quadruplet_only.csv')
tokenized_datasets = raw_dataset.map(
    tokenize_function, batched=True, remove_columns=raw_dataset['train'].column_names
)
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
downsampled_dataset = lm_datasets["train"].train_test_split(test_size=0.1, seed=42)

Found cached dataset csv (C:/Users/danendra/.cache/huggingface/datasets/csv/default-5ee63b2ab57cbf46/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
100%|██████████| 1/1 [00:00<00:00, 500.04it/s]
                                                   

In [41]:
downsampled_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 72
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 8
    })
})

In [53]:
tokenizer.decode(downsampled_dataset['train'][1]['input_ids'], skip_special_tokens=False)

' beli barang di dengan pengiriman pake si driver ngeluh hanya dapat perak sya sich gak percaya, masa sich sya jawab gitu, lah wong sy bayar rb sameday dari toko didepok, gk tau dah. intinya ada dua driver gosend, dan setelah gue liat kejadiannya di cctv kantor.. asumsi gue itu tu komplotan driver. pas baca thread ini, ya kok yakinnya malah penjualnya yang main. wkwkwk damn, iya bro. kitanya sebagai buyer jadi rugi. dia enak, tinggal ajuin klaim hilang terus dapet dana kita. sedangkan kita, harus nunggu approval dulu dan proses klaim segala macem. sial emang. gk mau buruk sangka'

In [52]:
tokenizer.decode(downsampled_dataset['train'][1]['input_ids'], skip_special_tokens=False).split('.')

[' beli barang di dengan pengiriman pake si driver ngeluh hanya dapat perak sya sich gak percaya, masa sich sya jawab gitu, lah wong sy bayar rb sameday dari toko didepok, gk tau dah',
 ' intinya ada dua driver gosend, dan setelah gue liat kejadiannya di cctv kantor',
 '',
 ' asumsi gue itu tu komplotan driver',
 ' pas baca thread ini, ya kok yakinnya malah penjualnya yang main',
 ' wkwkwk damn, iya bro',
 ' kitanya sebagai buyer jadi rugi',
 ' dia enak, tinggal ajuin klaim hilang terus dapet dana kita',
 ' sedangkan kita, harus nunggu approval dulu dan proses klaim segala macem',
 ' sial emang',
 ' gk mau buruk sangka']

In [None]:
wwm_probability = 0.2


def text_infilling(examples):
    new_inputs = []
    new_labels = []
    for i in range(len(examples['input_ids'])):
        word_ids = examples['word_ids'][i]
        # Create a map between words and corresponding token indices
        mapping = collections.defaultdict(list)
        current_word_index = -1
        current_word = None
        for idx, word_id in enumerate(word_ids):
            if word_id is not None:
                if word_id != current_word:
                    current_word = word_id
                    current_word_index += 1
                mapping[current_word_index].append(idx)

        # Randomly mask words
        mask = np.random.binomial(1, wwm_probability, (len(mapping),))
        input_ids = examples["input_ids"][i]
        labels = examples["labels"][i]
        new_inputs = []
        new_labels = []
        for word_id, masked in enumerate(mask):
            for idx in mapping[word_id]:
                masked_labels[idx] = labels[idx]
                input_ids[idx] = tokenizer.mask_token_id
        #dropping the same extra_id
        input_ids = drop_duplicate_id(input_ids, input_masked+1)
        labels = drop_duplicate_id(labels, label_masked+1)
        new_inputs.append(input_ids)
        new_labels.append(labels)
    examples['labels'] = new_labels
    examples['input_ids'] = new_inputs
    examples['attention_mask'] = [[1 for i in range(len(examples['input_ids'][j]))] for j in range(len(examples['input_ids']))]
    return examples

In [None]:
# wwm_probability = 0.2


# def whole_word_masking(examples):
#     new_inputs = []
#     new_labels = []
#     for i in range(len(examples['input_ids'])):
#         word_ids = examples['word_ids'][i]
#         # Create a map between words and corresponding token indices
#         mapping = collections.defaultdict(list)
#         current_word_index = -1
#         current_word = None
#         for idx, word_id in enumerate(word_ids):
#             if word_id is not None:
#                 if word_id != current_word:
#                     current_word = word_id
#                     current_word_index += 1
#                 mapping[current_word_index].append(idx)

#         # Randomly mask words
#         mask = np.random.binomial(1, wwm_probability, (len(mapping),))
#         input_ids = examples["input_ids"][i]
#         labels = examples["labels"][i]
#         input_masked = 0
#         label_masked = 0
#         for word_id, masked in enumerate(mask):
#             for idx in mapping[word_id]:
#                 #if not masking input_ids then we mask the label
#                 if masked==0:
#                     labels[idx] = tokenizer.convert_tokens_to_ids(f'<extra_id_{label_masked}>')
#                 #if masking then we mask the input_ids
#                 elif masked==1:
#                     input_ids[idx] = tokenizer.convert_tokens_to_ids(f'<extra_id_{input_masked}>')
#             if masked==0:
#                 label_masked+=1
#             elif masked==1:
#                 input_masked+=1
#         new_inputs.append(input_ids)
#         new_labels.append(labels)
#     examples['labels'] = new_labels
#     examples['input_ids'] = new_inputs
#     return examples