In [73]:
import pickle

from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch


In [74]:
with open("train.pkl", "rb") as f:
    data_train = pickle.load(f)

with open("test_no_target.pkl", "rb") as f:
    data_test = pickle.load(f)


In [75]:
X_train = []
y_train = []
X_test = []

for data, label in data_train:
    X_train.append(torch.Tensor(data))
    y_train.append(torch.Tensor(label))

for data in data_test:
    X_test.append(torch.Tensor(data))


In [76]:
train_max_len = max([len(sequence) for sequence in X_train])
test_max_len = max([len(sequence) for sequence in X_test])
max_len = max([train_max_len, test_max_len])


In [77]:
def pad_collate(batch, pad_value=0):
    """
    batch: list[tuple[torch.Tensor]]
    """
    xx, yy = zip(*batch)
    x_lens = [len(x) for x in xx]
    y_lens = [len(y) for y in yy]

    xx_pad = pad_sequence(xx, batch_first=True, padding_value=pad_value)
    yy_pad = pad_sequence(yy, batch_first=True, padding_value=pad_value)

    return xx_pad, yy_pad, x_lens, y_lens


In [78]:
class VariableLenDataset(Dataset):
    def __init__(self, in_data, target):
        self.data = [(x, y) for x, y in zip(in_data, target)]      

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        in_data, target = self.data[idx]
        return in_data, target


In [79]:
train_indices = int(len(data) * 0.7)

train_set = VariableLenDataset(X_train[:train_indices], y_train[:train_indices])
valid_set = VariableLenDataset(X_train[train_indices:], y_train[train_indices:])

train_loader = DataLoader(train_set, batch_size=50, shuffle=True, collate_fn=pad_collate)
valid_loader = DataLoader(valid_set, batch_size=50, shuffle=False, drop_last=False, collate_fn=pad_collate)


In [80]:
a = next(iter(train_loader))


In [85]:
a

(tensor([[144.,  72.,  13.,  ...,   0.,   0.,   0.],
         [144., 132.,  12.,  ...,   0.,   0.,   0.],
         [ 66., 100., 148.,  ...,   0.,   0.,   0.],
         ...,
         [144., 144., 145.,  ...,   0.,   0.,   0.],
         [ -1.,  -1., 144.,  ...,  32.,  -1.,  -1.],
         [ -1.,  -1.,  -1.,  ...,   0.,   0.,   0.]]),
 tensor([], size=(50, 0)),
 [176,
  142,
  592,
  620,
  342,
  447,
  352,
  552,
  244,
  180,
  86,
  56,
  360,
  60,
  44,
  168,
  44,
  88,
  51,
  465,
  388,
  180,
  48,
  910,
  300,
  48,
  672,
  981,
  480,
  196,
  714,
  52,
  172,
  100,
  548,
  195,
  12,
  525,
  240,
  3206,
  472,
  932,
  288,
  124,
  180,
  312,
  276,
  52,
  5322,
  219],
 [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0])