# 修改 Masking 位置策略
原本 Masking 位置為完全隨機 => 以前被 Mask 過的位置，之後不會再 Mask

In [22]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [23]:
from transformers import BertTokenizer, BertForPreTraining, AdamW
from transformers.models.bert.modeling_bert import BertForPreTrainingOutput, BertPreTrainingHeads, BertConfig, BERT_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.models.albert.modeling_albert import AlbertSOPHead
from transformers.utils import ModelOutput
from transformers.utils.doc import add_start_docstrings_to_model_forward, replace_return_docstrings
from torch.nn import CrossEntropyLoss
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from tqdm import tqdm
import pandas as pd
import torch
import random
import copy

In [24]:
# 設定 BertForPreTraining 輸出形式
# 主要是想從原本的 BertForPreTrainingOutput 多輸出 mlm_loss 和 nsp_loss
class MyBertForPreTrainingOutput(BertForPreTrainingOutput):
    """
    Output type of [`MyBertForPreTraining`].
    Args:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.
        prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
            before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        mlm_loss (`float`):
            MLM loss.
        nsp_loss (`float`):
            NSP loss.
    """
    def __init__(self, loss=None, prediction_logits=None, seq_relationship_logits=None, hidden_states=None, attentions=None, mlm_loss=None, nsp_loss=None):
        super().__init__(loss=loss, prediction_logits=prediction_logits, seq_relationship_logits=seq_relationship_logits, hidden_states=hidden_states, attentions=attentions)
        self.mlm_loss = mlm_loss
        self.nsp_loss = nsp_loss

In [25]:
class MyAlbertSOPHead(torch.nn.Module):
    def __init__(self, config: BertConfig):
        super().__init__()  

        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(config.hidden_size , config.num_labels)

    def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
        dropout_pooled_output = self.dropout(pooled_output)
        logits = self.classifier(dropout_pooled_output)
        return logits

In [26]:
class BertPretrainingHeadsWithSOP(BertPreTrainingHeads):
    def __init__(self, config):
        super().__init__(config)
        self.seq_relationship = MyAlbertSOPHead(config)

