# 模型训练之前，进行的batch操作
### 1、tensorflow版本的dataset

### 2、pytorch版本的dataset

In [1]:
import tensorflow as tf
import torch

### tokenier/vocab

In [4]:
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'

PAD_TOKEN = '[PAD]'
UNKNOWN_TOKEN = '[UNK]'
START_DECODING = '[START]'
STOP_DECODING = '[STOP]'

class Vocab:
    def __init__(self,vocab_path,max_size):
        self.word2id={UNKNOWN_TOKEN:0,PAD_TOKEN:1,START_DECODING:2,STOP_DECODING:3}
        self.id2word={0:UNKNOWN_TOKEN,1:PAD_TOKEN,2:START_DECODING,3:STOP_DECODING}
        self.count=4
        
        with open(vocab_path,"r",encoding="utf-8") as f:
            for line in f:
                pair=line.split()
                if len(pair)!=2:
                    print("Warning:the error line is :{}"%line)
                w=pair[0]
                if w in [SENTENCE_START,SENTENCE_END,PAD_TOKEN,UNKNOWN_TOKEN,START_DECODING,STOP_DECODING]:
                    raise Exception("word <%s> is exception"%w)
                if w in self.word2id:
                    raise Exception("word <%s> is repetition")
                
                if max_size>=self.count:
                    self.word2id[w]=self.count
                    self.id2word[self.count]=w
                    self.count+=1
                else:
                    break
        
    def word_to_id(self,word):
        if w not in self.word2id:
            return self.word2id[UNKNOWN_TOKEN]
        return self.word2id[word]
    
    def id_to_word(self,w_id):
        if w_id not in self.id2word:
            raise Exception("id %s can not found"%w_id)
        return self.id2word[w_id]
    
    def size(self):
        return self.count

### 1、tf版本1
#### 读取数据：
    * tf.data.TextLineDataset(txt_path)----->将一句输入 整合 指定格式输出  
    * tf.data.Dataset.zip((dataset_train_x, dataset_train_y))   #当作zip使用
    * train_dataset.shuffle(1000, reshuffle_each_iteration=True).repeat()   #shuffle+repeat()，buffer_size=1000
#### 将迭代器处理成dataset:
    * tf.data.Dataset.from_generator(generator,output_types={key:tf.int32....},output_shape={key:shape...})   #generator必须是迭代器
    为什么要再用这个？
    答：因为上面读取数据后，还要进行train、test、eval不同状况下的处理，返回的数据不是dataset类型了，而是字典。。。。
#### 构建batch数据,并设置默认填充：
    * dataset.padded_batch(batch_size,padded_shapes={key:shape...},padding_values={key:value...},drop_reminder=true)
    drop_reminder:表示最后一个不满足的batch是否丢弃
#### map操作：
    * 对dataset的所有输出进行格式或者映射(map)的处理

