In [1]:
import os
from functools import partial
import ssl
import torch
from torch.utils.data import Subset
from torchvision import transforms
from torchvision.datasets import (
    CIFAR10,
    CIFAR100,
    DTD,
    EuroSAT,
    FGVCAircraft,
    Flowers102,
    OxfordIIITPet,
    StanfordCars,
)
import matplotlib.pyplot as plt
import numpy as np

# Disable SSL certificate verification (use with caution)
ssl._create_default_https_context = ssl._create_unverified_context

# Define datasets and their configurations
DATASET_CONFIG = {
    "CIFAR-10": {
        "class_names": None,  # Will be extracted from the dataset later
        "train": partial(CIFAR10, train=True, download=True),
        "test": partial(CIFAR10, train=False, download=True),
        "num_classes": 10,
    },
    "CIFAR-100": {
        "class_names": None,
        "train": partial(CIFAR100, train=True, download=True),
        "test": partial(CIFAR100, train=False, download=True),
        "num_classes": 100,
    },
    "EuroSAT": {
        "class_names": None,
        "train": partial(EuroSAT, download=True),
        "test": partial(EuroSAT, download=True),
        "num_classes": 10,  # EuroSAT has 10 classes
    },
    "Flowers102": {
        "class_names": None,
        "train": partial(Flowers102, split="train", download=True),
        "test": partial(Flowers102, split="val", download=True),
        "num_classes": 102,
    },
    # "Oxford-IIIT-Pets": {
    #     "class_names": None,
    #     "train": partial(OxfordIIITPet, split="trainval", download=True),
    #     "test": partial(OxfordIIITPet, split="test", download=True),
    #     "num_classes": 37,
    # },
    # "DTD": {
    #     "class_names": None,
    #     "train": partial(DTD, split="train", download=True),
    #     "test": partial(DTD, split="val", download=True),
    #     "num_classes": 47,
    # },
    "FGVC-Aircraft": {
        "class_names": None,
        "train": partial(FGVCAircraft, split="train", download=True),
        "test": partial(FGVCAircraft, split="val", download=True),
        "num_classes": 100,
    },
    "Stanford Cars": {
        "class_names": None,
        "train": partial(StanfordCars, split="train", download=True),
        "test": partial(StanfordCars, split="test", download=True),
        "num_classes": 196,
    },
}

def im_convert(tensor):
    """
    Denormalize the tensor and convert it to a NumPy array for visualization.
    """
    image = tensor.clone().detach().cpu().numpy().transpose(1, 2, 0)
    image = image * 0.5 + 0.5  # Denormalize
    image = np.clip(image, 0, 1)
    return image

def get_subset(dataset, num_samples=5):
    """
    Get a subset of the dataset.

    Parameters:
        dataset (torch.utils.data.Dataset): The dataset.
        num_samples (int): Number of samples to get.

    Returns:
        torch.utils.data.Subset: A subset of the dataset.
    """
    indices = list(range(min(num_samples, len(dataset))))
    return Subset(dataset, indices)

def get_class_names(dataset, dataset_name):
    """
    Get class names for the dataset.

    Parameters:
        dataset (torch.utils.data.Dataset): Dataset instance.
        dataset_name (str): Name of the dataset.

    Returns:
        list or None: List of class names, or None if not available.
    """
    if hasattr(dataset, 'classes'):
        return dataset.classes
    else:
        # For datasets without 'classes' attribute, return None
        print(f"{dataset_name} doesn't have 'classes' attribute, using class indices instead.")
        return None

def load_datasets(root, transform, num_samples=5):
    """
    Load and prepare datasets for visualization.

    Parameters:
        root (str): Root directory where datasets are stored.
        transform (torchvision.transforms.Compose): Transforms to apply.
        num_samples (int): Number of samples per dataset.

    Returns:
        tuple: Dictionary mapping dataset names to subsets, dictionary mapping dataset names to class names.
    """
    datasets_to_show = {}
    class_names_dict = {}
    for name, config in DATASET_CONFIG.items():
        try:
            dataset = config["train"](root=root, transform=transform)
            datasets_to_show[name] = get_subset(dataset, num_samples)
            class_names = get_class_names(dataset, name)
            class_names_dict[name] = class_names
        except Exception as e:
            print(f"Error loading dataset {name}: {e}")
    return datasets_to_show, class_names_dict

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.font_manager import FontProperties
import seaborn as sns

