In [1]:
import torch 
import time 
from sklearn.model_selection import KFold

import torch.nn as nn
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from gensim.models.phrases import Phrases
from gensim.models import Word2Vec

from src.train_utils import set_seed, ModelSave, get_torch_device, EarlyStop, TrainParams
from src.evaluation import classification_inference
from src.metric import  multi_cls_metrics,multi_cls_log
from src.dataset.tokenizer import GensimTokenizer

from iflytek_app.dataset import MixDataset
from iflytek_app.models import Textcnn
from iflytek_app.process import train_process, test_process, result_process

device = get_torch_device()
set_seed()

No GPU available, using the CPU instead.


In [2]:
c2v = GensimTokenizer( Word2Vec.load('./checkpoint/char_min1_win5_sg_d100'))
w2v = GensimTokenizer(Word2Vec.load('./checkpoint/phrase_min1_win5_sg_d100'))
w2v.init_vocab()
c2v.init_vocab()
phraser = Phrases.load('./checkpoint/phrase_tokenizer')
df, label2idx = train_process()
test = test_process()

                id           l1           l2          len
count  4199.000000  4199.000000  4199.000000  4199.000000
mean   2099.000000     8.969278    37.087878    46.057156
std    1212.291219     4.576621    79.204914    79.332999
min       0.000000     2.000000     1.000000     4.000000
25%    1049.500000     5.000000     6.000000    15.000000
50%    2099.000000     8.000000    12.000000    22.000000
75%    3148.500000    12.000000    26.000000    36.000000
max    4198.000000    32.000000   946.000000   961.000000
{'14784131 14858934 14784131 14845064': 0, '14852788 14717848 15639958 15632020': 1, '14844856 14724258 14925237 14854807': 2, '14925756 15639967 14853254 14728639': 3, '14844593 14924945': 4, '15709098 14716590 14924703 14779559': 5, '14726332 14728344 14854542 14844591': 6, '14858934 15636660 15704193 14849963': 7, '15710359 14847407 14845602 14859696': 8, '14794687 14782344': 9, '15630486 15702410 14718849 15709093': 10, '15632285 15706536 14721977 14925219': 11, '147829

In [3]:
# -*-coding:utf-8 -*-

from src.preprocess.str_utils import *
import random
from concurrent.futures import ThreadPoolExecutor


class Augmenter(object):
    """
    Action: Delete, Swap, Substitute
    Granularity: char, word, entity, sentence
    """

    def __init__(self, min_sample, max_sample, prob):
        self.min_sample = min_sample
        self.max_sample = max_sample
        self.prob = prob

    def action(self, text):
        """
        Core Augment action
        """
        raise NotImplementedError

    @staticmethod
    def _data_check(data):
        if not data or len(data) == 0:
            return False
        else:
            return True

    def _param_check(self):
        pass

    def augment(self, data, n_thread=4):
        """

        :param augment:
        :param data:
        :param n:
        :param n_thead:
        :return:
        """
        min_output = len(data) * self.min_sample
        max_output = len(data) * self.max_sample  # 默认增强样本<=原始样本
        max_retry = 3
        result = set()  # only keep non-dupulicate
        for _ in range(max_retry):
            with ThreadPoolExecutor(n_thread) as executor:
                for aug_data in executor.map(self.action, data):
                    if self._data_check(aug_data):
                        result.add(aug_data)
            if len(result) > min_output:
                break
        if len(result) > max_output:
            return random.sample(result, max_output)
        else:
            return result
        
        
class WordAugmenter(Augmenter):
    def __init__(self, min_sample, max_sample, prob, tokenizer):
        super(WordAugmenter, self).__init__(min_sample, max_sample, prob)
        self.filters = [stop_word_handler, punctuation_handler, emoji_handler]
        self.tokenizer = tokenizer 

    def get_aug_index(self, tokens):
        index = set()
        for i, t in enumerate(tokens):
            if any(f.check(t) for f in self.filters):
                continue
            index.add(i)
        return index


class W2vSynomous(WordAugmenter):
    def __init__(self, min_sample, max_sample, prob, tokenizer, topn=10):
        super(W2vSynomous, self).__init__(min_sample, max_sample, prob, tokenizer)
        self.topn = topn

    def gen_synom(self, word):
        if random.random() < self.prob:
            try:
                nn = self.tokenizer.model.most_similar(word, topn=self.topn)
                return random.choice(nn)[0]
            except:
                return None
        else:
            return None

    def action(self, text):
        new_sample = []
        words = self.tokenizer.tokenize(text)
        flag = False
        for i, t in enumerate(words):
            if i in self.get_aug_index(words):
                self.gen_synom(t)
                flag = True
            else:
                new_sample.append(t)
        if flag:
            return new_sample
        else:
            return None


class WordnetSynomous(WordAugmenter):
    def __init__(self, min_sample, max_sample, prob, tokenizer):
        super(WordnetSynomous, self).__init__(min_sample, max_sample, prob, tokenizer)
        self.wordnet = self.load('word_net.text')

    def load(self, file):
        wordnet = {}
        with open(file, 'r') as f:
            for line in f:
                line = line.strip().split(" ")
                if not line[0].endswith('='):
                    continue
                for i in range(1, len(line)):
                    wordnet[line[i]] = line[1:i] + line[(i + 1):]
        return wordnet

    def gen_synom(self, word):
        if word in self.wordnet and random.random() < self.prob:
            return random.choice(self.wordnet[word])
        else:
            return word

    def action(self, text):
        new_sample = []
        tokens = self.tokenizer.tokenize(text)
        flag = False
        for i, t in enumerate(tokens):
            if i in self.get_aug_index(tokens):
                self.gen_synom(t)
                flag = True
            else:
                new_sample.append(t)
        if flag:
            return tokens
        else:
            return None


class WordShuffle(WordAugmenter):
    def __init__(self, min_sample, max_sample, prob, tokenizer):
        super(WordShuffle, self).__init__(min_sample, max_sample, prob, tokenizer)

    def get_swap_pos(self, left, right):
        if random.random() < self.prob:
            return random.randint(left, right)
        else:
            return left - 1

    def action(self, text):
        new_sample = []
        tokens = self.tokenizer.tokenize(text)
        l = len(text)
        for i, t in enumerate(tokens):
            if i in self.get_aug_index(tokens):
                pos = self.get_swap_pos(i + 1, l - 1)
                tokens[i], tokens[pos] = tokens[pos], tokens[i]
            new_sample.append(tokens[i])
        return new_sample


class WordDelete(WordAugmenter):
    def __init__(self, min_sample, max_sample, prob, tokenizer):
        super(WordDelete, self).__init__(min_sample, max_sample, prob, tokenizer)

    def action(self, text):
        new_sample = []
        tokens = self.tokenizer.tokenize(text)
        for i, t in enumerate(tokens):
            if i in self.get_aug_index(tokens):
                if random.random()< self.prob:
                    continue
            new_sample.append(t)
        return new_sample

## Random Delete 

## Random Insert