In [23]:
from dataset import AVADataset

from torch.utils.data import DataLoader
import torch

In [None]:
class AVADataLoader(DataLoader):
    """
    DataLoader for the AVA dataset.

    Implements automatic padding to the maximum number of speakers in the batch.
    """

    def __init__(self, dataset, batch_size=1, shuffle=False, num_workers=0):
        super().__init__(dataset, batch_size=batch_size,
                         shuffle=shuffle, num_workers=num_workers, collate_fn=self.collate_fn)

    def collate_fn(self, batch):
        """
        Collate function to pad the batch to the maximum number of speakers.
        """
        # Get the maximum number of speakers in the batch
        mel, images, targets, bboxes = zip(*batch)
        max_speakers = max([target.shape[1] for target in targets])
        # Pad the targets and bboxes to (batch_size, T, max_speakers, -1)
        targets_padded = torch.zeros(
            (len(batch), self.dataset.T, max_speakers), dtype=torch.long)
        bboxes_padded = torch.zeros(
            (len(batch), self.dataset.T, max_speakers, 4), dtype=torch.float)
        for i, (target, bbox) in enumerate(zip(targets, bboxes)):
            targets_padded[i, :, :target.shape[1]] = target
            bboxes_padded[i, :, :bbox.shape[1], :bbox.shape[2]] = bbox
        # Stack the mel and images
        mel = torch.stack(mel, dim=0)
        images = torch.stack(images, dim=0)
        
        # Stack the images and targets
        images = images.view(-1, self.dataset.T, self.dataset.C, self.dataset.W, self.dataset.H)
        targets = targets_padded.view(-1, self.dataset.T, max_speakers)
        bboxes = bboxes_padded.view(-1, self.dataset.T, max_speakers, 4)
        return mel, images, targets, bboxes


In [38]:
train_dataset = AVADataset(mode='train')
train_loader = AVADataLoader(train_dataset, batch_size=16, num_workers=0, shuffle=False)

In [39]:
from tqdm import tqdm

for i, (mel, images, targets, bboxes) in tqdm(enumerate(train_loader), total=len(train_loader)):
    i



  9%|▉         | 6/64 [00:08<01:26,  1.49s/it]


KeyboardInterrupt: 