# Utilities for data 

> This module handles all communication-related functionalities, including message passing, event handling, and notifications.

In [None]:
#| default_exp data.utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import math
from os.path import join, exists
import torch
from torchvision import transforms
import numpy as np

ASIZE, LSIZE, RSIZE, RED_SIZE, SIZE =\
    3, 32, 256, 32, 40

transform_train = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((RED_SIZE, RED_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((RED_SIZE, RED_SIZE)),
    transforms.ToTensor(),
])


In [None]:
#| export
from torch.utils.data import ConcatDataset

class BufferAwareConcatDataset(ConcatDataset):
    """
    A ConcatDataset wrapper that provides a method to call 
    load_next_buffer() on all its constituent datasets.
    """
    def load_next_buffer(self):
        """
        Iterates through all underlying datasets and calls their 
        load_next_buffer method.
        """
        for dataset in self.datasets:
            # Check if the method exists to be safe, though 
            # in your context, it should exist on all of them.
            if hasattr(dataset, 'load_next_buffer'):
                dataset.load_next_buffer()
            else:
                # Optionally, you can raise an error or log a warning 
                # if a dataset is missing the expected method
                print(f"Warning: Dataset {type(dataset)} is missing load_next_buffer()")

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()