In [None]:
import numpy
import sys
sys.path.append('/scr/vidit/Foundation_Models/FoundationModels')
from FoundationModels import *
from torchvision.transforms import v2

In [None]:
class TensorAugmentationDINO(object):
     def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):

        flips = transforms.Compose(
            [
                v2.RandomHorizontalFlip(p=0.5),
                v2.RandomVerticalFlip(p=0.5)
            ]
        )
        random_invert = v2.RandomInvert(p=0.001)  # Very small probability


        def safe_color_jitter(img):
            """Ensure color jitter works with float16 by converting to float32 first."""
            img = img.to(torch.float32)  # Convert to float32
            #img = v2.ColorJitter(brightness=0.01)(img)  # Apply color jitter
            img = v2.RandomRotation(degrees=(0, 10))(img) 
            return img.to(torch.float16)  # Convert back to float16

        def elastic_transform(img):
            """Apply elastic transform with a probability of 10-20%."""
            if random.uniform(0, 1) < 0.01:
                return v2.ElasticTransform()(img)
            return img
        
        augmentation_pipeline = transforms.Compose([
            flips,
            v2.Lambda(elastic_transform)
            #safe_color_jitter
            #random_invert
        ])


        # first global crop
        self.global_transfo1 = transforms.Compose([
            #v2.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC, antialias=True),
            augmentation_pipeline,
            v2.ToImageTensor(),
            #color_jittering,
            v2.Normalize(mean=[0.485], std=[0.229])
            #self_normalize()
        ])
        

        # second global crop
        self.global_transfo2 = transforms.Compose([
            #v2.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC, antialias=True),
            augmentation_pipeline,
            v2.ToImageTensor(),
            #color_jittering,
            v2.Normalize(mean=[0.485], std=[0.229])
            #self_normalize()
        ])
        # transformation for the local small crops
        self.local_crops_number = local_crops_number
        self.local_transfo = transforms.Compose([
            v2.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC, antialias=True),
            augmentation_pipeline,
            v2.ToImageTensor(),
            #color_jittering,
            v2.Normalize(mean=[0.485], std=[0.229])
            #self_normalize()
        ])

     def __call__(self, image):
        crops = []
        crops.append(self.global_transfo1(image))
        crops.append(self.global_transfo2(image))
        for _ in range(self.local_crops_number):
            crops.append(self.local_transfo(image))
        return crops
     def normalize(self, x):
        m = x.mean((-2, -1), keepdim=True)
        s = x.std((-2, -1), unbiased=False, keepdim=True)
        x -= m
        x /= s + 1e-7
        return x


In [None]:
from FoundationModels.dataset.dataset import IterableImageArchive
from FoundationModels.dataset import dataset_config
from FoundationModels.dataset.dataset_functions import randomize, split_for_workers, get_proc_split
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from torchvision import datasets, transforms


transform = v2.Compose([
    v2.CenterCrop(224),
    v2.ToTensor(),
    v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

config = dataset_config.DatasetConfig(
            "/scr/data/foundation_data/CHAMMIv2s.zip", # args.data_path, /scr/data/CHAMMIv2m.zip
            split_fns=[randomize, split_for_workers],
            transform=transform,
            seed=42
            )

dataset = IterableImageArchive(config)
data_loader = DataLoader(dataset=dataset, batch_size=1, num_workers=8, worker_init_fn=dataset.worker_init_fn)