def visualize_datasets(datasets_to_show, class_names_dict, num_samples=5, save_path=None):
    """
    Create a publication-quality visualization of samples from multiple datasets.

    Parameters:
        datasets_to_show (dict): Dictionary mapping dataset names to dataset subsets.
        class_names_dict (dict): Dictionary mapping dataset names to lists of class names.
        num_samples (int): Number of samples per dataset.
        save_path (str or None): Path to save the image. If None, display the image.
    """
    num_datasets = len(datasets_to_show)
    
    # # Set up the plot style using a valid style
    plt.style.use('seaborn-v0_8-paper')  # Use a valid matplotlib style
    sns.set_style("whitegrid", {'axes.grid': False})
    
    
    # Create figure and gridspec
    fig = plt.figure(figsize=(15, 2.5 * num_datasets))  # Reduced height for less spacing
    gs = gridspec.GridSpec(num_datasets, num_samples + 1, width_ratios=[0.2] + [1] * num_samples, hspace=0.3)  # Adjusted hspace for less vertical spacing
    
    # # Use Times New Roman font
    # title_font = FontProperties(family='Times New Roman', weight='bold', size=18)  # Increased size
    # dataset_font = FontProperties(family='Times New Roman', weight='bold', size=14)  # Increased size
    # class_font = FontProperties(family='Times New Roman', size=12)  # Increased size

    # Use a different font if Times New Roman is not available
    title_font = FontProperties(family='DejaVu Sans', weight='bold', size=18)
    dataset_font = FontProperties(family='DejaVu Sans', weight='bold', size=14)
    class_font = FontProperties(family='DejaVu Sans', size=10)

    # Set title
    fig.suptitle("Comprehensive Visualization of Dataset Samples", fontproperties=title_font, y=0.98)

    for row_idx, (dataset_name, subset) in enumerate(datasets_to_show.items()):
        # Add dataset name
        ax = fig.add_subplot(gs[row_idx, 0])
        ax.text(0.5, 0.5, dataset_name, fontproperties=dataset_font, 
                ha='center', va='center', rotation=0)
        ax.axis('off')

        class_names = class_names_dict.get(dataset_name, None)
        
        for col_idx in range(num_samples):
            ax = fig.add_subplot(gs[row_idx, col_idx + 1])
            try:
                image, label = subset[col_idx]
                image = im_convert(image)
                ax.imshow(image)
                
                # Get class name
                if class_names and label < len(class_names):
                    class_name = class_names[label]
                else:
                    class_name = f"Class {label}"
                
                    # Start of Selection
                    # Add class name as caption with adjusted position to ensure visibility without overlap
                ax.text(0.5, -0.1, class_name, fontproperties=class_font,
                    ha='center', va='bottom', transform=ax.transAxes,
                    wrap=True, clip_on=False, color='black')

                    # Add class name as caption
                # ax.text(0.5, -0.1, class_name, fontproperties=class_font,
                #         ha='center', va='top', transform=ax.transAxes,
                #         wrap=True)
            except IndexError:
                ax.axis('off')
            ax.set_xticks([])
            ax.set_yticks([])

    plt.tight_layout()
    fig.subplots_adjust(top=0.92, bottom=0.05, left=0.05, right=0.95)  # Adjusted for better spacing

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Visualization saved to {save_path}")
    else:
        plt.show()

    plt.close(fig)

def main():
    """
    Main function to execute dataset visualization.
    """
    root = "../data/"
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    num_samples = 5
    datasets_to_show, class_names_dict = load_datasets(root, transform, num_samples=num_samples)

    visualize_datasets(
        datasets_to_show,
        class_names_dict=class_names_dict,
        num_samples=5,
        save_path="publication_quality_dataset_samples.png"
    )


if __name__ == "__main__":
    main()


  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified
Files already downloaded and verified
Flowers102 doesn't have 'classes' attribute, using class indices instead.


  plt.tight_layout()


Visualization saved to publication_quality_dataset_samples.png


In [2]:
import matplotlib.font_manager as fm

# 列出所有可用字体
available_fonts = fm.findSystemFonts(fontpaths=None, fontext='ttf')
print(available_fonts)

# 检查 Times New Roman 是否在可用字体中
if any('Times New Roman' in font for font in available_fonts):
    print("Times New Roman is available.")
else:
    print("Times New Roman is not available.")


['/usr/share/fonts/truetype/dejavu/DejaVuSerif-Bold.ttf', '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', '/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', '/usr/share/fonts/truetype/dejavu/DejaVuSansMono-Bold.ttf', '/usr/share/fonts/truetype/dejavu/DejaVuSerif.ttf', '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf']
Times New Roman is not available.


In [3]:
print(plt.style.available)


['Solarize_Light2', '_classic_test_patch', '_mpl-gallery', '_mpl-gallery-nogrid', 'bmh', 'classic', 'dark_background', 'fast', 'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn-v0_8', 'seaborn-v0_8-bright', 'seaborn-v0_8-colorblind', 'seaborn-v0_8-dark', 'seaborn-v0_8-dark-palette', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8-deep', 'seaborn-v0_8-muted', 'seaborn-v0_8-notebook', 'seaborn-v0_8-paper', 'seaborn-v0_8-pastel', 'seaborn-v0_8-poster', 'seaborn-v0_8-talk', 'seaborn-v0_8-ticks', 'seaborn-v0_8-white', 'seaborn-v0_8-whitegrid', 'tableau-colorblind10']
