Skip to content

Latest commit

 

History

History
131 lines (116 loc) · 5.46 KB

pytorch处理文本数据代码版本2-处理文本相似度数据.md

File metadata and controls

131 lines (116 loc) · 5.46 KB

pytorch处理文本数据代码版本2-处理文本相似度数据

这里代码参考的是:https://github.com/DA-southampton/TextMatch/blob/master/SiaGRU/data.py 感谢原作者

# -*- coding: utf-8 -*-
"""
Created on Thu Mar 12 15:30:14 2020

@author: zhaog
"""
import re
import gensim
import numpy as np
import pandas as pd
import torch
from hanziconv import HanziConv  ##dasou:中文文本处理库
from torch.utils.data import Dataset

class LCQMC_Dataset(Dataset):
    def __init__(self, LCQMC_file, vocab_file, max_char_len):
        p, h, self.label = load_sentences(LCQMC_file)
        word2idx, _, _ = load_vocab(vocab_file)
        self.p_list, self.p_lengths, self.h_list, self.h_lengths = word_index(p, h, word2idx, max_char_len)
        self.p_list = torch.from_numpy(self.p_list).type(torch.long)
        self.h_list = torch.from_numpy(self.h_list).type(torch.long)
        self.max_length = max_char_len
        
    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        return self.p_list[idx], self.p_lengths[idx], self.h_list[idx], self.h_lengths[idx], self.label[idx]
    
# 加载word_index训练数据
##dasou: 使用了pandas这个库,将文本相似度数据相同的列提取出来进行处理,而不是针对每一行一个样本进行处理,其实看到这里这个代码存在的一个问题就是如果将来
##出来大的数据,也就是大的文件,pandas是没有办法直接全部读进来的,这是个缺点,不过对几个G的数据应该不存在这种问题
def load_sentences(file, data_size=None):
    df = pd.read_csv(file,sep='\t',header=None)##dasou 为了适应我的数据格式
    p = map(get_word_list, df[0].values[0:data_size]) ## p的每个元素类似这种 ['晚', '上', '尿', '多', '吃', '什', '么', '药']
    h = map(get_word_list, df[1].values[0:data_size])
    label = df[2].values[0:data_size]
    #p_c_index, h_c_index = word_index(p, h)
    return p, h, label

# word->index
def word_index(p_sentences, h_sentences, word2idx, max_char_len):
    p_list, p_length, h_list, h_length = [], [], [], []
    for p_sentence, h_sentence in zip(p_sentences, h_sentences):
        p = [word2idx[word] for word in p_sentence if word in word2idx.keys()]
        h = [word2idx[word] for word in h_sentence if word in word2idx.keys()]
        p_list.append(p)
        p_length.append(min(len(p), max_char_len))
        h_list.append(h)
        h_length.append(min(len(h), max_char_len))
    p_list = pad_sequences(p_list, maxlen = max_char_len)
    h_list = pad_sequences(h_list, maxlen = max_char_len)
    return p_list, p_length, h_list, h_length

# 加载字典
def load_vocab(vocab_file):
    vocab = [line.strip() for line in open(vocab_file, encoding='utf-8').readlines()]
    word2idx = {word: index for index, word in enumerate(vocab)}
    idx2word = {index: word for index, word in enumerate(vocab)}
    return word2idx, idx2word, vocab

''' 把句子按字分开,中文按字分,英文数字按空格, 大写转小写,繁体转简体'''
def get_word_list(query):
    query = HanziConv.toSimplified(query.strip())
    regEx = re.compile('[\\W]+')#我们可以使用正则表达式来切分句子,切分的规则是除单词,数字外的任意字符串
    res = re.compile(r'([\u4e00-\u9fa5])')#[\u4e00-\u9fa5]中文范围
    sentences = regEx.split(query.lower())
    str_list = []
    for sentence in sentences:
        if res.split(sentence) == None:
            str_list.append(sentence)
        else:
            ret = res.split(sentence)
            str_list.extend(ret)
    return [w for w in str_list if len(w.strip()) > 0]

def load_embeddings(embdding_path):
    model = gensim.models.KeyedVectors.load_word2vec_format(embdding_path, binary=False)
    embedding_matrix = np.zeros((len(model.index2word) + 1, model.vector_size))
    #填充向量矩阵
    for idx, word in enumerate(model.index2word):
        embedding_matrix[idx + 1] = model[word]#词向量矩阵
    return embedding_matrix

def pad_sequences(sequences, maxlen=None, dtype='int32', padding='post',
                  truncating='post', value=0.):
    """ pad_sequences
    把序列长度转变为一样长的,如果设置了maxlen则长度统一为maxlen,如果没有设置则默认取
    最大的长度。填充和截取包括两种方法,post与pre,post指从尾部开始处理,pre指从头部
    开始处理,默认都是从尾部开始。
    Arguments:
        sequences: 序列
        maxlen: int 最大长度
        dtype: 转变后的数据类型
        padding: 填充方法'pre' or 'post'
        truncating: 截取方法'pre' or 'post'
        value: float 填充的值
    Returns:
        x: numpy array 填充后的序列维度为 (number_of_sequences, maxlen)
    """
    lengths = [len(s) for s in sequences]
    nb_samples = len(sequences)
    if maxlen is None:
        maxlen = np.max(lengths)
    x = (np.ones((nb_samples, maxlen)) * value).astype(dtype)
    for idx, s in enumerate(sequences):
        if len(s) == 0:
            continue  # empty list was found
        if truncating == 'pre':
            trunc = s[-maxlen:]
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError("Truncating type '%s' not understood" % padding)
        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError("Padding type '%s' not understood" % padding)
    return x