# 決定 Tokenizer 與使用 BertForPretraining 來做 BERT 預訓練

In [1]:
from transformers import BertTokenizer, BertForPreTraining, AdamW
from transformers.models.bert.modeling_bert import BertForPreTrainingOutput, BERT_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from torch.nn import CrossEntropyLoss
from typing import List, Optional, Tuple, Union
from tqdm import tqdm
import pandas as pd
import torch
import random

In [2]:
class MyBertForPreTrainingOutput(BertForPreTrainingOutput):
    def __init__(self, loss=None, prediction_logits=None, seq_relationship_logits=None, hidden_states=None, attentions=None, mlm_loss=None, nsp_loss=None):
        super().__init__(loss=loss, prediction_logits=prediction_logits, seq_relationship_logits=seq_relationship_logits, hidden_states=hidden_states, attentions=attentions)
        self.mlm_loss = mlm_loss
        self.nsp_loss = nsp_loss

In [3]:
class MyBertForPreTraining(BertForPreTraining):
    def __init__(self, config):
        super().__init__(config)
    # @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    # @replace_return_docstrings(output_type=MyBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        next_sentence_label: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], MyBertForPreTrainingOutput]:
        r"""
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
                config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
                the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
            next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
                Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
                pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
                - 0 indicates sequence B is a continuation of sequence A,
                - 1 indicates sequence B is a random sequence.
            kwargs (`Dict[str, any]`, optional, defaults to *{}*):
                Used to hide legacy arguments that have been deprecated.
        Returns:
        Example:
        ```python
        >>> from transformers import AutoTokenizer, BertForPreTraining
        >>> import torch
        >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        >>> model = BertForPreTraining.from_pretrained("bert-base-uncased")
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> prediction_logits = outputs.prediction_logits
        >>> seq_relationship_logits = outputs.seq_relationship_logits
        ```
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output, pooled_output = outputs[:2]
        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

        total_loss = None
        if labels is not None and next_sentence_label is not None:
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
            total_loss = masked_lm_loss + next_sentence_loss

        if not return_dict:
            output = (prediction_scores, seq_relationship_score) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return MyBertForPreTrainingOutput(
            loss=total_loss,
            prediction_logits=prediction_scores,
            seq_relationship_logits=seq_relationship_score,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            mlm_loss = masked_lm_loss,
            nsp_loss = next_sentence_loss,
        )

# 取出資料集

In [None]:
class getData():
    def __init__(self, modelType, datapath, maskPercent):
        self.datapath = datapath
        self.tokenizer = BertTokenizer.from_pretrained(modelType)
        self.maskPercent = maskPercent
        self.text = self.toText()
        self.inputs = None
        self.nspPrepare()
        self.mlmPrepare()
        return self.inputs
    
    def toText(self):
        df = pd.read_csv(self.datapath)
        text = []
        for review in df["text"]:
            text.append(review)
        
        return text
    
    def nspPrepare(self):
        bag = [item for sentence in self.text for item in sentence.split('.') if item != '']
        bag_size = len(bag)

        sentence_a = []
        sentence_b = []
        label = []

        for paragraph in self.text:
            sentences = [
                sentence for sentence in paragraph.split('.') if sentence != ''
            ]
            num_sentences = len(sentences)
            if num_sentences > 1:
                start = random.randint(0, num_sentences-2)
                # 50/50 whether is IsNextSentence or NotNextSentence
                if random.random() >= 0.5:
                    # this is IsNextSentence
                    sentence_a.append(sentences[start])
                    sentence_b.append(sentences[start+1])
                    label.append(0)
                else:
                    index = random.randint(0, bag_size-1)
                    # this is NotNextSentence
                    sentence_a.append(sentences[start])
                    sentence_b.append(bag[index])
                    label.append(1)

        self.inputs = self.tokenizer(sentence_a, sentence_b, return_tensors='pt',
                   max_length=512, truncation=True, padding='max_length')
        self.inputs['next_sentence_label'] = torch.LongTensor([label]).T

    def mlmPrepare(self):
        self.inputs['labels'] = self.inputs.input_ids.detach().clone()
        rand = torch.rand(self.inputs.input_ids.shape)
        # create mask array
        mask_arr = (rand < self.maskPercent) * (self.inputs.input_ids != 101) * \
                (self.inputs.input_ids != 102) * (self.inputs.input_ids != 0)
        
        selection = []

        for i in range(self.inputs.input_ids.shape[0]):
            selection.append(
                torch.flatten(mask_arr[i].nonzero()).tolist()
            )

        for i in range(self.inputs.input_ids.shape[0]):
            self.inputs.input_ids[i, selection[i]] = 103

In [16]:
class OurDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    def __len__(self):
        return len(self.encodings.input_ids)

In [None]:
class trainModel():
    def __init__(self, modelType, inputs, batch_size, epoch, maskPercent):
        self.model = MyBertForPreTraining.from_pretrained(modelType)
        self.tokenizer = BertTokenizer.from_pretrained(modelType)
        self.inputs = inputs
        self.batch_size = batch_size
        self.epoch = epoch
        self.maskPercent = maskPercent
        self.loader = torch.utils.data.DataLoader(OurDataset(self.inputs), \
                                             batch_size=self.batch_size, shuffle=True)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model.to(self.device)
        self.model.train()
        self.optim = AdamW(self.model.parameters(), lr = 5e-5)
        self.rec = pd.read_csv("record.csv")
        self.record = {"mask_percent": 0,
                  "mlm_acc_each_epoch": [], 
                  "mlm_loss_each_epoch": []}
    
    def training(self):
        acc_each_epoch = []
        loss_each_epoch = []
        for epoch in range(self.epoch):
            # setup loop with TQDM and dataloader
            mask_nums = 0
            mlm_correct = 0
            nsp_nums = 0
            nsp_correct = 0
            loop = tqdm(self.loader, leave=True)
            for batch in loop:
                # initialize calculated gradients (from prev step)
                self.optim.zero_grad()
                # pull all tensor batches required for training
                input_ids = batch['input_ids'].to(self.device)
                token_type_ids = batch['token_type_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                next_sentence_label = batch['next_sentence_label'].to(self.device)
                labels = batch['labels'].to(self.device)
                # process
                outputs = self.model(input_ids, attention_mask=attention_mask,
                                token_type_ids=token_type_ids,
                                next_sentence_label=next_sentence_label,
                                labels=labels)
                
                mask_positions = (input_ids == self.tokenizer.mask_token_id)
                prediction_logits = outputs.prediction_logits[mask_positions]
                predicted_ids = prediction_logits.argmax(-1)
                mask_truth = labels[mask_positions]
                
                seq_relationship_logits = outputs.seq_relationship_logits
                predicted_labels = torch.argmax(seq_relationship_logits, dim=1)
                predicted_label = predicted_labels

                mask_nums += len(predicted_ids)
                mlm_correct += torch.eq(predicted_ids, mask_truth).sum().item()
                nsp_nums += len(predicted_label)
                nsp_correct += predicted_label.eq(torch.squeeze(next_sentence_label)).sum().item()
                
                # extract loss
                loss = outputs.loss
                mlm_loss = outputs.mlm_loss.item()
                nsp_loss = outputs.nsp_loss.item()
                mlm_acc = mlm_correct / mask_nums
                nsp_acc = nsp_correct / nsp_nums
                # calculate loss for every parameter that needs grad update
                loss.backward()
                # update parameters
                self.optim.step()
                # print relevant info to progress bar
                loop.set_description(f'Epoch {epoch}')
                loop.set_postfix(Total_loss='{:.4f}'.format(loss.item()), MLM_Accuracy='{:.4f}'.format(mlm_acc), NSP_Accuracy='{:.4f}'.format(nsp_acc), \
                                MLM_loss='{:.4f}'.format(mlm_loss), NSP_loss='{:.4f}'.format(nsp_loss))
            acc_each_epoch.append(mlm_acc)
            loss_each_epoch.append(mlm_loss)

        self.record["mask_percent"] = self.maskPercent
        self.record["mlm_acc_each_epoch"].append(acc_each_epoch)
        self.record["mlm_loss_each_epoch"].append(loss_each_epoch)
        new_rec = self.rec.append(self.record, ignore_index=True)
        new_rec.to_csv("record.csv", index = None)

In [21]:
datapath = 'bbc-text.csv'
modelType = 'bert-base-cased'
masked15 = 

  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 0: 100%|██████████| 371/371 [02:11<00:00,  2.83it/s, MLM_Accuracy=0.2859, MLM_loss=0.1353, NSP_Accuracy=0.8112, NSP_loss=0.2407, Total_loss=0.3760]
Epoch 1:  93%|█████████▎| 345/371 [01:58<00:08,  2.89it/s, MLM_Accuracy=0.3536, MLM_loss=0.0574, NSP_Accuracy=0.9498, NSP_loss=0.0382, Total_loss=0.0956]