In [5]:
class Batcher:
    def __init__(self,vocab,train_x_path,train_y_path,test_x_path,eval_x_path,eval_y_path,max_enc_len, max_dec_len):
        self.vocab=vocab
        self.train_x_path=train_x_path
        self.train_y_path=train_y_path
        self.test_x_path=test_x_path
        self.eval_x_path=eval_x_path
        self.eval_y_path=eval_y_path
        self.max_enc_len=max_enc_len
        self.max_dec_len=max_dec_len
    
    #加载数据+处理数据
    def example_generator(self,mode,batch_size):
        if mode=="train":
            dataset_train_x=tf.data.TextLineDataset(self.train_x_path)   #读取文件
            dataset_train_y=tf.data.TextLineDataset(self.train_y_path)
            data_train=tf.data.Dataset.zip((dataset_train_x, dataset_train_y))  #zip x,y
            data_train=data_train.shuffle(1000,reshuffle_each_iteration=True).repeat() #打乱
            for train_d in data_train:
                #可对每一行的enc_x,dec_y进行处理
                x,y=train_d
                x=x.numpy().decode("utf-8")  #转numpy-->str
                y=y.numpy().decode("utf-8")
                enc_x=x.split()[:self.max_enc_len]   #分词,限制长度处理
                dec_x=y.split()[:self.max_dec_len]   #分词，限制长度处理
                enc_len=len(enc_x)
                dec_len=len(dec_x)
                enc_x=enc_x+[0]*(self.max_enc_len-enc_len)   #填充
                enc_x=[self.vocab.word_to_id(w) for w in enc_x]  #转id
                start_id=self.vocab.word_to_id(START_DECODING)
                stop_id=self.vocab.word_to_id(STOP_DECODING)
                dec_input,dec_outputs=get_dec_inp(dec_x,start_id,stop_id)
                abstract_sentences=[""]   #预测用到
                output = {
                "enc_len": enc_len,
                "enc_input": enc_x,
                "dec_input": dec_input,
                "target": dec_outputs,
                "dec_len": dec_len,
                "article": x,
                "abstract": y,
                "abstract_sents": abstract_sentences
                }
                yield output
                
            
        elif mode=="test":
            dataset_test_x=tf.data.TextLineDataset(self.test_x_path)
            for test_d in dataset_test_x:
                #可对每一行的enc_x进行处理
                x=test_d.numpy().decode("utf-8")
                enc_x=x.split()[:self.max_enc_len]   #分词,限制长度处理
                enc_len=len(enc_x)
                enc_x=enc_x+[0]*(self.max_enc_len-enc_len)   #填充
                enc_x=[self.vocab.word_to_id(w) for w in enc_x]  #转id
                abstract_sentences=[]   #预测用到
                output = {
                "enc_len": enc_len,
                "enc_input": enc_x,
                "dec_input": [],
                "target": [],
                "dec_len": self.max_dec_len,
                "article": x,
                "abstract": '',
                "abstract_sents": abstract_sentences
                }
                yield output
                
        else:
            dataset_eval_x=tf.data.TextLineDataset(self.eval_x_path)
            dataset_eval_y=tf.data.TextLineDataset(self.eval_y_path)
            data_evval=tf.data.Dataset.zip((dataset_eval_x,dataset_eval_y))
            for eval_d in data_evval:
                #可对每一行的enc_x,dec_y进行处理
                x,y=train_d
                x=x.numpy().decode("utf-8")  #转numpy-->str
                y=y.numpy().decode("utf-8")
                enc_x=x.split()[:self.max_enc_len]   #分词,限制长度处理
                enc_len=len(enc_x)
                enc_x=enc_x+[0]*(self.max_enc_len-enc_len)   #填充
                enc_x=[self.vocab.word_to_id(w) for w in enc_x]  #转id
                abstract_sentences=[]   #预测用到
                output = {
                "enc_len": enc_len,
                "enc_input": enc_x,
                "dec_input": [],
                "target": [],
                "dec_len": self.max_dec_len,
                "article": x,
                "abstract": y,
                "abstract_sents": abstract_sentences
                }
                yield output
        
        
    def batch_generator(self,batch_size, mode):
        dataset=tf.data.Dataset.from_generator(
            lambda:self.example_generator(mode,batch_size),
            output_types={
                "enc_len": tf.int32,
                "enc_input": tf.int32,
                "dec_input": tf.int32,
                "target": tf.int32,
                "dec_len": tf.int32,
                "article": tf.string,
                "abstract": tf.string,
                "abstract_sents": tf.string
                },
            output_shape={
                "enc_len": [],
                "enc_input":[None],
                "dec_input": [None],
                "target": [None],
                "dec_len": [],
                "article": [],
                "abstract": [],
                "abstract_sents":[None]
                }
        )
        dataset=dataset.padded_batch(
            batch_size,
            padded_shapes={
                "enc_len": [],
                "enc_input":[None],
                "dec_input": [self.max_dec_len],
                "target": [self.max_dec_len],
                "dec_len": [],
                "article": [],
                "abstract": [],
                "abstract_sents":[None]
            },
            padding_values={
                "enc_len": -1,
                "enc_input":1,
                "dec_input":1 ,
                "target": 1,
                "dec_len": -1,
                "article": b'',
                "abstract": b'',
                "abstract_sents":b''
            },
            drop_reminder=True
        )
        
        def update(entry):
            return (
                {
                "enc_len": entry["enc_len"],
                "enc_input":entry["enc_input"],
                "article": entry["article"]
                },
                {
                "dec_input":entry["dec_input"],
                "target":entry["target"],
                "dec_len":entry["dec_len"],
                "abstract":entry["abstract"],
                "abstract_sents":entry["abstract_sents"]
                }
            )
        dataset=dataset.map(update)
        return dataset
    
    def batcher(self,params):
        return self.batch_generator(params["batch_size"],params["mode"])
    
    #dec文本进行处理
    def get_dec_inp(self,seq,start_id,stop_id):
        seq_id=[self.vocab.word_to_id(w) for w in seq]
        dec_inp=[start_id]+seq_id
        dec_inp=dec_inp[:self.max_dec_len]
        dec_out=dec_inp[1:]+[stop_id]
        assert len(dec_inp)==len(dec_out)
        return dec_inp,dec_out
        

# tf - tfrecord方式
偷个懒，后期更新。。。。