In [9]:
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader

In [11]:
class DD100(Dataset):
    def __init__(self, folder_path):
        self.folder_path = folder_path
        self.file_names = (sorted([f for f in os.listdir(folder_path) if f.endswith('.npy')]))
    
    def __len__(self):
        return len(self.file_names)//2
    
    def __getitem__(self, idx):
        file_name_follower = self.file_names[2*idx]
        file_name_leader = self.file_names[2*idx+1]
        data_follower = np.load(os.path.join(self.folder_path, file_name_follower))
        data_leader = np.load(os.path.join(self.folder_path, file_name_leader))
        data_follower = data_follower.reshape(-1, 55, 3)
        data_leader = data_leader.reshape(-1, 55, 3)
        return data_leader, data_follower


In [17]:
import torch
from torch.utils.data import Dataset, DataLoader

class RandomDataset(Dataset):
    def __init__(self, num_samples, sequence_length, num_features):
        self.num_samples = num_samples
        self.sequence_length = sequence_length
        self.num_features = num_features
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # Generate random data for x and c
        x = torch.randn(self.sequence_length, self.num_features)
        c = torch.randn(self.sequence_length, self.num_features)
        
        return x, c

# Parameters for the random dataset
num_samples = 98  # Number of samples
sequence_length = 50  # Length of each sequence
num_features = 10  # Number of features in each element of the sequence

# Create an instance of the RandomDataset
dataset = RandomDataset(num_samples, sequence_length, num_features)

# Create a DataLoader for batching and shuffling the data
batch_size = 10
shuffle = True

data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

# Example usage of the data loader
for batch_idx, data in enumerate(data_loader):
    x = data[0]
    y = data[1]
    print(x.shape)
    print(y.shape)
    break

torch.Size([10, 50, 10])
torch.Size([10, 50, 10])


In [21]:
for data in data_loader:
    x, c = data
    print(x.shape)
    print(c.shape)
    

torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([10, 50, 10])
torch.Size([8, 50, 10])
torch.Size([8, 50, 10])


In [26]:
with torch.no_grad():
    for data in data_loader:
        x, c = data
        print(x.shape)
        print(c.shape)
        break

ValueError: too many values to unpack (expected 2)

In [25]:
for data_loader in data_loader:
    x, c = data_loader
    print(x.shape)
    print(c.shape)
    break

torch.Size([10, 50, 10])
torch.Size([10, 50, 10])


In [3]:
def pad_sequence(seq, max_length):
    T,J,D = seq.shape
    padded_seq = np.zeros((max_length,J,D))
    padded_seq[:T] = seq
    return padded_seq

def collate_fn(batch):
    batch_size = len(batch)
    max_length = max([len(data[0]) for data in batch])
    J = batch[0][0].shape[1]
    D = batch[0][0].shape[2]
    data_leader = np.zeros((batch_size, max_length, J, D))
    data_follower = np.zeros((batch_size, max_length, J, D))
    for i, (data_leader_, data_follower_) in enumerate(batch):
        data_leader[i] = pad_sequence(data_leader_, max_length)
        data_follower[i] = pad_sequence(data_follower_, max_length)
    return data_leader, data_follower

In [4]:
train_dataset = DD100('./data/motion/pos3d/train')
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [8]:
for data in train_dataloader:
    print(data[0].shape, data[1].shape)
    break

(32, 3227, 55, 3) (32, 3227, 55, 3)


In [16]:
test_dataset = DD100('./data/motion/pos3d/test')
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [17]:
for data in test_dataloader:
    print(data[0].shape, data[1].shape)
    break

(32, 3006, 55, 3) (32, 3006, 55, 3)


In [24]:
type(data_loader)

torch.utils.data.dataloader.DataLoader