In [58]:
from torch.utils.data import Dataset, DataLoader
from setting import ENG_FRA_PATH
import collections
import torch
import re
from typing import List

## 通用文件读取函数

In [10]:
def read_data(path:str,encoding='utf-8'):
    with open(path,'r+',encoding=encoding) as file:
        return file.readlines()

读取翻译数据对

In [12]:
lines = read_data(ENG_FRA_PATH)

处理特殊字符的正则表达式

In [23]:
sub_re = re.compile('[\n\u202f\xa0]')

处理句子的函数

In [40]:
def no_space(char, prev_char):
    return char in set(',.!?') and prev_char != ' '
def process_line(line):
    line = sub_re.sub(' ',line).lower().strip()
    out = [' ' + char if i > 0 and no_space(char, line[i - 1]) else char for i, char in enumerate(line)]
    return ''.join(out)

进行转化

In [43]:
lines = list(map(process_line,lines))

# 分词

In [92]:
def tokenize(lines:str):
    sources,targets=[],[]
    for idx,line in enumerate(lines,0):
        source,target = line.split('\t')
        sources.append(source.split())
        targets.append(target.split())
    return sources,targets

In [77]:
sources,targets = tokenize(lines)

# 词典

In [64]:
class Vocab:
    def __init__(self,tokens:List[List[str]],min_freq,reserved_tokens:List[str]):
        """
        :param tokens:          一列是一句话的分词结果
        :param min_freq:        最小词频
        :param reserved_tokens: 保留的单词
        """
        # 词频统计
        counter = Vocab.count_corpus(tokens)
        self._token_freq = sorted(counter.items(),key=lambda x:x[1],reverse=True)
        # 转换表
        self.idx2token = ['<unk>']+reserved_tokens
        self.token2idx = {token:idx for idx,token in enumerate(self.idx2token)}
        # 替换低频词
        for token,freq in self._token_freq:
            if freq < min_freq:
                break
            elif token not in self.token2idx:
                self.idx2token.append(token)
                self.token2idx[token] = len(self.idx2token)-1
    def __len__(self):
        return len(self.idx2token)
    def __getitem__(self, tokens):
        if not isinstance(tokens,(list,tuple)):
            return self.token2idx.get(tokens,self.unk)
        else:
            return [self.__getitem__(token) for token in tokens]
    def to_tokens(self,indices):
        if not isinstance(indices,(list,tuple)):
            return self.idx2token[indices]
        else:
            return [self.idx2token[idx] for idx in indices]
    @property
    def unk(self):
        return 0
    @property
    def token_freq(self):
        return self._token_freq
    @classmethod
    def count_corpus(cls,tokens):
        if len(tokens)!=0 and isinstance(tokens[0],(tuple,list)):
            tokens = [token for line in tokens for token in line]
        else:
            raise TypeError
        return collections.Counter(tokens)

In [78]:
src_vocab = Vocab(sources,2,['<pad>','<bos>','<eos>'])
tar_vocab = Vocab(targets,2,['<pad>','<bos>','<eos>'])

句子最大长度

In [76]:
def max_seq_len(sentences):
    if not isinstance(sentences,(list,tuple)):
        raise TypeError
    return max(sentences,key=lambda x:len(x))

In [79]:
' '.join(max_seq_len(sources))

'me , too .'

# 语料库

In [138]:
class EngFra(Dataset):
    def __init__(self, data_path:str, num_example, min_freq:int=2, reserved_tokens:List=None):
        # 默认填充词
        if reserved_tokens is None:
            reserved_tokens = []
        reserved_tokens.extend(['<pad>', '<bos>', '<eos>'])
        # 原始语料库
        self.sources,self.targets = self._load_data(data_path,num_example)
        # 词汇表
        self.src_vocab = Vocab(self.sources,min_freq,reserved_tokens)
        self.tar_vocab = Vocab(self.targets,min_freq,reserved_tokens)
        # 语料库句子最大长度
        self.src_seq_len = len(max_seq_len(self.sources)) + 1
        self.tar_seq_len = len(max_seq_len(self.targets)) + 1
        # 语料预处理-seq2tensor
        self.t_src,self.t_tar = self.process_seq(self.src_seq_len,self.tar_seq_len,'<pad>','<eos>',len(self.sources))

    def process_seq(self,src_seq_len,tar_seq_len,pad,eos,len):
        # 原始语料padding
        _1,_2,_3 = [src_seq_len]*len,[pad]*len,[eos]*len
        sources = list(map(self.padding,self.sources,_1,_2,_3))
        _1 = [tar_seq_len]*len
        targets = list(map(self.padding,self.targets,_1,_2,_3))
        # padding后的语料转tensor
        sources = self.seq2tensor(sources,self.src_vocab)
        targets = self.seq2tensor(targets,self.tar_vocab)
        return sources,targets

    @classmethod
    def padding(cls,line:list,seq_len,padding_token:str,eos_token:str):
        return line + [eos_token]+ [padding_token]*(seq_len-len(line)),len(line)
    @classmethod
    def seq2tensor(cls,pairs:tuple,vocab:Vocab):
        res = []
        for pair in pairs:
            seq,valid_len = pair
            seq = torch.tensor([vocab[token] for token in seq],dtype=torch.long)
            valid_len = torch.tensor(valid_len,dtype=torch.long)
            res.append((seq,valid_len))
        return res

    @classmethod
    def _load_data(cls,data_path,num_example=None):
        sub_re = re.compile('[\n\u202f\xa0]')
        lines = list(map(process_line,read_data(data_path,encoding='utf-8')))
        if num_example is None:
            return tokenize(lines)
        else:
            return tokenize(lines[:num_example])
    def __getitem__(self, item):
        return self.t_src[item],self.t_tar[item]
    def __len__(self):
        return len(self.sources)

In [165]:
data = EngFra(data_path=ENG_FRA_PATH,num_example=600,min_freq=2)

In [166]:
dataloader = DataLoader(dataset=data,batch_size=2)

In [168]:
for idx,pair_data in enumerate(dataloader,0):
    break

In [169]:
src,tar = pair_data

In [170]:
src_seq,src_len = src

In [171]:
[data.src_vocab.idx2token[idx] for idx in src_seq[0].detach()]

['go', '.', '<eos>', '<pad>', '<pad>', '<pad>']