In [2]:
import torch
from torch.utils.data import DataLoader, Dataset
import pickle
import numpy as np

In [5]:
class MKDataSet(Dataset):
    def __init__(self, start_index=0, max_len=None):
        with open("./data/s04_8000_word2num.plk", "rb") as f:
            self.en_num, self.ch_num = pickle.load(f)
        if max_len is not None:
            self.en_num = self.en_num[start_index:start_index + max_len]
            self.ch_num = self.ch_num[start_index:start_index + max_len]

    def __len__(self):
        assert len(self.en_num) == len(self.ch_num)
        return len(self.en_num)

    def __getitem__(self, index):
        return self.en_num[index], self.ch_num[index]

    def collate_fn(self, batch):
        src_num, tgt_num = [], []
        src_len, tgt_len = [], []

        for src, tgt in batch:
            src_num.append(src)
            tgt_num.append(tgt)
            src_len.append(len(src))
            tgt_len.append(len(tgt))
        src_max_len = max(src_len)
        tgt_max_len = max(tgt_len)
        src_num = [i + [0] * (src_max_len - len(i)) for i in src_num]
        tgt_num = [i + [0] * (tgt_max_len - len(i)) for i in tgt_num]
        dec_input_tgt_num = []
        dec_output_tgt_num = []
        for tgt in tgt_num:
            dec_input_tgt_num.append(tgt[:-1])
            # print(dec_input_tgt_num)
            dec_output_tgt_num.append(tgt[1:])
            # print(dec_output_tgt_num)

        return torch.tensor(src_num), torch.tensor(dec_input_tgt_num), torch.tensor(dec_output_tgt_num)

def handleOneTrajectorySerial(serial_no):
    
    return

In [None]:
with open("./data/s04_8000_word2num.plk", "rb") as f:
    en_num, ch_num = pickle.load(f) #en_num 和 ch_num的长度是相同的，里面的内容一一对应

print(len(ch_num))

In [None]:
# print(en_num)
print(ch_num)

In [26]:
states = np.load('data/ArduCopter3_6_12 bug_free 10000/server0_5000000_5000999/experiment/output/PA/0/states_np_5000000_0.npy')
profiles = np.load('data/ArduCopter3_6_12 bug_free 10000/server0_5000000_5000999/experiment/output/PA/0/profiles_np_5000000_0.npy')


handled_states_with_wps = []
for i in range(profiles.shape[0]):
    for j in range(states.shape[1]):
        states[i][j][3] = profiles[i][0]
        states[i][j][4] = profiles[i][1]
        states[i][j][5] = profiles[i][2]
        handled_states_with_wps.append(states[i][j][0:6])

src_list = []
tar_list = []

src_len = 10
tar_len = 10

window_size = src_len + tar_len
for k in range(len(handled_states_with_wps) - window_size):
    src_list.append(handled_states_with_wps[k: k + src_len])
    tar_list.append(handled_states_with_wps[k + src_len : k + window_size])


# for l in range(len(src_list)):
#     print(src_list[l], tar_list[l])
print(src_list[0], tar_list[0])
print('-------------')
print(src_list[1], tar_list[1])

[array([-35.1902464 ,  40.8865888 , 417.58      , -35.19016658,
        40.88666068, 426.09418472]), array([-35.1902464 ,  40.8865888 , 417.78      , -35.19016658,
        40.88666068, 426.09418472]), array([-35.1902464 ,  40.8865888 , 417.97      , -35.19016658,
        40.88666068, 426.09418472]), array([-35.1902464 ,  40.8865888 , 418.15      , -35.19016658,
        40.88666068, 426.09418472]), array([-35.1902463 ,  40.8865888 , 418.32      , -35.19016658,
        40.88666068, 426.09418472]), array([-35.1902463 ,  40.8865888 , 418.47      , -35.19016658,
        40.88666068, 426.09418472]), array([-35.1902463 ,  40.8865888 , 418.61      , -35.19016658,
        40.88666068, 426.09418472]), array([-35.1902463 ,  40.8865888 , 418.73      , -35.19016658,
        40.88666068, 426.09418472]), array([-35.1902463 ,  40.8865888 , 418.86      , -35.19016658,
        40.88666068, 426.09418472]), array([-35.1902461 ,  40.886589  , 418.98      , -35.19016658,
        40.88666068, 426.09418472])]

In [6]:
if __name__ == '__main__':
    data = MKDataSet()
    print(data.__getitem__(1))
    dataloder = DataLoader(data, collate_fn=data.collate_fn, batch_size=2)
    for enc_inputs, dec_inputs, dec_outputs in dataloder:
        print('enc_inputs:', enc_inputs)
        print('dec_inputs:', dec_inputs)
        print('dec_outputs:', dec_outputs)
        print("*" * 10)
        pass

([458, 1136, 414], [1, 1614, 89, 2030, 2])
enc_inputs: tensor([[ 458, 1136,  414],
        [ 458, 1136,  414]])
dec_inputs: tensor([[   1, 1448,  315, 2030],
        [   1, 1614,   89, 2030]])
dec_outputs: tensor([[1448,  315, 2030,    2],
        [1614,   89, 2030,    2]])
**********
enc_inputs: tensor([[1032,  414],
        [  90,  292]])
dec_inputs: tensor([[   1, 1614, 3284, 3223, 2783, 2030],
        [   1, 1448, 1667, 1667,  906,    2]])
dec_outputs: tensor([[1614, 3284, 3223, 2783, 2030,    2],
        [1448, 1667, 1667,  906,    2,    0]])
**********
enc_inputs: tensor([[ 90, 292],
        [773, 414]])
dec_inputs: tensor([[   1, 1448, 1667, 3301,  906],
        [   1, 1448, 1609,  906,    2]])
dec_outputs: tensor([[1448, 1667, 3301,  906,    2],
        [1448, 1609,  906,    2,    0]])
**********
enc_inputs: tensor([[ 773,    0,    0,    0],
        [ 940, 1045, 1103,  292]])
dec_inputs: tensor([[   1, 1448, 1609,    2],
        [   1, 1614,   89, 2030]])
dec_outputs: tensor([[