In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models,optimizers
from tensorflow.keras.applications import ResNet50V2
import matplotlib.pyplot as plt
from functools import partial
from typing import Union

def visualize(ds, grid_size=(1,1), pdf_name:Union[str, None]=None):
    """
    Visualize the first n samples of the dataset. If grid_size has 0, nothing is done.

    Args:
        ds (tf.data.Dataset): Dataset to visualize.
        grid_size (tuple, optional): Size of the grid to plot the samples. Defaults to (1,1).
        pdf_name (str, optional): Name of the pdf file to save the plot. Defaults to None.
    """
    
    if 0 in grid_size:
        return
    
    n = grid_size[0] * grid_size[1]

    plt.figure(figsize=(12, 12))
    for i, (X, y) in enumerate(ds.take(n)):
        image = X
        label = y
        plt.subplot(*grid_size, i+1)
        plt.title(f'Label: {label}')
        plt.axis('off')
        plt.imshow(image)
    if pdf_name:
        plt.savefig(pdf_name)
    plt.show()


def filter_fn(example, chosen_labels_ids):
    """
    Filters the dataset to only include the chosen labels.

    Args:
        example (dict): Dictionary containing the image and the labels.
        chosen_labels_ids (list): List of the chosen labels ids in cifar10 dataset.

    Returns:
        bool: True if the example contains one of the chosen labels, False otherwise.
    """

    labels = example["label"]
    flags = [tf.reduce_any(tf.equal(labels, label)) for label in chosen_labels_ids]
    return tf.reduce_any(flags)

def preprocess_fn(example, chosen_labels_ids):
    """
    Preprocess the dataset to resize the images to 224x224 and normalize the pixel values.
    It also converts the original cifar10 label ids to new ids starting from 0.

    Args:
        example (dict): Dictionary containing the image and the labels.
        chosen_labels_ids (list): List of the chosen labels ids in cofar10 dataset.

    Returns:
        tuple: Tuple containing the preprocessed image and the new label id.
    """
    image = example['image']
    image = tf.image.resize(image, (224, 224))
    image = tf.cast(image, tf.float32)
    image /= 255.0
    labels = example['label']
    label = [tf.reduce_any(tf.equal(labels, label)) for label in chosen_labels_ids]
    label = tf.cast(label, tf.int32)
    label = tf.argmax(label)
    return image, label

def createModel(n_classes=10):
    """
    Creates a model with ResNet50V2 as base model and two dense layers on top.

    Args:
        n_classes (int, optional): Number of classes. Defaults to 2.

    Returns:
        tf.keras.Model: Model with ResNet50V2 as base model and two dense layers on top.
    # """

    base_model = ResNet50V2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    base_model.trainable = False
    model = models.Sequential()
    model.add(base_model)
    model.add(layers.GlobalAveragePooling2D())
    model.add(layers.Dense(100, activation='relu'))
    model.add(layers.Dense(n_classes, activation='softmax'))
    model.compile(optimizer=optimizers.Adam(learning_rate=0.0001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model