In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.font_manager import FontProperties
import seaborn as sns

def visualize_datasets(datasets_to_show, class_names_dict, num_samples=5, save_path=None):
    """
    Create a publication-quality visualization of samples from multiple datasets.

    Parameters:
        datasets_to_show (dict): Dictionary mapping dataset names to dataset subsets.
        class_names_dict (dict): Dictionary mapping dataset names to lists of class names.
        num_samples (int): Number of samples per dataset.
        save_path (str or None): Path to save the image. If None, display the image.
    """
    num_datasets = len(datasets_to_show)
    
    # Set up the plot style
    plt.style.use('seaborn-whitegrid')
    sns.set_style("whitegrid", {'axes.grid': False})
    
    # Create figure and gridspec
    fig = plt.figure(figsize=(15, 3 * num_datasets))
    gs = gridspec.GridSpec(num_datasets, num_samples + 1, width_ratios=[0.2] + [1] * num_samples)
    
    # Use a nice font
    title_font = FontProperties(family='sans-serif', weight='bold', size=16)
    dataset_font = FontProperties(family='sans-serif', weight='bold', size=12)
    class_font = FontProperties(family='sans-serif', size=10)

    # Set title
    fig.suptitle("Sample Visualization Across Multiple Datasets", fontproperties=title_font, y=0.98)

    for row_idx, (dataset_name, subset) in enumerate(datasets_to_show.items()):
        # Add dataset name
        ax = fig.add_subplot(gs[row_idx, 0])
        ax.text(0.5, 0.5, dataset_name, fontproperties=dataset_font, 
                ha='center', va='center', rotation=0)
        ax.axis('off')

        class_names = class_names_dict.get(dataset_name, None)
        
        for col_idx in range(num_samples):
            ax = fig.add_subplot(gs[row_idx, col_idx + 1])
            try:
                image, label = subset[col_idx]
                image = im_convert(image)
                ax.imshow(image)
                
                # Get class name
                if class_names and label < len(class_names):
                    class_name = class_names[label]
                else:
                    class_name = f"Class {label}"
                
                # Add class name as caption
                ax.text(0.5, -0.1, class_name, fontproperties=class_font,
                        ha='center', va='top', transform=ax.transAxes,
                        wrap=True)
            except IndexError:
                ax.axis('off')
            ax.set_xticks([])
            ax.set_yticks([])

    plt.tight_layout()
    fig.subplots_adjust(top=0.94, bottom=0.02, left=0.02, right=0.98)

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Visualization saved to {save_path}")
    else:
        plt.show()

    plt.close(fig)
