# Utilities for data 

> This module handles all communication-related functionalities, including message passing, event handling, and notifications.

In [None]:
#| default_exp data.utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import math
from os.path import join, exists
import torch
from torchvision import transforms
import numpy as np

ASIZE, LSIZE, RSIZE, RED_SIZE, SIZE =\
    3, 32, 256, 32, 40

# transform_train = transforms.Compose([
#     transforms.ToPILImage(),
#     transforms.Resize((RED_SIZE, RED_SIZE)),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
# ])
transform_train = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((RED_SIZE, RED_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5)),
])

transform_test = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((RED_SIZE, RED_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5)),
])


In [None]:
#| export
import torch
from torchvision.transforms import v2
lejepa_train_tf = v2.Compose(
    [
        v2.ToPILImage(),
        v2.RandomResizedCrop(42, scale=(0.8, 1.0)), 
        v2.RandomApply([v2.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        v2.RandomGrayscale(p=0.2),
        # Reduced kernel size for smaller image resolution
        v2.RandomApply([v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.1),
        v2.RandomHorizontalFlip(),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        # Normalizes to [-1, 1] to match Tanh output
        v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)


lejepa_test_tf = v2.Compose(
            [
                v2.ToPILImage(),
                v2.Resize(42),
                v2.CenterCrop(42),
                v2.ToImage(),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )

In [None]:
#| export
from torch.utils.data import ConcatDataset

class BufferAwareConcatDataset(ConcatDataset):
    """
    A ConcatDataset wrapper that provides a method to call 
    load_next_buffer() on all its constituent datasets.
    """
    def load_next_buffer(self):
        """
        Iterates through all underlying datasets and calls their 
        load_next_buffer method.
        """
        for dataset in self.datasets:
            # Check if the method exists to be safe, though 
            # in your context, it should exist on all of them.
            if hasattr(dataset, 'load_next_buffer'):
                dataset.load_next_buffer()
            else:
                # Optionally, you can raise an error or log a warning 
                # if a dataset is missing the expected method
                print(f"Warning: Dataset {type(dataset)} is missing load_next_buffer()")
                
        self.cumulative_sizes = self.cumsum(self.datasets)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()