## **CHECK SIMPLE AUGMENTATION**

In [None]:
from PIL import Image
import torch
from torchvision import datasets, transforms

# arg definition
data_path = "/gpfsdswork/dataset/imagenet/train"
global_crops_scale = (0.8, 1.)

# ============ preparing data ... ============
transform = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
            transforms.ToTensor(),
            # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
dataset = datasets.ImageFolder(data_path, transform=transform)
original_dataset = datasets.ImageFolder(data_path, transform=None)
print(f"Data loaded: there are {len(dataset)} images.")

In [None]:
from random import randint
n = randint(0,10000)
img_resized_crop = dataset[n]
img = original_dataset[n]

In [None]:
img[0]

In [None]:
transforms.ToPILImage()(img_resized_crop[0])

## **CHECK DINO DATA AUGMENTATION**

In [None]:
from torchvision import transforms
import utils

class DataAugmentationDINO(object):
    def __init__(self, global_crops_scale, local_crops_scale, local_crops_number):
        flip_and_color_jitter = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
        ])
        normalize = transforms.Compose([
            transforms.ToTensor(),
            #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        # first global crop
        self.global_transfo1 = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
            flip_and_color_jitter,
            utils.GaussianBlur(1.0),
            normalize,
        ])
        # second global crop
        self.global_transfo2 = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC),
            flip_and_color_jitter,
            utils.GaussianBlur(0.1),
            utils.Solarization(0.2),
            normalize,
        ])
        # transformation for the local small crops
        self.local_crops_number = local_crops_number
        self.local_transfo = transforms.Compose([
            transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC),
            flip_and_color_jitter,
            utils.GaussianBlur(p=0.5),
            normalize,
        ])

    def __call__(self, image):
        crops = []
        crops.append(self.global_transfo1(image))
        crops.append(self.global_transfo2(image))
        for _ in range(self.local_crops_number):
            crops.append(self.local_transfo(image))
        return crops

In [None]:
original_dataset = datasets.ImageFolder('/gpfswork/rech/uli/ssos027/dino_experience/data/ImageNet/train', transform=None)
data = original_dataset[1]

In [None]:
transfo = DataAugmentationDINO(
        (0.4, 1.0),
        (0.05, 0.4),
        6,
    )
transforms.ToPILImage()(transfo(data[0])[0])

In [None]:
data[0]

## CHECK COCO DATASET DEFINITION

In [None]:
import glob
import numpy as np
import os

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.transforms.functional as TF
from pycocotools.coco import COCO
import torch


# TODO : maybe reduce the number of classes for training 
# Check how the training is performed for object-centric representation on existing works
class COCODataset(Dataset):
    def __init__(
        self,
        dataset,
        data_dir,
        transform,
    ):
        super(COCODataset, self).__init__()
        if dataset == 'COCO':
            ann_file = data_dir + '/annotations/instances_train2017.json'
            self.coco = COCO(ann_file)
            self.ids = self.coco.getImgIds() # list of image id
            self.cat_ids = self.coco.getCatIds() # list of cat id
            self.root = data_dir + '/train2017/'
            self.target_transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize(224, interpolation=TF.InterpolationMode.NEAREST),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ])
            self.num_cat = len(self.cat_ids) # TODO only use no crowd annotations ???
        elif dataset == 'COCOplus':
            self.fpaths = glob.glob(data_dir + '/train2017/*.jpg') + glob.glob(data_dir + '/unlabeled2017/*.jpg')
            self.fpaths = np.array(self.fnames) # to avoid memory leak
        elif dataset == 'COCOval':
            self.fpaths = glob.glob(data_dir + '/val2017/*.jpg')
            self.fpaths = np.array(self.fnames) # to avoid memory leak
        else:
            raise NotImplementedError
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        if self.dataset == 'COCO':
            return len(self.ids)
        return len(self.fpaths)
        
    def __getitem__(self, idx):
        if self.dataset == 'COCO': 
            img_id = self.ids[idx]
            
            # Load image
            # type : PIL.Image.Image
            # size : (W, H)
            fname = self.coco.loadImgs(img_id)[0]['file_name']
            image = Image.open(os.path.join(self.root, fname)).convert('RGB')

            # Get all the annotations linked to our image
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            anns = self.coco.loadAnns(ann_ids)
            
            # for each pixel in the image, attribute its category id
            # category id began from 1 to 99 with a total of 80 categories
            # size : (H,W)
            anns_img = np.zeros((image.size[1],image.size[0]), dtype=np.uint8)
            for ann in anns:
                anns_img = np.maximum(anns_img, self.coco.annToMask(ann)*ann['category_id'])
            
            # (H,W) != (224,224)
            one_hot_mask = torch.zeros(224, 224, self.num_cat)
            for i, class_id in enumerate(self.cat_ids):
                # type : np.ndarray
                # len : H*W
                mask = (anns_img == class_id).astype(np.uint8)
                # transform our np.ndarray to torch.tensor 
                # and resize from HxW to 224x224
                one_hot_mask[..., i] = self.target_transform(mask)
             
            # transform image for training
            transfo_img = self.transform(image)
            
            # transform our np.ndarray to torch.tensor 
            # and resize from HxW to 224x224
            anns_img = self.target_transform(anns_img)
            
            # type : tuple[torch.tensor, tuple[torch.tensor, torch.tensor]]
            return transfo_img, (anns_img, one_hot_mask)
        
        fpath = self.fnames[idx]
        image = Image.open(fpath).convert('RGB')
        # type : tuple[torch.tensor, None]
        return self.transform(image), None

In [None]:
dataset = 'COCO'
data_dir = '/gpfswork/rech/uli/ssos027/dino_experience/data/COCO'
transform = DataAugmentationDINO(
        (0.7, 1.0),
        (0.05, 0.4),
        6,
    )

coco = COCODataset(dataset, data_dir, transform)
img, transfo_img, _ = coco[1]

In [None]:
transform = DataAugmentationDINO(
        (0.7, 1.0),
        (0.3, 0.7),
        6,
    )
coco.set_transform(transform)

In [None]:
transfo_img = coco.transform(img)
transforms.ToPILImage()(transfo_img[0])

In [None]:
transforms.ToPILImage()(transfo_img[2])

## **CHECK COCO API**

In [None]:
from pycocotools.coco import COCO

In [None]:
root = '/gpfswork/rech/uli/ssos027/dino_experience/data/COCO'
img_root = root + '/train2017/'
ann_file = root + '/annotations/instances_train2017.json'
coco = COCO(ann_file)

In [None]:
img_id = coco.getImgIds()[0]

In [None]:
fname = coco.loadImgs(img_id)[0]['file_name']
image = Image.open(os.path.join(img_root, fname)).convert('RGB')

In [None]:
ann_ids = coco.getAnnIds(imgIds=img_id)
anns = coco.loadAnns(ann_ids)

In [None]:
anns_img = np.zeros((image.size[1],image.size[0]), dtype=np.uint8)
for ann in anns:
    anns_img = np.maximum(anns_img, coco.annToMask(ann)*ann['category_id'])