In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
!pip install transformers==4.16

In [None]:
import os
import re
import csv
import jieba
import random
import numpy as np
from transformers import pipeline
from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline

In [None]:
def read_file(file_p):
    out_arr = []
    with open(file_p, encoding='utf-8') as f:
        out_arr = [x.strip() for x in f.readlines()]
    return out_arr

In [None]:
def get_task_processor(task, data_dir):
    """
    A TSV processor for stsa, trec and snips dataset.
    """
    
    if task == 'tnews':
        return TSVDataProcessor(data_dir=data_dir, skip_header=False, label_col=0, text_col=1)
    elif task == 'chnsenticorp':
        return TSVDataProcessor(data_dir=data_dir, skip_header=False, label_col=0, text_col=1)
    else:
        raise ValueError('Unknown task')
    

def get_data(task, data_dir, data_seed=159):
    random.seed(data_seed)
    processor = get_task_processor(task, data_dir)

    examples = dict()

    examples['train'] = processor.get_train_examples()
    examples['dev'] = processor.get_dev_examples()
    examples['test'] = processor.get_test_examples()

    for key, value in examples.items():
        print('#{}: {}'.format(key, len(value)))
    return examples, processor.get_labels(task)


class InputExample:
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text, label=None):
        self.guid = guid
        self.text = text
        self.label = label


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id

    def __getitem__(self, item):
        return [self.input_ids, self.input_mask,
                self.segment_ids, self.label_id][item]


class DatasetProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_test_examples(self):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self, task_name):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                lines.append(line)
            return lines


class TSVDataProcessor(DatasetProcessor):
    """Processor for dataset to be augmented."""

    def __init__(self, data_dir, skip_header, label_col, text_col):
        self.data_dir = data_dir
        self.skip_header = skip_header
        self.label_col = label_col
        self.text_col = text_col

    def get_train_examples(self):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(self.data_dir, "train.tsv")), "train")

    def get_dev_examples(self):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(self.data_dir, "dev.tsv")), "dev")

    def get_test_examples(self):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(self.data_dir, "test.tsv")), "test")

    def get_labels(self, task_name):
        """add your dataset here"""
        labels = set()
        with open(os.path.join(self.data_dir, "train.tsv"), "r") as in_file:
            for line in in_file:
                labels.add(line.split("\t")[self.label_col])
        return sorted(labels)

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if self.skip_header and i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text = line[self.text_col]
            label = line[self.label_col]
            examples.append(
                InputExample(guid=guid, text=text, label=label))
        return examples


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()


In [None]:
class BertAugmentor(object):
  def __init__(self, model_dir='bert-base-chinese', beam_size=3):
    self.beam_size = beam_size
    self.model = pipeline('fill-mask', model=model_dir, top_k=beam_size)
    self.mask_token = self.model.tokenizer.mask_token

  def gen_sen(self, query, num_mask):
    '''{'sequence': ,'score' }'''
    tops = self.model(query)[0] if num_mask > 1 else self.model(query)
    num_mask -= 1
    while num_mask:
      qs = [x['sequence'] for x in tops]
      new_tops = self.model(qs)[0] if num_mask > 1 else self.model(qs)
      cur_tops = []
      for q, q_preds in zip(tops, new_tops):
        pre_score = q['score']
        for each in q_preds:
          each['cur_score'] = each['score']
          each['score'] = pre_score * each['score']
          cur_tops.append(each)
      tops = sorted(cur_tops, key=lambda x: x['score'], reverse=True)[:self.beam_size]
      num_mask -= 1
    return tops

  def word_replacement(self, query, n):
    label_len = len(query.label)
    out_arr = []
    aug_list = []
    seg_list = jieba.cut(query.text, cut_all=False)
    seg_list = [x for x in seg_list]
    set_index = [i + 2 for i, _ in enumerate(seg_list)]
    set_index.pop(0)
    seg_list = [query.label] + [" "] + seg_list
    # randomly sample n index to replace
    replace_index = random.sample(set_index, min(n, len(set_index)))
    for cur_index in replace_index:
      new_query = seg_list.copy()
      word_len = len(new_query[cur_index])
      new_word = [self.mask_token] * word_len
      new_query[cur_index] = ''.join(new_word)
      gen_qs = self.gen_sen(''.join(new_query), word_len)
      out_arr.extend(gen_qs)
    out_arr = sorted(out_arr, key=lambda x: x['score'], reverse=True)[:n]
    for seq in out_arr:
      x = seq['sequence']
      x = re.sub("(?<![ -~]) (?![ -~])", '', x)
      x_label, x_text = x[:label_len], x[label_len:]
      aug_list.append([x_label, x_text])
    return aug_list

  def word_insertion(self, query, n):
    label_len = len(query.label)
    out_arr = []
    aug_list = []
    seg_list = jieba.cut(query.text, cut_all=False)
    seg_list = [x for x in seg_list]
    set_index = [i + 3 for i, _ in enumerate(seg_list)]
    seg_list = [query.label] + [" "] + seg_list
    # randomly sample n index to replace
    insert_index = random.sample(set_index, min(n, len(set_index)))
    # randomly insert [MASK] between words
    for cur_index in insert_index:
      new_query = seg_list.copy()
      # randomly insert n characters with 1<=n<=3
      insert_num = np.random.randint(1, 4)
      for _ in range(insert_num):
        new_query.insert(cur_index, self.mask_token)
      gen_qs = self.gen_sen(''.join(new_query), insert_num)
      out_arr.extend(gen_qs)
    out_arr = sorted(out_arr, key=lambda x: x['score'], reverse=True)[:n]
    for seq in out_arr:
      x = seq['sequence']
      x = re.sub("(?<![ -~]) (?![ -~])", '', x)
      x_label, x_text = x[:label_len], x[label_len:]
      aug_list.append([x_label, x_text])
    return aug_list

  def aug(self, query, num_aug=9):
    num_new_per_technique = int(num_aug / 2)
    augmented_sentences = self.word_replacement(query, num_new_per_technique)
    augmented_sentences += self.word_insertion(query, num_new_per_technique+2)
    return augmented_sentences
  
  def augment(self, example, num_aug=9, aug_train=True):
    if aug_train:
      out = open(data_dir + '/train_aug.tsv', 'w')
    else:
      out = open(data_dir + '/dev_aug.tsv', 'w')
    out_writer = csv.writer(out, delimiter='\t')
    for query in example:
      for i in self.aug(query, num_aug):
        out_writer.writerow([i[0], i[1]])

In [None]:
for exp_id in range(10):
    data_dir = os.path.join("/content/gdrive/My Drive/project/tnews", "exp_{}".format(exp_id))
    processor = get_task_processor('tnews', data_dir)
    train_examples = processor.get_train_examples()
    # model = BertAugmentor(model_dir = data_dir+'/model')
    model = BertAugmentor()
    model.augment(train_examples)
  