In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
!pip install transformers==2.9.1

In [None]:
import random
import os
import csv
import torch
import jieba
import argparse

import random
from transformers import BertTokenizer
from transformers.modeling_bert import BertForSequenceClassification

from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import glue_convert_examples_to_features as convert_examples_to_features

from tqdm import tqdm

import torch
from torch.utils.data import DataLoader, TensorDataset

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

In [None]:
def get_task_processor(task, data_dir):
    """
    A TSV processor for stsa, trec and snips dataset.
    """
    if task == 'tnews':
      return TSVDataProcessor(data_dir=data_dir, skip_header=False, label_col=0, text_col=1)
    elif task == 'chnsenticorp':
        return TSVDataProcessor(data_dir=data_dir, skip_header=False, label_col=0, text_col=1)
    else:
        raise ValueError('Unknown task')


def get_data(task, aug_type, data_dir, aug_train = False, aug_dev = False, data_seed=159):
    random.seed(data_seed)
    processor = get_task_processor(task, data_dir)

    examples = dict()
    if not aug_train:
      examples['train'] = processor.get_train_examples()
    else:
      if aug_type == 'pbert':
        examples['train'] = processor.get_train_p_aug_examples()
      elif aug_type == 'cbert':
        examples['train'] = processor.get_train_c_aug_examples()
      elif aug_type == 'eda':
        examples['train'] = processor.get_train_e_aug_examples()

    if aug_dev:
      examples['dev'] = processor.get_dev_aug_examples()
    else:
      examples['dev'] = processor.get_dev_examples()
   
    examples['test'] = processor.get_test_examples()

    for key, value in examples.items():
        print('#{}: {}'.format(key, len(value)))
    return examples, processor.get_labels(task)


class InputExample:
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id

    def __getitem__(self, item):
        return [self.input_ids, self.input_mask,
                self.segment_ids, self.label_id][item]


class DatasetProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_test_examples(self):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self, task_name):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                lines.append(line)
            return lines


class TSVDataProcessor(DatasetProcessor):
    """Processor for dataset to be augmented."""

    def __init__(self, data_dir, skip_header, label_col, text_col):
        self.data_dir = data_dir
        self.skip_header = skip_header
        self.label_col = label_col
        self.text_col = text_col

    def get_train_examples(self):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(self.data_dir, "train.tsv")), "train")
    
    def get_train_p_aug_examples(self):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(self.data_dir, "train_aug.tsv")), "train")
    
    def get_train_c_aug_examples(self):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(self.data_dir, "cbert_aug.tsv")), "train")
        
    def get_train_e_aug_examples(self):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(self.data_dir, "eda_aug.tsv")), "train")

    def get_dev_examples(self):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(self.data_dir, "dev.tsv")), "dev")
    
    def get_dev_aug_examples(self):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(self.data_dir, "dev_aug.tsv")), "dev")

    def get_test_examples(self):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(self.data_dir, "test.tsv")), "test")

    def get_labels(self, task_name):
        """add your dataset here"""
        labels = set()
        with open(os.path.join(self.data_dir, "train.tsv"), "r") as in_file:
            for line in in_file:
                labels.add(line.split("\t")[self.label_col])
        return sorted(labels)

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if self.skip_header and i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_a = line[self.text_col]
            seg_list = jieba.cut(text_a, cut_all=False)
            seg_list = [x for x in seg_list]
            text_a = ' '.join(seg_list)
            label = line[self.label_col]
            examples.append(
                InputExample(guid=guid, text_a=text_a, label=label))
        return examples


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

