In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import torch
torch.cuda.get_device_name()

In [None]:
from datasets import load_dataset

prothom_alo_dataset = load_dataset("text", data_files="../datasets/Bangla Prothom Alo.txt", split="train")

In [None]:
prothom_alo_dataset.set_format("pandas")

In [None]:
prothom_alo_df = prothom_alo_dataset[:]

In [None]:
(prothom_alo_df=='').sum()

In [None]:
prothom_alo_dataset.reset_format()

In [None]:
prothom_alo_dataset

In [None]:
prothom_alo_dataset = prothom_alo_dataset.filter(lambda x: x["text"]!="")

In [None]:
prothom_alo_dataset

In [None]:
# train_size = 10_000
# test_size = int(0.1 * train_size)

downsampled_dataset = prothom_alo_dataset.train_test_split(
    train_size=0.8, seed=42
)
downsampled_dataset

In [None]:
with open('../datasets/Bangla Error Words.txt', encoding='utf-8') as f:
    lines = f.readlines()

In [None]:
error_words = dict()

In [None]:
for line in lines:
    combination = line.split()
    original_word = combination[0]
    modified_words = combination[1:]
    error_words[original_word] = modified_words

In [None]:
error_words['গত']

In [None]:
import numpy as np
np.random.seed(42)

def replace_error_word(sentence, error_words):
    for error_word in error_words.keys():
        if error_word in sentence:
            #print(error_word)
            index = np.random.randint(len(error_words[error_word]))
            sentence = sentence.replace(error_word, error_words[error_word][index])
            break
    return sentence

In [None]:
# np.random.randint(1,3)

In [None]:
# replace_error_word("তখন আমাদের দেখা হবে।", error_words)

In [None]:
# count = 0
# index = 0
# indices = list()
# for sample in downsampled_dataset["test"].select(range(10)):
#     #replace 15% of time
#     if np.random.random() < 0.15:
#         replaced_sample = replace_error_word(sample['text'], error_words)
#         sample['text'] = replaced_sample
#         print(sample['text'])
#         indices.append(index)
#         count += 1
#     print(index)
#     index += 1
        
# print(count)

In [None]:
def corrupt_text(examples):
    # Create a corrupt example
    #replace 15% of time
    if np.random.random() < 0.20:
        examples["text"] = replace_error_word(examples['text'], error_words)
    return examples

In [None]:
downsampled_dataset["test"] = downsampled_dataset["test"].map(corrupt_text)

In [None]:
downsampled_dataset["test"][:10]

In [None]:
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast(
    #tokenizer_object=tokenizer,
    tokenizer_file="wordpiece_tokenizer_prothom_alo.json", # You can load from the tokenizer file, alternatively
    unk_token="[UNK]",
    pad_token="[PAD]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    mask_token="[MASK]",
    return_special_tokens_mask = True,
    model_max_length = 512,
)

In [None]:
from transformers import BertConfig, BertForMaskedLM


# Set a configuration for our RoBERTa model
wordpiece_bert_config = BertConfig(pad_token_id=tokenizer.pad_token_id)

# Building the model from the config
# Model is randomly initialized
model = BertForMaskedLM(wordpiece_bert_config)

print(wordpiece_bert_config)

In [None]:
# wordpiece_bert_config

In [None]:
def tokenize_function(examples):
    result = tokenizer(examples["text"], padding="max_length", max_length=80, truncation=True)
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result

In [None]:
# Use batched=True to activate fast multithreading!
tokenized_datasets = downsampled_dataset.map(
    tokenize_function, batched=True, remove_columns=["text"]
)
tokenized_datasets

In [None]:
tokenized_datasets.remove_columns("token_type_ids")

In [None]:
# temp = tokenized_datasets.filter(lambda x:x if 1 in x["input_ids"] else None)

In [None]:
# temp

In [None]:
tokenized_datasets['train'][0]

In [None]:
def group_texts(examples):
    # Create a new labels column
    examples["labels"] = examples["input_ids"].copy()
    return examples

In [None]:
lm_datasets = tokenized_datasets.map(group_texts, batched=True)
lm_datasets

In [None]:
downsampled_dataset["test"][:100]

In [None]:
tokenizer.decode(lm_datasets["train"][1]["input_ids"])

In [None]:
tokenizer.decode(lm_datasets["train"][1]["labels"])

In [None]:
tokenizer.pad_token_id

In [None]:
import collections
import numpy as np
np.random
from transformers import default_data_collator

wwm_probability = 0.15


def bangla_data_collator(features):
    for feature in features:
#         word_ids = feature.pop("word_ids")

#         # Create a map between words and corresponding token indices
#         mapping = collections.defaultdict(list)
#         current_word_index = -1
#         current_word = None
#         for idx, word_id in enumerate(word_ids):
#             if word_id is not None:
#                 if word_id != current_word:
#                     current_word = word_id
#                     current_word_index += 1
#                 mapping[current_word_index].append(idx)

        # Randomly mask words
        input_ids = feature["input_ids"]
        labels = feature["labels"]
        mask = np.random.binomial(1, wwm_probability, (len(input_ids),))
        special_tokens =  [tokenizer.unk_token_id, tokenizer.pad_token_id, tokenizer.cls_token_id, \
                           tokenizer.sep_token_id, tokenizer.mask_token_id]
        
        new_labels = [-100] * len(labels)
        for idx in np.where(mask)[0]:
#             word_id = word_id.item()
#             print(word_id)
#             for idx in mapping[word_id]:
#             if word_ids[idx] is not None:
            if input_ids[idx] not in special_tokens:
                new_labels[idx] = labels[idx]
                input_ids[idx] = tokenizer.mask_token_id
            feature["labels"] = new_labels
        
    return default_data_collator(features)

In [None]:
from transformers import DataCollatorForLanguageModeling

# data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
lm_datasets = lm_datasets.remove_columns(["word_ids"])
data_collator = bangla_data_collator

In [None]:
samples = [lm_datasets["train"][i] for i in range(3)]
# for sample in samples:
#     _ = sample.pop("word_ids")

for chunk in bangla_data_collator(samples)["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")

In [None]:
samples = [lm_datasets["train"][i] for i in range(1)]

chunk = data_collator(samples)
print(chunk["input_ids"])
print(chunk["labels"])

In [None]:
model = BertForMaskedLM.from_pretrained("models/wordpiece/bert-base-pretrained-prothom-alo")

In [None]:
tokenizer = PreTrainedTokenizerFast.from_pretrained("models/wordpiece/bert-base-pretrained-prothom-alo")

In [None]:
#  disable weights and biases logging
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
from transformers import TrainingArguments

batch_size = 64
# Show the training loss with every epoch|
logging_steps = len(downsampled_dataset["train"]) // batch_size


training_args = TrainingArguments(
    num_train_epochs = 6,
    report_to = None,
    output_dir="models/wordpiece/bert-base-pretrained-prothom-alo",
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    #push_to_hub=True,
    fp16=True,
    logging_steps=logging_steps,
    load_best_model_at_end=True,
    save_strategy = "epoch",
)

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
import math

eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

In [None]:
# text = 'পরে সেখানে সংক্ষিপ্ত সমাবেশ অনুষ্ঠিত হয় ।'
text =  'পরে সেখানে সংক্ষিপ্ত সমাবেশ অনুষ্ঠিত [MASK] ।'

In [None]:
tokenizer.tokenize(text)

In [None]:
import torch 

inputs = tokenizer(text, return_tensors="pt")
inputs.to("cuda")

token_logits = model(**inputs).logits
# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")