In [27]:
import random
import numpy as np
from tensorflow.python.client import device_lib#获取cpu，gpu操作包
from word_sequence import WordSequence

VOCAB_SIZE_THRESHOLD_CPU = 50000

'''获取当前GPU信息'''
def _get_available_gpus():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos if x.device_type == 'GPU']

'''根据输入输出的字典大小来选择，是在CPU上embedding还是在GPU上进行embedding'''
def _get_embed_device(vocab_size):#超过某个临界值选择在cpu处理
    gpus = _get_available_gpus()
    if not gpus or vocab_size > VOCAB_SIZE_THRESHOLD_CPU:
        return "/cpu:0"
    return "/gpu:0"


def transform_sentence(sentence, ws, max_len=None, add_end=False):#句子，word_sequence，最大长度
    """ 单独的句子的转换 """
    encoded = ws.transform(
        sentence,
        max_len= max_len if max_len is not None else len(sentence))
    encoded_len = len(sentence) + (1 if add_end else 0)
    if encoded_len > len(encoded):
        encoded_len = len(encoded)
    #[4, 4, 5, 6]
    return encoded, encoded_len


def batch_flow(data, ws, batch_size, raw=False, add_end=True):#ws可以使一个WordSequence的对象也可是对象的列表或元组
                                                                #data同理
    """
    从数据中随机去生成batch_size的数据，然后给转换后输出出去
    raw:是否返回原始对象，如果为True，假设结果ret， 那么len(ret) == len(data) * 3
        如果为false，那么len(ret) == len(data） * 2
    Q = (q1, q2, q3 ... qn)
    A = (a1, a2, a3 ... an)
    len(Q) == len(A)
    batch_flow([Q, A], ws, batch_size = 32)
    raw = False:
    next(generator) == q_i_encoded, q_i_len, a_i_encoded, a_i_len
    raw = True:
    next(generator) == q_i_encoded, q_i_len, q_i, a_i_encoded, a_i_len, a_i
    """
    #ws数量要和data数量要保持一致（多个）,len(data) == len(ws)
    all_data = list(zip(*data))
    #[(['1', '2'], ['a', 'b']), (['2', '3', '4'], ['b', 'c', 'd']), (['1', '3', '4'], ['a', 'c', 'd'])]
    if isinstance(ws, (list, tuple)):#ws是否是list或tuple
        assert len(ws) == len(data), 'ws的长度必须等于data的长度 if ws 是一个list or tuple'

    if isinstance(add_end, bool):
        add_end = [add_end] * len(data)
    else:
        assert(isinstance(add_end, (list, tuple))), 'add_end不是boolean，就应该是一个list(tuple) of boolean'
        assert len(add_end) == len(data), '如果add_end 是list(tuple)，那么add_end的长度应该和输入数据的长度一致'

    mul = 2
    if raw:
        mul = 3

    while True:
        data_batch = random.sample(all_data, batch_size) #在all_data数据中随机抽取生成batch_size个数据
        batches = [[] for i in range(len(data) * mul)]

        max_lens = []#求所有数据中每句句子的长度
        for j in range(len(data)):#len（data）就是2，为[x,y]      
            max_len = max([len(x[j]) if hasattr(x[j], '__len__') else 0 for x in data_batch]) + (1 if add_end[j] else 0)
            #data_batch为如下格式[(['1', '2'], ['a', 'b']), (['2', '3', '4'], ['b', 'c', 'd']), (['1', '3', '4'], ['a', 'c', 'd'])]
            #即在all_data中选一些元素出来
            #判断在这一个batch中，最长的数字1,2以及最长的字母对应abc
            max_lens.append(max_len)#判断对象是否有某种属性/方法

        for d in data_batch:
            for j in range(len(data)):
                if isinstance(ws, (list, tuple)):
                    w = ws[j]
                else:
                    w = ws

                #添加结束标记（结尾）
                line = d[j]
                if add_end[j] and isinstance(line, (tuple, list)):
                    line = list(line) + [WordSequence.END_TAG]
                if w is not None:
                    x, xl = transform_sentence(line, w, max_lens[j], add_end[j])#data，ws可以为不同组输入，如输入3组data，分别对应三组ws
                    batches[j * mul].append(x) #最后生成batches时 batches[0,1]分别为第一组数据的翻译，长度。[2,3]为第二组数据的
                    batches[j * mul + 1].append(xl)#如果需要raw，还要返回原始数据,此时batches[0,1,2]对应第一组数据的转换，长度，原始数据
                else:                             #3,4,5对应第二组数据的
                    batches[j * mul].append(line)
                    batches[j * mul + 1].append(line)
                if raw:
                    batches[j * mul + 2].append(line)
        batches = [np.asarray(x) for x in batches]
        yield batches