In [None]:
BERT_MODEL = 'hfl/chinese-roberta-wwm-ext'
class Classifier:
    def __init__(self, label_list, device, cache_dir):
        self._label_list = label_list
        self._device = device

        self._tokenizer = BertTokenizer.from_pretrained(BERT_MODEL,
                                              do_basic_tokenize=False,
                                              do_lower_case=True,
                                              cache_dir=cache_dir)

        self._model = BertForSequenceClassification.from_pretrained(BERT_MODEL,
                                                                    num_labels=len(label_list),
                                                                    cache_dir=cache_dir)
        self._model.to(device)

        self._optimizer = None

        self._dataset = {}
        self._data_loader = {}

    def load_data(self, set_type, examples, batch_size, max_length, shuffle):
        self._dataset[set_type] = examples
        self._data_loader[set_type] = _make_data_loader(
            examples=examples,
            label_list=self._label_list,
            tokenizer=self._tokenizer,
            batch_size=batch_size,
            max_length=max_length,
            shuffle=shuffle)

    def get_optimizer(self, learning_rate, warmup_steps, t_total):
        self._optimizer, self._scheduler = _get_optimizer(
            self._model, learning_rate=learning_rate,
            warmup_steps=warmup_steps, t_total=t_total)

    def train_epoch(self):
        self._model.train()

        for step, batch in enumerate(tqdm(self._data_loader['train'],
                                          desc='Training')):
            batch = tuple(t.to(self._device) for t in batch)
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2],
                      'labels': batch[3]}

            self._optimizer.zero_grad()
            outputs = self._model(**inputs)
            loss = outputs[0]  # model
            loss.backward()
            self._optimizer.step()
            self._scheduler.step()

    def evaluate(self, set_type):
        self._model.eval()

        preds_all, labels_all = [], []
        data_loader = self._data_loader[set_type]

        for batch in tqdm(data_loader,
                          desc="Evaluating {} set".format(set_type)):
            batch = tuple(t.to(self._device) for t in batch)
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2],
                      'labels': batch[3]}

            with torch.no_grad():
                outputs = self._model(**inputs)
                tmp_eval_loss, logits = outputs[:2]
            preds = torch.argmax(logits, dim=1)

            preds_all.append(preds)
            labels_all.append(inputs["labels"])

        preds_all = torch.cat(preds_all, dim=0)
        labels_all = torch.cat(labels_all, dim=0)

        return torch.sum(preds_all == labels_all).item() / labels_all.shape[0]
    
    def analysis(self, set_type):
        self._model.eval()

        preds_all, labels_all = [], []
        data_loader = self._data_loader[set_type]

        for batch in tqdm(data_loader,
                          desc="Evaluating {} set".format(set_type)):
            batch = tuple(t.to(self._device) for t in batch)
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2],
                      'labels': batch[3]}

            with torch.no_grad():
                outputs = self._model(**inputs)
                tmp_eval_loss, logits = outputs[:2]
            preds = torch.argmax(logits, dim=1)

            preds_all.append(preds)
            labels_all.append(inputs["labels"])

        preds_all = torch.cat(preds_all, dim=0)+1
        labels_all = torch.cat(labels_all, dim=0)+1
        all_count = torch.bincount(labels_all)[1:]
        correct_labels = (preds_all == labels_all) * preds_all
        count = torch.bincount(correct_labels)[1:]
        
        return 'class, total number of entries, number of correctly classified entries, accuracy' ,list(zip(self._label_list, all_count, count, count/all_count))


def _get_optimizer(model, learning_rate, warmup_steps, t_total):
    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': 0.01},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                                num_training_steps=t_total)
    return optimizer, scheduler


def _make_data_loader(examples, label_list, tokenizer, batch_size, max_length, shuffle):
    features = convert_examples_to_features(examples,
                                            tokenizer,
                                            label_list=label_list,
                                            max_length=max_length,
                                            output_mode="classification")

    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)

    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

In [None]:
random.seed()
torch.backends.cudnn.deterministic = True
epochs = 10
min_epochs = 0
learning_rate = 4e-5
warmup_steps = 100
batch_size = 15
max_seq_length = 64
hidden_dropout_prob = 0.1
cache = "transformers_cache"

examples, label_list = get_data(
    task='tnews',
    aug_type = 'pbert',
    data_dir = '/content/gdrive/My Drive/project/tnews/exp_0',
    aug_train = False,
    aug_dev = False,
    data_seed = 42)



t_total = len(examples['train']) // epochs

classifier = Classifier(label_list=label_list, device=device, cache_dir = cache)
classifier.get_optimizer(learning_rate=learning_rate,
                          warmup_steps=warmup_steps,
                          t_total=t_total)

classifier.load_data(
    'train', examples['train'], batch_size, max_length=max_seq_length, shuffle=True)
classifier.load_data(
    'dev', examples['dev'], batch_size, max_length=max_seq_length, shuffle=True)
classifier.load_data(
    'test', examples['test'], batch_size, max_length=max_seq_length, shuffle=True)

print('=' * 60, '\n', 'Training', '\n', '=' * 60, sep='')
best_dev_acc, final_test_acc = -1., -1.
for epoch in range(epochs):
    classifier.train_epoch()
    dev_acc = classifier.evaluate('dev')

    if epoch >= min_epochs:
        do_test = (dev_acc > best_dev_acc)
        best_dev_acc = max(best_dev_acc, dev_acc)
    else:
        do_test = False

    print('Epoch {}, Dev Acc: {:.4f}, Best Ever: {:.4f}'.format(
        epoch, 100. * dev_acc, 100. * best_dev_acc))

    if do_test:
        final_test_acc = classifier.evaluate('test')
        print('Test Acc: {:.4f}'.format(100. * final_test_acc))

print('Final Dev Acc: {:.4f}, Final Test Acc: {:.4f}'.format(
    100. * best_dev_acc, 100. * final_test_acc))


In [None]:
# print(classifier._label_list)
classifier.analysis('test')