In [None]:
!pip install transformers[torch] datasets torch numpy sentencepiece

In [None]:
INP_HF_DATASET_PATH = '/kaggle/input/iwslt-en-zh/ds'
INP_SPM_MODEL_PATH = '/kaggle/input/spiecebpeunproc/other/base-spiece/1/en.model'
INP_SRC_LANG = 'en'

OUTPUT_PATH = "/kaggle/working/models"

In [None]:
import datasets
import numpy as np
import torch
from torch.utils.data import RandomSampler
from transformers import (
    BertConfig,
    BertForMaskedLM,
    BertTokenizerFast,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    BatchEncoding
)
import sentencepiece as spm

In [None]:
dev = "CPU"
device = torch.device("cpu")
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    dev = "MPS"
    device = torch.device("mps")
elif torch.cuda.is_available() and torch.cuda.device_count():
    dev = "CUDA"
    device = torch.device("cuda")
torch.set_default_device(device)

In [None]:
ds = datasets.load_from_disk(INP_HF_DATASET_PATH)

t = ds['train']

v = ds['validation']

In [None]:
from collections.abc import Mapping

SEQUENCE_LENGTH = 288


class Tokeniser:
    out_keys = [
        "input_ids",
        "token_type_ids",
        "attention_mask",
        "special_tokens_mask",
    ]

    def __init__(self, lang, model_file):
        self.lang = lang
        self.model = spm.SentencePieceProcessor(model_file=model_file)

        self.pad_token_id = 3
        self.bos_id = self.model.bos_id()
        self.eos_id = self.model.eos_id()

    def _process_id(self, input_ids):
        input_ids = [self.bos_id, *input_ids, self.eos_id]
        o_len = len(input_ids)
        token_type_ids = [0] * o_len
        attention_mask = [1] * o_len
        special_tokens_mask = [1] + [0] * (o_len - 2) + [1]

        if o_len > SEQUENCE_LENGTH:
            input_ids = input_ids[: SEQUENCE_LENGTH - 1] + [self.eos_id]
            token_type_ids = token_type_ids[:SEQUENCE_LENGTH]
            attention_mask = attention_mask[:SEQUENCE_LENGTH]
            special_tokens_mask = special_tokens_mask[: SEQUENCE_LENGTH - 1] + [1]

        elif o_len < SEQUENCE_LENGTH:
            # EOS
            input_ids += [self.eos_id]

            # Padding
            input_ids += [self.pad_token_id] * (SEQUENCE_LENGTH - len(input_ids))

            token_type_ids += [0] * (SEQUENCE_LENGTH - len(token_type_ids))
            attention_mask += [0] * (SEQUENCE_LENGTH - len(attention_mask))

            # Padding
            special_tokens_mask += [1] * (SEQUENCE_LENGTH - len(special_tokens_mask))

        return {
            "input_ids": input_ids,
            "token_type_ids": token_type_ids,
            "attention_mask": attention_mask,
            "special_tokens_mask": special_tokens_mask,
        }

    def encode(self, row):
        if isinstance(row, list):
            return self.encode_batch(row)

        raw_ids = self.model.encode(row)
        return self._process_id(raw_ids)

    def encode_batch(self, rows):
        ids = list(map(lambda row: self._process_id(row), self.model.encode(rows)))
        return {key: [example[key] for example in ids] for key in Tokeniser.out_keys}

    def pad(self, inputs, **_kwargs):
        if (
            isinstance(inputs, (list, tuple))
            and len(inputs) > 0
            and isinstance(inputs[0], Mapping)
        ):
            inputs = {
                key: [example[key] for example in inputs] for key in Tokeniser.out_keys
            }
        return BatchEncoding(inputs, tensor_type="pt")

    def __call__(self, inputs, **_kwargs):
        return self.encode(inputs)

    def __len__(self):
        return len(self.model)


tokenizer = Tokeniser(INP_SRC_LANG, model_file=INP_SPM_MODEL_PATH)

In [None]:
# Given batch, maps input -> { 'input_ids': list[list], 'token_type_ids', 'attention_mask', 'special_tokens_mask'}
def get_row_data(batch):
    out = tokenizer(
        list(map(lambda r: r[INP_SRC_LANG], batch["translation"])),
    )    
    return out


train_dataset = t.map(get_row_data, batched=True)
test_dataset = v.map(get_row_data, batched=True)


In [None]:
config = BertConfig(
    vocab_size=len(tokenizer),
    max_position_embeddings=288,  # or 512 (sentence length for attn mask)
    hidden_size=256,
    num_attention_heads=8
    # Add or modify other config parameters as needed
)

In [None]:


model = BertForMaskedLM(config)

# Important: Tokenizer impls __len__ for output vocab size
model.resize_token_embeddings(len(tokenizer))

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # Disable Masked Language Modeling
)

training_args = TrainingArguments(
    output_dir=OUTPUT_PATH,
    evaluation_strategy="steps",
    overwrite_output_dir=True,
    num_train_epochs=20,
    per_device_train_batch_size=10,
    gradient_accumulation_steps=8,
    per_device_eval_batch_size=64,
    logging_steps=1000,
    save_steps=1000,
    load_best_model_at_end=True,
    save_total_limit=3,
    use_cpu=dev == "CPU",
    dataloader_pin_memory=False,
    fp16=dev != "CPU"
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

trainer._get_train_sampler = lambda: RandomSampler(
    trainer.train_dataset, generator=torch.Generator(device)
)

In [None]:
from kaggle_secrets import UserSecretsClient
import wandb
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_sec")
wandb.login(key=secret_value_0)


In [None]:
import numpy as np
np.object = object
trainer.train()
trainer.save_model(OUTPUT_PATH)