In [27]:
# 修改本來的 BertForPreTraining
class MyBertForPreTraining(BertForPreTraining):
    def __init__(self, config, nspTask = "NSP"):
        super().__init__(config)
        if nspTask == "SOP":
            self.cls = BertPretrainingHeadsWithSOP(config)
            
    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @replace_return_docstrings(output_type=MyBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        next_sentence_label: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], MyBertForPreTrainingOutput]:
        r"""
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
                config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
                the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
            next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
                Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
                pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
                - 0 indicates sequence B is a continuation of sequence A,
                - 1 indicates sequence B is a random sequence.
            kwargs (`Dict[str, any]`, optional, defaults to *{}*):
                Used to hide legacy arguments that have been deprecated.
        Returns:
        Example:
        ```python
        >>> from transformers import AutoTokenizer, BertForPreTraining
        >>> import torch
        >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        >>> model = BertForPreTraining.from_pretrained("bert-base-uncased")
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> prediction_logits = outputs.prediction_logits
        >>> seq_relationship_logits = outputs.seq_relationship_logits
        ```
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output, pooled_output = outputs[:2]
        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)

        total_loss = None
        if labels is not None and next_sentence_label is not None:
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
            total_loss = masked_lm_loss + next_sentence_loss

        if not return_dict:
            output = (prediction_scores, seq_relationship_score) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return MyBertForPreTrainingOutput(
            loss=total_loss,
            prediction_logits=prediction_scores,
            seq_relationship_logits=seq_relationship_score,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            mlm_loss = masked_lm_loss,
            nsp_loss = next_sentence_loss,
        )

# 取出資料集

In [28]:
class getData():
    def __init__(self, modelType, datapath, nspTask = "NSP"):
        self.datapath = datapath
        self.tokenizer = BertTokenizer.from_pretrained(modelType)
        self.nspTask = nspTask
        self.df = pd.read_pickle(self.datapath)
        self.sentence_a = []
        self.sentence_b = []
        self.label = []
        self.important_label = []
        self.inputs = None
        self.nspPrepare()
        self.inputs['labels'] = self.inputs.input_ids.detach().clone()
    
    def nspPrepare(self):
        if self.nspTask == "NSP":
            self.nspData()
        # elif self.nspTask == "SOP":
        #     self.sopData()

        self.inputs = self.tokenizer(self.sentence_a, self.sentence_b, return_tensors='pt',
                   max_length=512, truncation=True, padding='max_length')
        self.inputs['next_sentence_label'] = torch.LongTensor([self.label]).T

        fixed_len = 512
        self.important_label = [sublist + [0] * (fixed_len - len(sublist)) for sublist in self.important_label]
        self.important_label = torch.LongTensor(self.important_label)
        # self.inputs['important_label'] = self.important_label
        
        mask_important = torch.full(self.important_label.shape, False)
        mask_important[self.important_label >= 2] = True
        self.inputs['mask_important'] = mask_important
    
    def nspData(self):  
        text_idx = 0
        sen_idx = 0
        text_num = len(self.df)
        while text_idx < text_num:
            num_sentences = len(self.df.iloc[text_idx, 2])
            if num_sentences > 1:
                start = random.randint(0, num_sentences-2)
                # 50/50 whether is IsNextSentence or NotNextSentence
                if random.random() >= 0.5:
                    # this is IsNextSentence
                    self.sentence_a.append(self.df.iloc[text_idx, 2][start])
                    self.sentence_b.append(self.df.iloc[text_idx, 2][start+1])
                    self.label.append(0)
                    combine = self.df.iloc[text_idx, 3][start] + self.df.iloc[text_idx, 3][start+1][1:]
                    combine_len = len(combine)
                    if combine_len <= 512:
                        self.important_label.append(combine)
                    else:
                        a_cpy = copy.deepcopy(self.df.iloc[text_idx, 3][start])
                        b_cpy = copy.deepcopy(self.df.iloc[text_idx, 3][start+1])
                        len_a = len(a_cpy)
                        len_b = len(b_cpy)
                        while combine_len > 512:
                            if len_a >= len_b:
                                a_cpy.pop(-2)
                                len_a -= 1
                            else:
                                b_cpy.pop(-2)
                                len_b -= 1
                            combine_len -= 1
                        self.important_label.append(a_cpy + b_cpy[1:])
                else:
                    text_rand = text_idx
                    rand_sen = sen_idx
                    while (text_rand == text_idx) and (rand_sen in [sen_idx, sen_idx+1]):
                        text_rand = random.randint(0, text_num-1)
                        rand_sen = random.randint(0, len(self.df.iloc[text_rand, 2])-1)
                    # this is NotNextSentence
                    self.sentence_a.append(self.df.iloc[text_idx, 2][start])
                    self.sentence_b.append(self.df.iloc[text_rand, 2][rand_sen])
                    self.label.append(1)
                    combine = self.df.iloc[text_idx, 3][start] + self.df.iloc[text_rand, 3][rand_sen][1:]
                    combine_len = len(combine)
                    if combine_len <= 512:
                        self.important_label.append(combine)
                    else:
                        a_cpy = copy.deepcopy(self.df.iloc[text_idx, 3][start])
                        b_cpy = copy.deepcopy(self.df.iloc[text_rand, 3][rand_sen])
                        len_a = len(a_cpy)
                        len_b = len(b_cpy)
                        while combine_len > 512:
                            if len_a >= len_b:
                                a_cpy.pop(-2)
                                len_a -= 1
                            else:
                                b_cpy.pop(-2)
                                len_b -= 1
                            combine_len -= 1
                        self.important_label.append(a_cpy + b_cpy[1:])
            text_idx += 1
    
    # def sopData(self):
    #     for paragraph in text:
    #         sentences = [
    #             sentence for sentence in paragraph.split('.') if sentence != ''
    #         ]
    #         num_sentences = len(sentences)
    #         if num_sentences > 1:
    #             start = random.randint(0, num_sentences-2)
    #             # 50/50 whether is IsNextSentence or NotNextSentence
    #             if random.random() >= 0.5:
    #                 # this is IsNextSentence
    #                 self.sentence_a.append(sentences[start])
    #                 self.sentence_b.append(sentences[start+1])
    #                 self.label.append(0)
    #             else:
    #                 # this is NotNextSentence
    #                 self.sentence_a.append(sentences[start+1])
    #                 self.sentence_b.append(sentences[start])
    #                 self.label.append(1)
    
    def returnInput(self):
        return self.inputs

In [29]:
class OurDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    def __len__(self):
        return len(self.encodings.input_ids)

In [30]:
class trainModel():
    def __init__(self, modelType, inputs, batch_size, epoch, acc_goal_each_epoch, masking_method = "propose", saveModelName = "", saveCSV = True, nspTask = "NSP"):
        self.model = MyBertForPreTraining.from_pretrained(modelType)
        self.tokenizer = BertTokenizer.from_pretrained(modelType)
        self.inputs = inputs
        self.batch_size = batch_size
        self.epoch = epoch
        self.acc_goal_each_epoch = acc_goal_each_epoch  # 每個 epoch 的 MLM 正確率基準
        self.masking_method = masking_method
        self.saveModelName = saveModelName
        self.saveCSV = saveCSV
        self.loader = torch.utils.data.DataLoader(OurDataset(self.inputs), \
                                             batch_size=self.batch_size, shuffle=True)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model.to(self.device)
        self.model.train()
        self.optim = AdamW(self.model.parameters(), lr = 5e-5)
        self.last_acc = 0.0
        
        if os.path.isfile("record_maskPos.csv"):
            self.rec = pd.read_csv("record_maskPos.csv")
        else:
            self.rec = pd.DataFrame({"mlm_acc_each_epoch":[], "mlm_loss_each_epoch":[], 'Mask_Percent_each_epoch':[]})
            
        self.training()
        # self.save_model(self.saveModelName)
    
    # 把輸入序列對做 Masking
    #   mask_ori: 原始可被 Mask 的位置 (非特殊 token 的地方)
    #   mask_imp: 重要 token 的位置
    def mlmPrepare(self, input_sentences, maskPercentNow, mask_ori, mask_imp):
        # mask_arr: 表示本次要 Mask 的位置，True 的地方表示 Mask 
        mask_arr = torch.full(mask_ori.shape, False)    # 先初始化 (全部先填成 False)
        
        # 把輸入 batch 內的序列對依序處理
        for i in range(len(mask_ori)):
            num_to_mask = round(len(torch.where(mask_ori[i])[0]) * (maskPercentNow * 0.01)) # 表示該序列對有幾個 token 要 Mask
            imp_pos = torch.where(mask_imp[i])  # 表示 important tokens 的位置 (index)
            imp_len = len(imp_pos[0])           # 表示有多少個 important tokens

            # 判斷 important tokens 是否夠用
            #   如果不夠的話，拿 not important tokens 來補
            if num_to_mask <= imp_len:
                mask_index = torch.randperm(imp_len)[:num_to_mask]  # 從 imp_pos 中隨機提取 num_to_mask 個元素做為要被 Mask 的 index
                mask_arr[i, imp_pos[0][mask_index]] = True          # 更新 mask_arr，將位置在 mask_index 的元素改為 True，表示 "要做 Mask"
            else:
                mask_notImp = mask_ori[i] ^ mask_imp[i]             # mask_notImp 表示不重要 token 的位置
                # 先把 imp_pos Mask 掉
                mask_index = torch.randperm(imp_len)[:imp_len]
                num_to_mask -= imp_len
                mask_arr[i, imp_pos[0][mask_index]] = True

                # 剩下的 num_to_mask 由不重要的 tokens mask
                notImp_pos = torch.where(mask_notImp)
                notImp_pos_len = len(notImp_pos[0])
                new_index = torch.randperm(notImp_pos_len)[:num_to_mask]
                # 更新 mask_arr、mask_avai
                mask_arr[i, notImp_pos[0][new_index]] = True

        selection = []
        for i in range(input_sentences.shape[0]):
            selection.append(
                torch.flatten(mask_arr[i].nonzero()).tolist()
            )

        rand_mask_type = copy.deepcopy(selection)

        for row in range(len(rand_mask_type)):
            for col in range(len(rand_mask_type[row])):
                rand_mask_type[row][col] = random.random()

        vocab_size = len(self.tokenizer.vocab)
        vocab = self.tokenizer.get_vocab()
        special_tokens = [vocab['[CLS]'], vocab['[SEP]'], vocab['[MASK]'], vocab['[UNK]'],  vocab['[PAD]']]

        for i in range(input_sentences.shape[0]):
            for j in range(len(selection[i])):
                if rand_mask_type[i][j] < 0.10:
                    continue
                elif rand_mask_type[i][j] < 0.20:
                    rand_num = vocab['[CLS]']
                    while rand_num in special_tokens:
                        rand_num = random.randint(1, vocab_size-1)
                    input_sentences[i, selection[i][j]] = rand_num
                else:
                    input_sentences[i, selection[i][j]] = 103
        
        return input_sentences, mask_arr

    def training(self):
        acc_each_epoch = []
        loss_each_epoch = []
        Mask_Percent_each_epoch = []
        stay = 0
        percent_now = 6

        for epoch in range(self.epoch):
            # setup loop with TQDM and dataloader
            mask_nums = 0
            mlm_correct = 0
            nsp_nums = 0
            nsp_correct = 0
            loop = tqdm(self.loader, leave=True)

            for batch_index, batch in enumerate(loop):
                can_mask = (batch["input_ids"] != 101) * (batch["input_ids"] != 102) * (batch["input_ids"] != 0)

                input_sentences, mask_arr = self.mlmPrepare(batch["input_ids"].detach().clone(), percent_now, \
                                                            can_mask, batch["mask_important"])
                
                # initialize calculated gradients (from prev step)
                self.optim.zero_grad()
                # pull all tensor batches required for training
                input_ids = input_sentences.to(self.device)
                token_type_ids = batch['token_type_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                next_sentence_label = batch['next_sentence_label'].to(self.device)
                labels = batch['labels'].to(self.device)
                # process
                outputs = self.model(input_ids, attention_mask=attention_mask,
                                token_type_ids=token_type_ids,
                                next_sentence_label=next_sentence_label,
                                labels=labels)
                
                prediction_logits = outputs.prediction_logits[mask_arr]
                predicted_ids = prediction_logits.argmax(-1)
                
                seq_relationship_logits = outputs.seq_relationship_logits
                predicted_labels = torch.argmax(seq_relationship_logits, dim=1)
                predicted_label = predicted_labels

                mask_nums += len(predicted_ids)
                mlm_correct += torch.eq(predicted_ids, labels[mask_arr]).sum().item()
                nsp_nums += len(predicted_label)
                nsp_correct += predicted_label.eq(torch.squeeze(next_sentence_label)).sum().item()
                
                # extract loss
                loss = outputs.loss
                mlm_loss = outputs.mlm_loss.item()
                nsp_loss = outputs.nsp_loss.item()
                mlm_acc = mlm_correct / mask_nums
                nsp_acc = nsp_correct / nsp_nums
                # calculate loss for every parameter that needs grad update
                loss.backward()
                # update parameters
                self.optim.step()
                # print relevant info to progress bar
                loop.set_description(f'Epoch {epoch}')
                loop.set_postfix(Total_loss='{:.4f}'.format(loss.item()), MLM_Accuracy='{:.4f}'.format(mlm_acc), NSP_Accuracy='{:.4f}'.format(nsp_acc), \
                                MLM_loss='{:.4f}'.format(mlm_loss), NSP_loss='{:.4f}'.format(nsp_loss), Mask_Percent=percent_now)
            
            acc_each_epoch.append(mlm_acc)
            loss_each_epoch.append(mlm_loss)
            Mask_Percent_each_epoch.append(percent_now)

            if self.masking_method == "DMLM":
                percent_now += 1
            elif self.masking_method == "propose":
                if (mlm_acc >= self.acc_goal_each_epoch[epoch] * 0.01) or stay >= 2:
                    stay = 0
                    percent_now = 6 + epoch + 1
                else:
                    stay += 1
            elif self.masking_method == "adaptive":
                if mlm_acc > self.last_acc:
                    percent_now += 1
                else:
                    percent_now -= 1
                self.last_acc = mlm_acc
            
            if epoch % 2 == 1:
                self.save_model(self.saveModelName + "_epoch" + str(epoch + 1))


        if self.saveCSV:
            
            new_rec = pd.concat([self.rec, pd.DataFrame(pd.DataFrame({'mlm_acc_each_epoch': [acc_each_epoch], 'mlm_loss_each_epoch': [loss_each_epoch], 'Mask_Percent_each_epoch': [Mask_Percent_each_epoch]}))], ignore_index=True)
            new_rec.to_csv("record_maskPos.csv", index = False)
        torch.cuda.empty_cache()
    
    def save_model(self, model_name):
        self.model.save_pretrained(model_name)

In [31]:
datapath = 'bbc-text-with-important-003.pkl'
modelType = 'bert-base-cased'
epoch = 10
batch_size = 6
nsp_input = getData(modelType = modelType, datapath = datapath, nspTask = "NSP")
epoch_acc = [33.7, 42.1, 44.2, 45.7, 47.3, 49.0, 50.6, 51.9, 53.8 , 55.6]

In [11]:
mask_dyn_grow1 = trainModel(modelType = modelType, inputs = nsp_input.returnInput(), batch_size = batch_size, epoch = epoch, acc_goal_each_epoch = epoch_acc, masking_method = "DMLM", saveModelName = "saved_model/saved_model_SelMask_DMLM")
mask_dyn_grow1 = None

  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 0: 100%|██████████| 371/371 [05:53<00:00,  1.05it/s, MLM_Accuracy=0.2572, MLM_loss=0.0449, Mask_Percent=6, NSP_Accuracy=0.8252, NSP_loss=0.5870, Total_loss=0.6319]
Epoch 1: 100%|██████████| 371/371 [06:10<00:00,  1.00it/s, MLM_Accuracy=0.3765, MLM_loss=0.0457, Mask_Percent=7, NSP_Accuracy=0.9402, NSP_loss=0.7486, Total_loss=0.7943]
Epoch 2: 100%|██████████| 371/371 [06:16<00:00,  1.01s/it, MLM_Accuracy=0.3878, MLM_loss=0.0074, Mask_Percent=8, NSP_Accuracy=0.9591, NSP_loss=0.3208, Total_loss=0.3282]
Epoch 3: 100%|██████████| 371/371 [06:14<00:00,  1.01s/it, MLM_Accuracy=0.4009, MLM_loss=0.0286, Mask_Percent=9, NSP_Accuracy=0.9789, NSP_loss=0.0022, Total_loss=0.0307]
Epoch 4: 100%|██████████| 371/371 [05:05<00:00,  1.21it/s, MLM_Accuracy=0.4135, MLM_loss=0.0136, Mask_Percent=10, NSP_Accuracy=0.9816, NSP_loss=0.0666, Total_loss=0.0802]
Epoch 5: 100%|██████████| 371/371 [03:27<00:00,  1.78it/s, MLM_Accuracy

In [None]:
mask_dyn = trainModel(modelType = modelType, inputs = nsp_input.returnInput(), batch_size = batch_size, epoch = epoch, acc_goal_each_epoch = epoch_acc, saveModelName = "saved_model/saved_model_SelMask_propose")
mask_dyn = None