In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

def load_dataset(dataset_name, split):
    ds, ds_info = tfds.load(dataset_name, split=split, shuffle_files=True, with_info=True, as_supervised=True)
    return ds, ds_info

def shuffle_dataset(ds):
    buffer_size = tf.data.experimental.cardinality(ds).numpy()
    return ds.shuffle(buffer_size)

def split_dataset(ds, train_ratio=0.9):
    num_examples = tf.data.experimental.cardinality(ds).numpy()
    num_train = int(train_ratio * num_examples)
    train_ds = ds.take(num_train)
    test_ds = ds.skip(num_train)
    return train_ds, test_ds

def preprocess(image, label):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    image = tf.image.resize(image, (224, 224))
    return image, label

def prepare_dataset(ds, batch_size=32):
    return ds.map(preprocess).batch(batch_size)

# Load the Caltech-101 dataset
train_ds, ds_info = load_dataset('caltech101', 'train')
test_ds, _ = load_dataset('caltech101', 'test')

# Concatenate the train and test datasets
all_ds = train_ds.concatenate(test_ds)

# Shuffle the combined dataset
all_ds = shuffle_dataset(all_ds)

# Split the combined dataset into train and test datasets
train_ds, test_ds = split_dataset(all_ds)

# Prepare the datasets for training
train_ds = prepare_dataset(train_ds)
test_ds = prepare_dataset(test_ds)

In [None]:
# Define the model
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation='relu', kernel_initializer="he_normal"),
    tf.keras.layers.Dense(256, activation='relu', kernel_initializer="he_normal"),
    tf.keras.layers.Dense(128, activation='relu', kernel_initializer="he_normal"),
    tf.keras.layers.Dense(ds_info.features['label'].num_classes, activation='softmax')  # Update the output layer to match the number of classes in the dataset
])

# Compile the model
model.compile(optimizer=tf.keras.optimizers.SGD(momentum=0.9, learning_rate=0.001, nesterov=True), loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy'])

In [None]:
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping

reduce_lr = ReduceLROnPlateau(monitor='val_accuracy', factor=0.2, patience=2, min_lr=0.0001)
# lr * factor (0-1)
early_stopping = EarlyStopping(monitor="val_accuracy", patience=5, restore_best_weights=True)

callbacks = [early_stopping, reduce_lr] # callbacks list

In [None]:
# Train the model
model.fit(train_ds, epochs=20, validation_data=test_ds, batch_size=32, callbacks=callbacks)

In [None]:
model.summary()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def get_random_batches(test_ds, num_batches=4):
    return np.random.choice(len(test_ds), size=num_batches, replace=False)

def get_random_image_and_labels(test_ds, batch):
    images, labels = list(test_ds)[batch]
    index = np.random.choice(len(images), size=1)[0]
    return images[index], labels[index]

def predict_labels(model, images):
    y_probs_batch = model.predict(images)
    return np.argmax(y_probs_batch, axis=-1)

def plot_image(ax, image, true_label, pred_label):
    ax.set_axis_off()
    image_to_show = (image + 1) / 2
    ax.imshow(image_to_show.numpy(), cmap=plt.cm.gray_r, interpolation='nearest')
    title = f'True: {true_label}, Pred: {pred_label}'
    ax.set_title(title, color='green' if true_label == pred_label else 'red')

def plot_images(model, test_ds, class_names):
    batches = get_random_batches(test_ds)
    _, axes = plt.subplots(nrows=1, ncols=len(batches), figsize=(20, 6))
    for i, batch in enumerate(batches):
        image, true_label = get_random_image_and_labels(test_ds, batch)
        y_pred_batch = predict_labels(model, image[None, ...])
        pred_label = class_names[y_pred_batch[0]]
        true_label = class_names[true_label.numpy()]
        plot_image(axes[i], image, true_label, pred_label)
    plt.setp(axes, xticks=[], xticklabels=[], yticklabels=[])
    plt.tight_layout()
    plt.show()

# Get the class names from ds_info
class_names = ds_info.features['label'].names

# Plot the images with the actual and predicted labels
plot_images(model, test_ds, class_names)