In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
!pip install transformers==2.9.1

In [None]:
import re
import csv
import os
import jieba
import logging
import argparse
import random
from tqdm import tqdm, trange

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from transformers.tokenization_bert import BertTokenizer
from transformers.modeling_bert import BertForMaskedLM, BertOnlyMLMHead
from transformers import pipeline

from transformers import AdamW

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

In [None]:
def get_task_processor(task, data_dir):
    """
    A TSV processor for stsa, trec and snips dataset.
    """
    
    if task == 'tnews':
        return TSVDataProcessor(data_dir=data_dir, skip_header=False, label_col=0, text_col=1)
    elif task == 'chnsenticorp':
        return TSVDataProcessor(data_dir=data_dir, skip_header=False, label_col=0, text_col=1)
    else:
        raise ValueError('Unknown task')


def get_data(task, data_dir):
    processor = get_task_processor(task, data_dir)

    examples = dict()

    examples['train'] = processor.get_train_examples()
    examples['dev'] = processor.get_dev_examples()
    examples['test'] = processor.get_test_examples()

    for key, value in examples.items():
        print('#{}: {}'.format(key, len(value)))
    return examples, processor.get_labels(task)


class InputExample:
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text, label=None):
        self.guid = guid
        self.text = text
        self.label = label


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id

    def __getitem__(self, item):
        return [self.input_ids, self.input_mask,
                self.segment_ids, self.label_id][item]


class DatasetProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_test_examples(self):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self, task_name):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                lines.append(line)
            return lines


class TSVDataProcessor(DatasetProcessor):
    """Processor for dataset to be augmented."""

    def __init__(self, data_dir, skip_header, label_col, text_col):
        self.data_dir = data_dir
        self.skip_header = skip_header
        self.label_col = label_col
        self.text_col = text_col

    def get_train_examples(self):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(self.data_dir, "train.tsv")), "train")

    def get_dev_examples(self):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(self.data_dir, "dev.tsv")), "dev")

    def get_test_examples(self):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(self.data_dir, "test.tsv")), "test")

    def get_labels(self, task_name):
        """add your dataset here"""
        labels = set()
        with open(os.path.join(self.data_dir, "train.tsv"), "r") as in_file:
            for line in in_file:
                labels.add(line.split("\t")[self.label_col])
        return sorted(labels)

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if self.skip_header and i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text = line[self.text_col]
            label = line[self.label_col]
            examples.append(
                InputExample(guid=guid, text=text, label=label))
        return examples


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()


In [None]:
class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, init_ids, input_ids, input_mask, masked_lm_labels, label_length):
        self.init_ids = init_ids
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.masked_lm_labels = masked_lm_labels
        self.label_length = label_length


def convert_examples_to_features(examples, max_seq_length, tokenizer): 
    """Loads a data file into a list of `InputBatch`s."""

    features = []
    masked_lm_prob = 0.15
    max_predictions_per_seq = 20

    for (ex_index, example) in enumerate(examples):
        seg_list = jieba.cut(example.text, cut_all=False)
        seg_list = [x for x in seg_list]
        modified_example = example.label + " " + ' '.join(seg_list)
        label_len = len(tokenizer.tokenize(example.label))
        tokens_a = tokenizer.tokenize(modified_example, truncation=True, max_length=512)
        # Account for [CLS] and [SEP] and label with "(2+label_len)"
        if len(tokens_a) > max_seq_length - (2+label_len):
            tokens_a = tokens_a[0:(max_seq_length - (2+label_len))]

        # take care of prepending the class label in this code
        tokens = []
        tokens.append("[CLS]")
        for token in tokens_a:
            tokens.append(token)
        tokens.append("[SEP]")
        masked_lm_labels = [-100] * max_seq_length

        cand_indexes = [] # word index except label/[cls]/[sep]
        for (i, token) in enumerate(tokens):
            # making sure that masking of # prepended label is avoided
            if token == "[CLS]" or token == "[SEP]" or (i < label_len + 1):
                continue
            if (len(cand_indexes) >= 1 and token.startswith("##")):
                cand_indexes[-1].append(i)
            else:
                cand_indexes.append([i])
        num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(cand_indexes) * masked_lm_prob))))
        masked = random.sample(cand_indexes, num_to_predict)

        output_tokens = list(tokens)

        for mask in masked:
          for i in mask:
              masked_lm_labels[i] = tokenizer.convert_tokens_to_ids([tokens[i]])[0]
              output_tokens[i] = "[MASK]"


        init_ids = tokenizer.convert_tokens_to_ids(tokens)             # before masking
        input_ids = tokenizer.convert_tokens_to_ids(output_tokens)     # after masking

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        while len(input_ids) < max_seq_length:
            init_ids.append(0)
            input_ids.append(0)
            input_mask.append(0)

        assert len(init_ids) == max_seq_length
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length

        features.append(                                      # an example
            InputFeatures(init_ids=init_ids,                  # [101,313,233,4556,79,...]
                          input_ids=input_ids,                # [101,313,233,103,103,...]  
                          input_mask=input_mask,              # [1,1,1,1,1,0,0,0,...]
                          masked_lm_labels=masked_lm_labels,  # [-100,-100,273,493,-100,...]
                          label_length=label_len))
    return features

