In [20]:
import torch as th
import numpy as np

from datasets import load_impressionv2_dataset_split

In [14]:
train_ds, _ = load_impressionv2_dataset_split('train')

In [15]:
train_ds

<torch.utils.data.dataset.TensorDataset at 0x7fbae1d9cf10>

In [16]:
class TensorDatasetWithTransformer(th.utils.data.Dataset):
    def __init__(self, tensor_dataset, transform=None):
        self.tensor_dataset = tensor_dataset
        self.transform = transform
        
    def __getitem__(self, index):
        sample = self.tensor_dataset[index]
        if self.transform:
            sample = self.transform(sample)
            
        return sample
    
    def __len__(self):
        return len(self.tensor_dataset)

In [64]:
class SamplerTransform:
    def __init__(self, srA, srF, srT, is_random=False):
        self.srA = srA 
        self.srF = srF
        self.srT = srT
        self.is_random = is_random
    
    def __call__(self, x):
        audio, face, text, label = x
        al = audio.shape[0]
        fl = face.shape[0]
        tl = text.shape[0]
        assert al == 1526
        assert fl == 459
        assert tl == 60
        assert self.srA <= al
        assert self.srF <= fl
        assert self.srT <= tl
        
        if not self.is_random:
            a_idx = np.linspace(0, al-1, self.srA, dtype=int)
            f_idx = np.linspace(0, fl-1, self.srF, dtype=int)
            t_idx = np.linspace(0, tl-1, self.srT, dtype=int)
        else:
            a_idx = np.random.choice(al-1, self.srA, replace=False)
            f_idx = np.random.choice(fl-1, self.srF, replace=False)
            t_idx = np.random.choice(tl-1, self.srT, replace=False)
            a_idx.sort()
            f_idx.sort()
            t_idx.sort()
        audio_s = [a_idx, ]
        face_s = [f_idx, ]
        text_s = [t_idx, ]
        
        return audio_s, face_s, text_s, label

In [65]:
ds = TensorDatasetWithTransformer(train_ds, SamplerTransform(30, 30, 30, True))

In [66]:
ds[0]

([array([  68,   97,  120,  128,  135,  206,  226,  289,  402,  476,  584,
          721,  742,  756,  776,  781,  863,  925, 1008, 1026, 1105, 1122,
         1178, 1236, 1237, 1323, 1347, 1406, 1477, 1498])],
 [array([  8,  29,  30,  35,  44,  46,  66,  68, 103, 109, 113, 145, 177,
         182, 236, 239, 251, 253, 260, 276, 280, 328, 334, 348, 349, 354,
         360, 375, 399, 408])],
 [array([ 0,  1,  4,  7, 13, 14, 17, 18, 19, 21, 23, 24, 25, 30, 32, 35, 36,
         37, 39, 40, 42, 44, 45, 48, 51, 52, 53, 55, 56, 58])],
 tensor([0.5514, 0.5000, 0.5275, 0.6505, 0.7444]))

In [30]:
audio.shape, face.shape, text.shape, label.shape

(torch.Size([1526, 24]),
 torch.Size([459, 512]),
 torch.Size([60, 768]),
 torch.Size([5]))

In [None]:
al = audio.shape[0]
fl = face.shape[0]
tl = text.shape[0]

assert al == 1526
assert fl == 459
assert tl == 60

In [25]:
np.linspace(0, 3, 4, dtype=int)

array([0, 1, 2, 3])

In [17]:
ds = TensorDatasetWithTransformer(train_ds)

In [27]:
ds[1]

(tensor([[-1.7219e+00,  2.7733e+00, -8.3079e-04,  ...,  9.1841e-01,
           2.5316e-01,  7.1584e-01],
         [-1.7197e+00,  2.9463e+00,  3.3612e-01,  ...,  9.2970e-01,
           3.9145e-01,  7.2137e-01],
         [-1.7174e+00,  3.3067e+00,  6.6952e-01,  ...,  9.4596e-01,
           4.8278e-01,  7.1030e-01],
         ...,
         [ 1.7259e+00,  1.6559e+00, -2.1393e-01,  ...,  6.6622e-01,
          -2.3618e-01,  3.8720e-01],
         [ 1.7282e+00,  2.0043e+00,  1.9660e-01,  ...,  7.1194e-01,
          -2.1082e-01,  5.2063e-01],
         [ 1.7304e+00,  2.3335e+00,  4.9705e-01,  ...,  7.1788e-01,
          -3.0368e-01,  5.4108e-01]]),
 tensor([[1.0537, 0.4731, 0.2720,  ..., 0.8538, 0.8822, 1.0834],
         [1.1153, 0.2641, 1.1917,  ..., 0.6055, 1.5155, 0.5456],
         [1.1148, 0.0180, 1.0736,  ..., 1.0421, 1.7672, 0.2304],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.000

In [31]:
audio, face, text, label = ds[1]