Skip to content

Latest commit

 

History

History
147 lines (125 loc) · 5.29 KB

pytorch处理文本数据代码版本1-处理文本相似度数据.md

File metadata and controls

147 lines (125 loc) · 5.29 KB

pytorch处理文本数据代码版本1-处理文本相似度数据

下面的代码,相比于版本2的代码,并没有使用gensim,而且处理的时候针对的是每一个样本,也就是每一行,也就是 sentence1和sentence2并没有拆开来处理。

整体代码是我自己完全整理出来的,比较整齐

"""
@author: DASOU
@time: 20200726
"""
import torch
import os
import pickle as pkl

## 读取原始数据,生成对应的word2index
def get_word_voc(config_base):
    train_path=config_base.train_path
    file=open(train_path,'r')
    lines=file.readlines()
    min_freq,max_size,UNK,PAD=config_base.min_freq,config_base.max_size,config_base.UNK,config_base.PAD
    vocab_dic={}
    for line in lines:
        try:
            line=line.strip().split('\t')
        except:
            print('The data formate is not correct,please correct it as example data')
            exit()
        try:
            if len(line)==3:
                sen=line[0]+line[1]
                tokenizer = lambda x: [y for y in x]
                for word in tokenizer(sen):
                    vocab_dic[word] = vocab_dic.get(word, 0) + 1 ## 为了计算出每个单词的词频,为之后过滤低频词汇做准备
        except:
            print('The data formate is not correct,please correct it as example data')
            exit()
    file.close()
    vocab_list = sorted([_ for _ in vocab_dic.items() if _[1] >= min_freq], key=lambda x: x[1], reverse=True)[:max_size]## 是为了计算每个单词的词频
    vocab_dic = {word_count[0]: idx for idx, word_count in enumerate(vocab_list)}## 过滤掉低频词汇之后我们按照顺序来word-index的映射
    vocab_dic.update({UNK: len(vocab_dic), PAD: len(vocab_dic) + 1}) ## 补充unkonw和pad字符对应的数字
    return vocab_dic


def load_data(cate,vocab_dic,config_base):
    if cate=='train':
        data_path=config_base.train_path
    elif cate=='dev':
        data_path = config_base.dev_path
    else:
        data_path = config_base.test_path
    file=open(data_path,'r')
    contents=[]
    for line in file.readlines():
        words_line1=[]
        words_line2=[]
        line=line.strip().split('\t')
        sen1,sen2,label=line[0],line[1],line[2]
        tokenizer = lambda x: [y for y in x]
        token_sen1=tokenizer(sen1)
        token_sen2 = tokenizer(sen2)
        sen1_len = len(token_sen1)
        sen2_len = len(token_sen2)

        if config_base.pad_size:
            if len(token_sen1) < config_base.pad_size:
                token_sen1.extend([config_base.PAD] * (config_base.pad_size - len(token_sen1)))
            else:
                token_sen1 = token_sen1[:config_base.pad_size]

            if len(token_sen2) < config_base.pad_size:
                token_sen2.extend([config_base.PAD] * (config_base.pad_size - len(token_sen2)))
            else:
                token_sen2 = token_sen2[:config_base.pad_size]
        for word1 in token_sen1:
            words_line1.append(vocab_dic.get(word1, vocab_dic.get(config_base.UNK)))

        for word2 in token_sen2:
            words_line2.append(vocab_dic.get(word2, vocab_dic.get(config_base.UNK)))
        contents.append((words_line1,words_line2,int(label)))
    return contents

# 导入/训练对应的word2index
def get_w2i(config_base):

    if not os.path.exists(config_base.w2i_path):
        print('There is not a pre word2index,now is to process data for geting word2index')
        vocab_dic = get_word_voc(config_base)
        pkl.dump(vocab_dic, open(config_base.w2i_path, 'wb'))
        vord_size = len(vocab_dic)
    else:
        print('There is pre word2index, now is to load the pre infomation')
        vocab_dic = pkl.load(open(config_base.w2i_path, 'rb'), encoding='utf-8')
        vord_size = len(vocab_dic)
    return vocab_dic,vord_size

class DatasetIterater():
    def __init__(self, batches, config_base):
        self.batch_size = config_base.batch_size
        self.batches = batches
        self.n_batches = len(batches) // config_base.batch_size
        self.residue = False  # 记录batch数量是否为整数
        if len(batches) % self.n_batches != 0:
            self.residue = True
        self.index = 0
        self.device = config_base.device

    def _to_tensor(self, datas):
        x1 = torch.LongTensor([_[0] for _ in datas]).to(self.device)
        x2 = torch.LongTensor([_[1] for _ in datas]).to(self.device)
        y = torch.LongTensor([_[2] for _ in datas]).to(self.device)

        return (x1, x2), y

    def __next__(self):
        if self.residue and self.index == self.n_batches:
            batches = self.batches[self.index * self.batch_size: len(self.batches)]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches

        elif self.index >= self.n_batches:
            self.index = 0
            raise StopIteration
        else:
            batches = self.batches[self.index * self.batch_size: (self.index + 1) * self.batch_size]
            self.index += 1
            batches = self._to_tensor(batches)
            return batches

    def __iter__(self):
        return self

    def __len__(self):
        if self.residue:
            return self.n_batches + 1
        else:
            return self.n_batches

def build_iterator(dataset,config_base):
    iter = DatasetIterater(dataset,config_base)
    return iter