In [None]:
def prepare_data(features):
    all_init_ids = torch.tensor([f.init_ids for f in features], dtype=torch.long)
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_masked_lm_labels = torch.tensor([f.masked_lm_labels for f in features], dtype=torch.long)
    all_label_lengths = torch.tensor([f.label_length for f in features],dtype=torch.long)
    tensor_data = TensorDataset(all_init_ids, all_input_ids, all_input_mask, all_masked_lm_labels, all_label_lengths)
    return tensor_data

In [None]:
def compute_dev_loss(model, dev_dataloader):
    model.eval()
    sum_loss = 0.
    for step, batch in enumerate(dev_dataloader):
        batch = tuple(t.to(device) for t in batch)
        _, input_ids, input_mask, masked_ids, label_lengths = batch
        inputs = {'input_ids': batch[1],
                  'attention_mask': batch[2],
                  'masked_lm_labels': batch[3]}

        outputs = model(**inputs)
        loss = outputs[0]
        sum_loss += loss.item() * dev_dataloader.batch_size
    return sum_loss/len(dev_dataloader.dataset)

def train_pbert_and_augment(task_name, data_dir, output_dir, max_seq_length, train_batch_size, num_train_epochs, learning_rate, cache):
    
    os.makedirs(output_dir, exist_ok=True)
    processor = get_task_processor(task_name, data_dir)
    label_list = processor.get_labels(task_name)

    # load train and dev data
    train_examples = processor.get_train_examples()
    dev_examples = processor.get_dev_examples()

    tokenizer = BertTokenizer.from_pretrained(BERT_MODEL,
                                              do_basic_tokenize=False,
                                              model_max_length=512,
                                              cache_dir=cache)

    model = BertForMaskedLM.from_pretrained(BERT_MODEL,
                                            cache_dir=cache)

    model.to(device)

    # train data
    train_features = convert_examples_to_features(train_examples,
                                                  max_seq_length,
                                                  tokenizer)
    train_data = prepare_data(train_features)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler,
                                  batch_size=train_batch_size)


    # dev data
    dev_features = convert_examples_to_features(dev_examples,
                                                  max_seq_length,
                                                  tokenizer)
    dev_data = prepare_data(dev_features)
    dev_sampler = SequentialSampler(dev_data)
    dev_dataloader = DataLoader(dev_data, sampler=dev_sampler,
                                  batch_size=train_batch_size)

    num_train_steps = int(len(train_features) / train_batch_size * num_train_epochs)
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_features))
    logger.info("  Batch size = %d", train_batch_size)
    logger.info("  Num steps = %d", num_train_steps)

    # Prepare optimizer
    t_total = num_train_steps
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': 0.01},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=1e-8)

    best_dev_loss = float('inf')
    for epoch in trange(int(num_train_epochs), desc="Epoch"):
        avg_loss = 0.
        model.train()
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t.to(device) for t in batch)
            _, input_ids, input_mask, masked_ids, label_lengths = batch
            inputs = {'input_ids': batch[1],
                      'attention_mask': batch[2],
                      'masked_lm_labels': batch[3]}

            outputs = model(**inputs)
            loss = outputs[0]
            loss.backward()
            avg_loss += loss.item()
            optimizer.step()
            model.zero_grad()
            if (step + 1) % 3 == 0:
                print("avg_loss: {}".format(avg_loss / 3))
            avg_loss = 0.

        # eval on dev after every epoch
        dev_loss = compute_dev_loss(model, dev_dataloader)
        print("Epoch {}, Dev loss {}".format(epoch, dev_loss))
        if dev_loss < best_dev_loss:
            best_dev_loss = dev_loss
            print("Saving model. Best dev so far {}".format(best_dev_loss))
            pipe = pipeline('fill-mask')
            pipe.save_pretrained(output_dir+'/model')
    

In [None]:
BERT_MODEL = 'hfl/chinese-bert-wwm' #'bert-base-chinese' #'hfl/chinese-roberta-wwm-ext' 

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

logger = logging.getLogger(__name__)

In [None]:
task_name = "tnews"
# task_name = 'chnsenticorp'
max_seq_length = 64
train_batch_size = 50
num_train_epochs = 5
learning_rate = 4e-5
cache = "transformers_cache"


for exp_id in range(1):
    data_dir = os.path.join("/content/gdrive/My Drive/project/tnews", "exp_{}".format(exp_id))
    # data_dir = os.path.join("/content/gdrive/My Drive/project/chnsenticorp", "exp_{}".format(exp_id))
    output_dir = data_dir
    save_model_path = os.path.join(output_dir, 'model')
    os.makedirs(save_model_path)
    train_pbert_and_augment(task_name, data_dir, output_dir, max_seq_length, train_batch_size, num_train_epochs, learning_rate, cache)
