In [1]:
!pip install pillow -q
!pip install datasets -q
!pip install torch -q
!pip install torchvision -q
!pip install matplotlib -q
!pip install datasets -q
!pip install torchmetrics -q

In [2]:
from datasets import load_dataset
import PIL
import torch
from torchvision.transforms import v2
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import os

# Global

In [3]:
DATA_DIR = "data/"
MODEL_DIRS = DATA_DIR+"models/"

# Utils

In [4]:
class TorchDatasetWrapper(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.hf_dataset = hf_dataset
        self.transform=transform

    def __repr__(self):
        return str(self.hf_dataset)

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        example = self.hf_dataset[idx]
        image = example['image']
        if self.transform:
            image =  self.transform(image)
        return image, example['label']

In [5]:
def print_patches(image_unfolded):
    import matplotlib.pyplot as plt
    from torchvision.utils import make_grid
    if len(image_unfolded.shape) == 5:
        image_unfolded = image_unfolded[None,:,:,:,:,:]
    for i in range(image_unfolded.shape[0]):
        n_channels, n_patches_h, n_patches_w, h, w = image_unfolded.shape[1:]
        patches = image_unfolded[i].permute(1, 2, 0, 3, 4).reshape(-1, n_channels, h, w)
        grid = make_grid(patches, nrow=n_patches_w, padding=2, normalize=True)
        plt.figure(figsize=(3, 3))
        plt.imshow(grid.permute(1, 2, 0).numpy())
        plt.axis('off')
        plt.show()

In [6]:
def gen_super_tiny(dataset, split, q=10, p=0, step=500, classes=200):
    from datasets import Dataset
    images_per_class = {"validation":50,
                       "test":50,
                       "train":500
                       }
    dataset_length = images_per_class[split]*classes
    t = [dataset[p+i:i+q] for i in range(0, dataset_length-q, step)]
    all_images = [image for image_class_dict in t for image in image_class_dict["image"]]
    all_labels = [image for image_class_dict in t for image in image_class_dict["label"]]
    return Dataset.from_dict({"image":all_images, "label":all_labels})

In [13]:
val_transform = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
])
train_transform = v2.Compose([
    v2.Lambda(lambda x: x.convert('RGB')),
    v2.RandAugment(),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
])
train_dataset = TorchDatasetWrapper(load_dataset('Maysee/tiny-imagenet', split='train'),
                                    transform = train_transform)
valid_dataset = TorchDatasetWrapper(load_dataset('Maysee/tiny-imagenet', split='valid'),
                                    transform = val_transform)
train_batch_size = 900
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
val_batch_size = 900
val_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=val_batch_size, shuffle=False)