In [7]:
import torch
from torch.utils.data import Dataset,DataLoader
import numpy as np
import random
import os

In [17]:
class MyDataset(Dataset):
    def __init__(self):
#         super(MyDataset).__init()
        self.datas = np.random.randn(128, 12, 15)
    
    def __getitem__(self, index):
        return torch.tensor(self.datas[index]).unsqueeze(0), index%4
    
    def __len__(self):
        return len(self.datas)
    
class PadCollate:
    """
    a variant of callate_fn that pads according to the longest sequence in
    a batch of sequences
    """
    def __init__(self, dim=0, min_chunk_size=2, max_chunk_size=4, normlize=True,
                 num_batch=0, fix_len=False):
        """
        args:
            dim - the dimension to be padded (dimension of time in sequences)
        """
        self.dim = dim
        self.min_chunk_size = min_chunk_size
        self.max_chunk_size = max_chunk_size
        self.num_batch = num_batch
        self.fix_len = fix_len
        self.normlize = normlize

        if self.fix_len:
            self.frame_len = np.random.randint(low=self.min_chunk_size, high=self.max_chunk_size)
        else:
            assert num_batch > 0
            batch_len = []
            self.iteration = 0
            # print('==> Generating %d different random length...' % (int(np.ceil(num_batch/100))))
            # for i in range(int(np.ceil(num_batch/100))):
            #     batch_len.append(np.random.randint(low=self.min_chunk_size, high=self.max_chunk_size))
            # self.batch_len = np.repeat(batch_len, 100)

            print('==> Generating %d different random length...' % (num_batch))
            for i in range(num_batch):
                batch_len.append(np.random.randint(low=self.min_chunk_size, high=self.max_chunk_size))
            self.batch_len = np.array(batch_len)
            avg_len = int((min_chunk_size + max_chunk_size) / 2)
            while np.mean(self.batch_len[:num_batch]) < avg_len:
                self.batch_len += 1
                self.batch_len = self.batch_len.clip(max=self.max_chunk_size)

            print('==> Average of utterance length is %d. ' % (np.mean(self.batch_len[:num_batch])))
            print(self.batch_len[:num_batch])

    def pad_collate(self, batch):
        """
        args:
            batch - list of (tensor, label)
        reutrn:
            xs - a tensor of all examples in 'batch' after padding
            ys - a LongTensor of all labels in batch
        """
        if self.fix_len:
            frame_len = self.frame_len
        else:
            # frame_len = np.random.randint(low=self.min_chunk_size, high=self.max_chunk_size)
#             frame_len = self.batch_len[self.iteration % self.num_batch]
            frame_len = random.choice(self.batch_len)
#             self.iteration += 1
#             print(os.getpid())
#             if self.iteration >= self.num_batch:
#                 random.shuffle(self.batch_len)
#                 print(self.batch_len)
            
#             self.iteration %= self.num_batch
            
        # pad according to max_len
        # print()
        xs = torch.stack(list(map(lambda x: x[0], batch)), dim=0)
        xs_shape = xs.shape

        if frame_len < xs_shape[self.dim]:
            start = np.random.randint(low=0, high=xs_shape[self.dim] - frame_len)
            end = start + frame_len
            if self.dim == 2:
                xs = xs[:, :, start:end, :].contiguous()
            elif self.dim == 3:
                xs = xs[:, :, :, start:end].contiguous()
        else:
            xs = xs.contiguous()

        ys = torch.LongTensor(list(map(lambda x: x[1], batch)))

        return xs, ys

    def __call__(self, batch):
        return self.pad_collate(batch)


In [18]:
a = MyDataset()

kwargs = {'num_workers': 4, 'pin_memory': True} 
ad = DataLoader(a, batch_size=4, collate_fn=PadCollate(dim=2, num_batch=int(np.ceil(len(a) / 2)),
                                                                         min_chunk_size=1,
                                                                         max_chunk_size=7),
                                                   shuffle=False, **kwargs)

==> Generating 64 different random length...
==> Average of utterance length is 4. 
[5 4 3 4 5 5 4 7 3 4 5 6 7 2 5 2 5 2 7 6 2 7 2 6 5 7 7 7 6 3 4 6 4 6 5 6 6
 4 7 4 6 5 5 5 6 5 6 2 7 3 5 6 3 3 4 7 2 2 5 6 5 4 3 3]


In [19]:
def train(ad):
    for x in ad:
        print(x[0].shape[2])

In [20]:
for i in range(3):
    train(ad)

7
6
7
2
4
4
6
6
7
5
4
6
6
4
2
4
6
6
5
5
7
4
6
2
6
2
2
7
4
3
5
5
7
6
4
6
6
4
5
5
3
7
6
5
6
5
5
6
2
7
4
5
3
6
7
6
6
6
7
5
3
3
6
3
7
7
2
2
5
7
5
5
5
7
2
6
5
6
5
5
6
5
5
3
4
4
6
5
6
5
7
6
4
3
3
5
