In [1]:
import os
from functools import partial
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import (
    CIFAR10,
    CIFAR100,
    DTD,
    STL10,
    EuroSAT,
    FGVCAircraft,
    Flowers102,
    Food101,
    ImageFolder,
    OxfordIIITPet,
    StanfordCars,
)
import matplotlib.pyplot as plt
import numpy as np
import torch
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

DATASET_DICT = {
    "cifar10": [
        partial(CIFAR10, train=True, download=True),
        partial(CIFAR10, train=False, download=True),
        partial(CIFAR10, train=False, download=True),
        10,
    ],
    "cifar100": [
        partial(CIFAR100, train=True, download=True),
        partial(CIFAR100, train=False, download=True),
        partial(CIFAR100, train=False, download=True),
        100,
    ],
    "eurosat": [
        partial(EuroSAT, download=True), # eurosat has no attr 'train'
        partial(EuroSAT, download=True),
        partial(EuroSAT, download=True),
        100,
    ],
    "flowers102": [
        partial(Flowers102, split="train", download=True),
        partial(Flowers102, split="val", download=True),
        partial(Flowers102, split="test", download=True),
        102,
    ],
    "pets37": [
        partial(OxfordIIITPet, split="trainval", download=True),
        partial(OxfordIIITPet, split="test", download=True),
        partial(OxfordIIITPet, split="test", download=True),
        37,
    ],
    "dtd": [
        partial(DTD, split="train", download=True),
        partial(DTD, split="val", download=True),
        partial(DTD, split="test", download=True),
        47,
    ],
    "aircraft": [
        partial(FGVCAircraft, split="train", download=True),
        partial(FGVCAircraft, split="val", download=True),
        partial(FGVCAircraft, split="test", download=True),
        100,
    ],
    # "cars": [
    #     partial(StanfordCars,split="train", download=True),
    #     partial(StanfordCars,split="test", download=True),
    #     partial(StanfordCars,split="test", download=True),
    #     196,
    # ],
}

# Function to denormalize and convert tensor to numpy
def im_convert(tensor):
    image = tensor.clone().detach().numpy().transpose(1, 2, 0)
    image = image * 0.5 + 0.5  # Denormalize
    image = np.clip(image, 0, 1)
    return image

# Function to get a subset of data
def get_subset(dataset, num_samples=5):
    return torch.utils.data.Subset(dataset, range(num_samples))

# Visualization function
def visualize_datasets(datasets_to_show, num_samples=5):
    num_datasets = len(datasets_to_show)
    fig, axes = plt.subplots(num_datasets, num_samples, figsize=(num_samples*3, num_datasets*3))

    for row_idx, (dataset_name, dataset) in enumerate(datasets_to_show.items()):
        for col_idx in range(num_samples):
            ax = axes[row_idx, col_idx]
            image, label = dataset[col_idx]
            image = im_convert(image)
            ax.imshow(image)
            ax.set_title(f"Class {label}")
            ax.axis('off')
        axes[row_idx, 0].set_ylabel(dataset_name, size='large')

    plt.tight_layout()
    plt.show()

# Main execution
if __name__ == "__main__":
    root = "data/"
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    datasets_to_show = {}
    for dataset_name, (train_fn, _, _, _) in DATASET_DICT.items():
        dataset = train_fn(root=root, transform=transform)
        datasets_to_show[dataset_name] = get_subset(dataset)

    visualize_datasets(datasets_to_show)


  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified
Files already downloaded and verified


HTTPError: HTTP Error 404: Not Found