In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
from typing import Union

def visualize(ds, grid_size=(1,1), pdf_name:Union[str, None]=None):
    """
    Visualize the first n samples of the dataset. If grid_size has 0, nothing is done.

    Args:
        ds (tf.data.Dataset): Dataset to visualize.
        grid_size (tuple, optional): Size of the grid to plot the samples. Defaults to (1,1).
        pdf_name (str, optional): Name of the pdf file to save the plot. Defaults to None.
    """
    
    if 0 in grid_size:
        return
    
    n = grid_size[0] * grid_size[1]

    plt.figure(figsize=(12, 12))
    for i, (X, y) in enumerate(ds.take(n)):
        image = X
        label = y
        plt.subplot(*grid_size, i+1)
        plt.title(f'Label: {label}')
        plt.axis('off')
        plt.imshow(image)
    if pdf_name:
        plt.savefig(pdf_name)
    plt.show()


def filter_fn(example, chosen_labels_ids):
    """
    Filters the dataset to only include the chosen labels.

    Args:
        example (dict): Dictionary containing the image and the labels.
        chosen_labels_ids (list): List of the chosen labels ids in cifar10 dataset.

    Returns:
        bool: True if the example contains one of the chosen labels, False otherwise.
    """

    labels = example["label"]
    flags = [tf.reduce_any(tf.equal(labels, label)) for label in chosen_labels_ids]
    return tf.reduce_any(flags)