In [1]:
import torch as th
import numpy as np

from datasets import load_impressionv2_dataset_split

In [2]:
train_ds, _ = load_impressionv2_dataset_split('train')

In [3]:
class TensorDatasetWithTransformer(th.utils.data.Dataset):
    def __init__(self, tensor_dataset, transform=None):
        self.tensor_dataset = tensor_dataset
        self.transform = transform
        
    def __getitem__(self, index):
        sample = self.tensor_dataset[index]
        if self.transform:
            sample = self.transform(sample)
            
        return sample
    
    def __len__(self):
        return len(self.tensor_dataset)

In [4]:
class SamplerTransform:
    def __init__(self, srA, srF, srT, is_random=False):
        self.srA = srA 
        self.srF = srF
        self.srT = srT
        self.is_random = is_random
    
    def __call__(self, x):
        audio, face, text, label = x
        al = audio.shape[0]
        fl = face.shape[0]
        tl = text.shape[0]
        assert al == 1526
        assert fl == 459
        assert tl == 60
        assert self.srA <= al
        assert self.srF <= fl
        assert self.srT <= tl
        
        if not self.is_random:
            a_idx = np.linspace(0, al-1, self.srA, dtype=int)
            f_idx = np.linspace(0, fl-1, self.srF, dtype=int)
            t_idx = np.linspace(0, tl-1, self.srT, dtype=int)
        else:
            a_idx = np.random.choice(al-1, self.srA, replace=False)
            f_idx = np.random.choice(fl-1, self.srF, replace=False)
            t_idx = np.random.choice(tl-1, self.srT, replace=False)
            a_idx.sort()
            f_idx.sort()
            t_idx.sort()
        audio_s = audio[a_idx, ]
        face_s = face[f_idx, ]
        text_s = text[t_idx, ]
        
        return audio_s, face_s, text_s, label

In [5]:
ds = TensorDatasetWithTransformer(train_ds, SamplerTransform(30, 30, 30, True))

In [6]:
ds[0]

([array([ 108,  254,  298,  305,  382,  401,  502,  546,  622,  664,  671,
          687,  734,  736,  775,  828,  876,  943,  949,  994, 1003, 1043,
         1211, 1276, 1293, 1419, 1461, 1477, 1514, 1518])],
 [array([ 65,  81,  83,  99, 112, 118, 122, 126, 128, 139, 158, 170, 181,
         198, 216, 253, 288, 313, 329, 332, 346, 365, 375, 389, 401, 418,
         432, 435, 448, 450])],
 [array([ 0,  3,  4,  5,  7,  8,  9, 11, 13, 14, 15, 16, 17, 18, 19, 20, 22,
         24, 26, 30, 34, 35, 36, 38, 40, 44, 45, 51, 52, 57])],
 tensor([0.5514, 0.5000, 0.5275, 0.6505, 0.7444]))

In [7]:
train_ds[0]

(tensor([[-1.7219, -0.6741, -0.0686,  ..., -1.0177,  1.2637, -1.1608],
         [-1.7197, -0.6999,  0.0864,  ..., -0.1044,  0.2689, -0.3515],
         [-1.7174, -0.6589,  0.6075,  ...,  0.8071, -0.5965,  0.4629],
         ...,
         [ 1.7259,  3.9920,  0.5464,  ...,  0.7302,  1.4042,  0.6708],
         [ 1.7282,  4.0712,  0.7517,  ...,  0.7433,  1.5186,  0.6367],
         [ 1.7304,  3.9432,  0.9267,  ...,  0.7425,  1.4939,  0.6272]]),
 tensor([[1.4787, 0.9806, 0.2581,  ..., 1.4020, 0.5203, 1.1110],
         [1.7456, 1.5411, 0.3094,  ..., 1.7443, 0.5398, 1.3978],
         [1.5188, 1.2019, 0.7119,  ..., 1.0248, 1.1424, 1.0761],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]),
 tensor([[-0.6210,  0.1200,  0.4032,  ..., -0.4498,  0.8524,  0.5574],
         [ 0.2091,  0.0823, -0.2423,  ..., -0.0468,  1.3281,  0.6255],
         [

In [30]:
audio.shape, face.shape, text.shape, label.shape

(torch.Size([1526, 24]),
 torch.Size([459, 512]),
 torch.Size([60, 768]),
 torch.Size([5]))

In [None]:
al = audio.shape[0]
fl = face.shape[0]
tl = text.shape[0]

assert al == 1526
assert fl == 459
assert tl == 60

In [25]:
np.linspace(0, 3, 4, dtype=int)

array([0, 1, 2, 3])

In [17]:
ds = TensorDatasetWithTransformer(train_ds)

In [27]:
ds[1]

(tensor([[-1.7219e+00,  2.7733e+00, -8.3079e-04,  ...,  9.1841e-01,
           2.5316e-01,  7.1584e-01],
         [-1.7197e+00,  2.9463e+00,  3.3612e-01,  ...,  9.2970e-01,
           3.9145e-01,  7.2137e-01],
         [-1.7174e+00,  3.3067e+00,  6.6952e-01,  ...,  9.4596e-01,
           4.8278e-01,  7.1030e-01],
         ...,
         [ 1.7259e+00,  1.6559e+00, -2.1393e-01,  ...,  6.6622e-01,
          -2.3618e-01,  3.8720e-01],
         [ 1.7282e+00,  2.0043e+00,  1.9660e-01,  ...,  7.1194e-01,
          -2.1082e-01,  5.2063e-01],
         [ 1.7304e+00,  2.3335e+00,  4.9705e-01,  ...,  7.1788e-01,
          -3.0368e-01,  5.4108e-01]]),
 tensor([[1.0537, 0.4731, 0.2720,  ..., 0.8538, 0.8822, 1.0834],
         [1.1153, 0.2641, 1.1917,  ..., 0.6055, 1.5155, 0.5456],
         [1.1148, 0.0180, 1.0736,  ..., 1.0421, 1.7672, 0.2304],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.000

In [31]:
audio, face, text, label = ds[1]