
# Final layer updation of a pre-trained model, if new classes come in the target dataset



In [17]:
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import Adam
import numpy as np
from tensorflow.keras.datasets import mnist

# Define the CNN model with a feature extraction layer
class CNNFeatureExtractor(Model):
    def __init__(self, num_classes=8):
        super(CNNFeatureExtractor, self).__init__()
        self.conv1 = layers.Conv2D(32, kernel_size=(3, 3), activation='relu')
        self.pool1 = layers.MaxPooling2D(pool_size=(2, 2))
        self.conv2 = layers.Conv2D(64, kernel_size=(3, 3), activation='relu')
        self.pool2 = layers.MaxPooling2D(pool_size=(2, 2))
        self.flatten = layers.Flatten()
        self.feature_layer = layers.Dense(128, activation='relu', name="feature_layer")
        self.output_layer = layers.Dense(num_classes, activation='softmax')

    def call(self, inputs, training=False):
        x = self.conv1(inputs)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        features = self.feature_layer(x)
        if training:
            return self.output_layer(features)
        return features


In [18]:
# Load and preprocess MNIST data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255
x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255

# Split data into training on classes 0-7 and few-shot adaptation on classes 8-9
train_idx_0_7 = np.where(y_train < 8)[0]
train_idx_8_9 = np.where(y_train >= 8)[0]
x_train_0_7, y_train_0_7 = x_train[train_idx_0_7], y_train[train_idx_0_7]
x_train_8_9, y_train_8_9 = x_train[train_idx_8_9], y_train[train_idx_8_9] - 8  # Re-label classes 8,9 to 0,1

# Train the CNN on classes 0-7
cnn_model = CNNFeatureExtractor(num_classes=8)
cnn_model.compile(optimizer=Adam(), loss="sparse_categorical_crossentropy", metrics=["accuracy"])
cnn_model.fit(x_train_0_7, y_train_0_7, epochs=5, batch_size=64, validation_split=0.1)


Epoch 1/5
[1m678/678[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 7ms/step - accuracy: 0.8934 - loss: 0.3409 - val_accuracy: 0.0000e+00 - val_loss: 11.9750
Epoch 2/5
[1m678/678[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - accuracy: 0.9901 - loss: 0.0332 - val_accuracy: 0.0000e+00 - val_loss: 13.0834
Epoch 3/5
[1m678/678[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - accuracy: 0.9916 - loss: 0.0261 - val_accuracy: 0.0000e+00 - val_loss: 14.3762
Epoch 4/5
[1m678/678[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.9946 - loss: 0.0172 - val_accuracy: 0.0000e+00 - val_loss: 15.5040
Epoch 5/5
[1m678/678[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.9961 - loss: 0.0124 - val_accuracy: 2.0747e-04 - val_loss: 14.9043


<keras.src.callbacks.history.History at 0x7ad8353f0df0>

In [19]:

# Prototypical Networks for Few-shot Learning on Classes 8-9
def compute_prototypes(features, labels):
    prototypes = []
    for label in np.unique(labels):
        class_features = features[labels == label]
        prototype = np.mean(class_features, axis=0)
        prototypes.append(prototype)
    return np.stack(prototypes)

# Split few-shot data into support and query sets
support_idx = np.random.choice(len(x_train_8_9), size=5, replace=False)
query_idx = np.setdiff1d(np.arange(len(x_train_8_9)), support_idx)

x_support, y_support = x_train_8_9[support_idx], y_train_8_9[support_idx]
x_query, y_query = x_train_8_9[query_idx], y_train_8_9[query_idx]

# Extract support and query features using the trained CNN model
support_features = cnn_model(x_support, training=False).numpy()
query_features = cnn_model(x_query, training=False).numpy()
prototypes = compute_prototypes(support_features, y_support)

# Few-shot evaluation function
def few_shot_accuracy(query_features, prototypes, y_query):
    dists = np.linalg.norm(query_features[:, np.newaxis] - prototypes, axis=2)
    preds = dists.argmin(axis=1)
    return np.mean(preds == y_query)

# Evaluate few-shot learning on classes 8-9
accuracy_8_9 = few_shot_accuracy(query_features, prototypes, y_query)
print(f"Few-shot Accuracy on classes 8-9: {accuracy_8_9:.4f}")

# Optionally, evaluate on classes 0-7 and full set
test_idx_0_7 = np.where(y_test < 8)[0]
test_loader_0_7 = x_test[test_idx_0_7]
y_test_0_7 = y_test[test_idx_0_7]

# Evaluate accuracy on classes 0-7
test_features_0_7 = cnn_model(test_loader_0_7, training=False).numpy()
accuracy_0_7 = np.mean(np.argmax(test_features_0_7, axis=1) == y_test_0_7)
print(f"Accuracy on classes 0-7: {accuracy_0_7:.4f}")

Few-shot Accuracy on classes 8-9: 0.9529
Accuracy on classes 0-7: 0.0001