def batch_flow_bucket(data, ws, batch_size, raw=False, add_end=True,#分块切分数据
                      n_bucket=5, bucket_ind=1, debug=False):#n_bucket把数据分成五分，bucket_ind切分的维度

    all_data = list(zip(*data))
    #[(['1', '2'], ['a', 'b']), (['2', '3', '4'], ['b', 'c', 'd']), (['1', '3', '4'], ['a', 'c', 'd'])]
    lengths = sorted(list(set([len(x[bucket_ind]) for x in all_data])))#通过排序得到数据长度,是排序后的所有x[bucked_ind]的集合的列表
    if n_bucket > len(lengths): #如[2,3]：x y中的单词长度只有2和3这两种
        n_bucket = len(lengths)

    splits = np.array(lengths)[(np.linspace(0, 1, 5, endpoint=False) * len(lengths)).astype(int)].tolist()
                                                    #结尾添加top标记

    splits += [np.inf] #np.inf无限大的正整数
    if debug:
        print(splits)

    ind_data = {}
    for x in all_data:
        l = len(x[bucket_ind])
        for ind, s in enumerate(splits[:-1]):
            if l >= s and l <= splits[ind + 1]:
                if ind not in ind_data:
                    ind_data[ind] = []
                ind_data[ind].append(x)
                break

    inds = sorted(list(ind_data.keys()))
    ind_p = [len(ind_data[x]) / len(all_data) for x in inds]#分布的概率
    if debug:
        print(np.sum(ind_p), ind_p)

    if isinstance(ws, (list, tuple)):
        assert len(ws) == len(data), "len(ws) 必须等于len(data)，ws是list或者是tuple"

    if isinstance(add_end, bool):
        add_end = [add_end] * len(data)
    else:
        assert(isinstance(add_end, (list, tuple))), "add_end 不是 boolean，就应该是一个list(tuple) of boolean"
        assert len(add_end) == len(data), "如果add_end 是list(tuple)，那么add_end的长度应该和输入数据长度是一致的"

    mul = 2
    if raw:
        mul = 3

    while True:
        choice_ind = np.random.choice(inds, p=ind_p)
        if debug:
            print('choice_ind', choice_ind)
        data_batch = random.sample(ind_data[choice_ind], batch_size)
        batches = [[] for i in range(len(data) * mul)]

        max_lens = []
        for j in range(len(data)):
            max_len = max([
                len(x[j]) if hasattr(x[j], '__len__') else 0
                for x in data_batch
            ]) + (1 if add_end[j] else 0)

            max_lens.append(max_len)

        for d in data_batch:
            for j in range(len(data)):
                if isinstance(ws, (list, tuple)):
                    w = ws[j]
                else:
                    w = ws

                #添加结尾
                line = d[j]
                if add_end[j] and isinstance(line, (tuple, list)):
                    line = list(line) + [WordSequence.END_TAG]

                if w is not None:
                    x, xl = transform_sentence(line, w, max_lens[j], add_end[j])
                    batches[j * mul].append(x)
                    batches[j * mul + 1].append(xl)
                else:
                    batches[j * mul].append(line)
                    batches[j * mul + 1].append(line)
                if raw:
                    batches[j * mul + 2].append(line)
        batches = [np.asarray(x) for x in batches]

        yield batches

def test_batch_flow():
    from fake_data import generate
    x_data, y_data, ws_input, ws_target = generate(size=100)
    flow = batch_flow([x_data, y_data], [ws_input, ws_target], 4)
    x, xl, y, yl = next(flow)
    #x = next(flow)
    print(x)
    print(xl)
    print(y)
    print(yl)
    #print(x.shape, y.shape, xl.shape, yl.shape)

def test_batch_flow_bucket():
    from fake_data import generate
    x_data, y_data, ws_input, ws_target = generate(size=100)
    flow = batch_flow_bucket([x_data, y_data], [ws_input, ws_target], 4, debug=True)
    for _ in range(10):
        x, xl, y, yl = next(flow)
        print(x.shape, y.shape, xl.shape, yl.shape)


if __name__ == '__main__':
    # size = 300000
    # print(_get_embed_device(size))
    #test_batch_flow_bucket()
    test_batch_flow()

[[12  8  9 10 11 10  3  0  0  0]
 [12  5 10 12  4 10 11  7  6  3]
 [11  9  4  5  8  6  4 11 12  3]
 [ 8  9  9  3  0  0  0  0  0  0]]
[ 8 10 10  5]
[[7 5 5 6 6 7 6 6 7 7 6 6 7 3 0]
 [7 4 6 6 7 7 4 6 6 7 7 5 5 4 3]
 [7 6 6 7 4 4 5 5 4 4 7 7 3 0 0]
 [5 5 6 6 7 6 6 7 3 0 0 0 0 0 0]]
[15 15 14 10]


In [18]:
x = [['1','2'],
    ['2','3','4'],
    ['1','3','4']]

y = [['a','b'],
    ['b','c','d'],
    ['a','c','d']]

data  = [x,y]

all_data = list(zip(*data))
print(all_data)

lengths = sorted(list(set([len(x[1]) for x in all_data])))
print(lengths)

[(['1', '2'], ['a', 'b']), (['2', '3', '4'], ['b', 'c', 'd']), (['1', '3', '4'], ['a', 'c', 'd'])]
[2, 3]
