Few shot learning to classify CIFAR 100 dataset

In [1]:
#import library
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar100
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np

# Load and preprocess CIFAR-100 dataset
(x_train, y_train), (x_test, y_test) = cifar100.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Define data augmentation
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True
)

# Create Few-Shot Learning support and query sets
def create_few_shot_sets(x, y, n_classes, n_support, n_query):
    classes = np.random.choice(np.unique(y), n_classes, replace=False)
    support_set = []
    query_set = []
    support_labels = []
    query_labels = []

    for i, cls in enumerate(classes):
        cls_indices = np.where(y == cls)[0]
        selected_indices = np.random.choice(cls_indices, n_support + n_query, replace=False)
        support_set.append(x[selected_indices[:n_support]])
        query_set.append(x[selected_indices[n_support:]])
        support_labels.append(np.full((n_support,), i))
        query_labels.append(np.full((n_query,), i))

    support_set = np.concatenate(support_set, axis=0)
    query_set = np.concatenate(query_set, axis=0)
    support_labels = np.concatenate(support_labels, axis=0)
    query_labels = np.concatenate(query_labels, axis=0)

    return support_set, support_labels, query_set, query_labels

# Few-Shot Learning hyperparameters
n_classes = 5
n_support = 5
n_query = 15

# Create Few-Shot Learning sets
x_support, y_support, x_query, y_query = create_few_shot_sets(x_train, y_train, n_classes, n_support, n_query)

# Define a custom CNN model with pooling and fully connected layers
def create_custom_cnn_model(input_shape, num_classes):
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(num_classes, activation='softmax')
    ])
    return model

# Create and compile the CNN model
input_shape = x_train.shape[1:]
num_classes = n_classes
model = create_custom_cnn_model(input_shape, num_classes)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model with data augmentation
train_generator = datagen.flow(x_support, y_support, batch_size=16)
validation_data = (x_query, y_query)

model.fit(train_generator, epochs=20, validation_data=validation_data)

# Evaluate the model on query set
loss, accuracy = model.evaluate(x_query, y_query)
print(f'Few-Shot Learning Accuracy: {accuracy:.4f}')


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Few-Shot Learning Accuracy: 0.4533
