# Utilities for Datasets

## Compute Mean & Std of the real training dataset

### Uncomment only the chosen dataset

In [None]:
# dataset = dset.CIFAR10(root='../datasets/cifar10', train=True, download=True, transform=transform)
# dataset = dset.CIFAR100(root='../datasets/cifar100', train=True, download=True, transform=transform)
# dataset = OxfordPetsDataset(root='../datasets/oxfordpets', split='train', transform=transform) 
# dataset = dset.StanfordCars(root='../datasets/stanfordcars', split='train', download=False, transform=transform)
# dataset = dset.Food101(root='../datasets/food101', split='train', download=True, transform=transform)
# dataset = TinyImageNetDataset(root='../datasets/tinyimagenet', split='train', transform=transform)
# dataset = DermaMNIST(root='../datasets/dermamnist', split='train', size=224, as_rgb=True, download=True, transform=v2.ToTensor())
# dataset = BloodMNIST(root='../datasets/bloodmnist', split='train', size=224, as_rgb=True, download=True, transform=v2.ToTensor())
# dataset = dset.STL10(root='../datasets/stl10', split='train', download=True, transform=transform)
# dataset = dset.Imagenette(root='../datasets/imagenette', split='train', transform=transform, download=True)
# dataset = ImagewoofDataset(root='../datasets/imagewoof', split='train', transform=transform)
# dataset = Caltech101Dataset(root='../datasets/caltech101', split='train', transform=transform)

### Set the path to the output file

In [None]:
output_file = 'mean_std.txt'

### Compute Mean & Std and save them on the output file

In [None]:
import torch
import numpy
import torchvision.datasets as dset
from torchvision import transforms
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, load_from_disk
import scipy
import time
from PIL import Image
from medmnist import DermaMNIST, BloodMNIST


class OxfordPetsDataset(Dataset):
    def __init__(self, root='../datasets/oxfordpets', split='train', transform=None):
        self.split = split
        self.transform = transform
        self.data_dir = root+'/train85.pth' if split == 'train' else root+'/test15.pth'
        self.data = torch.load(self.data_dir)
        self.classes = sorted(set(label.item() for _, label in self.data)) # -> 0, 1, 2, ..., 36

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img, label = self.data[index]
        if self.transform:
            img = self.transform(img)

        return img, label


class TinyImageNetDataset(Dataset):  
    def __init__(self, root='../datasets/tinyimagenet', split='train', transform=None):
        self.split = split
        self.transform = transform
        self.data_dir = f"{root}/train" if split == 'train' else f"{root}/valid"
        self.data = load_from_disk(self.data_dir)
        self.classes = self.data.features['label'].names
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        example = self.data[index]
        img = example['image']
        label = example['label']

        if img.mode != 'RGB':
            img = img.convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label


class Caltech101Dataset(Dataset):
    def __init__(self, root='../datasets/caltech101', split='train', transform=None):
        self.split = split
        self.transform = transform
        self.data_dir = root+'/train' if split == 'train' else root+'/test'
        self.data = load_from_disk(self.data_dir)
        self.classes = self.data.features['label'].names
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        example = self.data[index]
        img = example['image']
        label = example['label']

        if img.mode != 'RGB':
            img = img.convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label


class ImagewoofDataset(Dataset):
    def __init__(self, root='../datasets/imagewoof', split='train', transform=None):
        self.split = split
        self.transform = transform
        self.data_dir = root+'/train' if split == 'train' else root+'/validation'
        self.data = load_from_disk(self.data_dir)
        self.classes = self.data.features['label'].names
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        example = self.data[index]
        img = example['image']
        label = example['label']

        if img.mode != 'RGB':
            img = img.convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label


# Resize to square images, otherwise torch.stack will return error
transform = v2.Compose([
    v2.Resize((224, 224)),
    v2.ToTensor()
])

batch_size = 96

data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Initialize variables to compute mean and std
mean = torch.zeros(3)
std = torch.zeros(3)
total_images = 0

start_time = time.time()

print("Starting computation of mean and standard deviation...")

for batch_idx, (images, labels) in enumerate(data_loader):

    # Log progress every 1000 images
    if total_images % 1000 == 0 and total_images != 0:
        elapsed_time = time.time() - start_time
        print(f"Processed {total_images}/{len(dataset)} images. Time elapsed: {elapsed_time:.2f} seconds")

    batch_samples = images.size(0)
    total_images += batch_samples

    # Reshape images to (batch_size, channels, height * width)
    images = images.view(batch_samples, images.size(1), -1)

    # Compute mean and std for the batch
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

elapsed_time = time.time() - start_time
print(f"Finished processing {total_images} images in {elapsed_time:.2f} seconds.")

# Final computation
mean /= total_images
std /= total_images

# Save results to a file
with open(output_file, 'a') as f:
    f.write(f'Number of images: {total_images}\n')
    f.write(f'Mean (R, G, B): {mean[0]:.4f}, {mean[1]:.4f}, {mean[2]:.4f}\n')
    f.write(f'Std (R, G, B): {std[0]:.4f}, {std[1]:.4f}, {std[2]:.4f}\n')

print(f"Mean and standard deviation have been computed and saved to the file {output_file}.")

## Save an Hugging Face dataset to the disk

In [None]:
from datasets import load_dataset
import os

# Define the path where you want to save the dataset
save_path = '../datasets/tinyimagenet'
hf_dataset_name = 'zh-plus/tiny-imagenet'

# Check if the dataset directory already exists to avoid re-downloading
if not os.path.exists(save_path):
    os.makedirs(save_path)
    dataset_train = load_dataset(hf_dataset_name, split='train')
    dataset_valid = load_dataset(hf_dataset_name, split='valid')

    dataset_train.save_to_disk(os.path.join(save_path, 'train'))
    dataset_valid.save_to_disk(os.path.join(save_path, 'valid'))

    print(f"HF dataset saved to {save_path}")
else:
    print(f"HF dataset already exists at {save_path}")