In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten, Lambda, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import MobileNetV2
from sklearn.datasets import fetch_lfw_people
from sklearn.model_selection import train_test_split
import numpy as np
import cv2
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler

# Load LFW dataset
lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)
n_samples, h, w = lfw_people.images.shape
X = lfw_people.data
y = lfw_people.target
target_names = lfw_people.target_names
n_classes = target_names.shape[0]

# Resize the images to 96x96 and Expand the input data to have a fourth dimension
X_train = []
for img in X:
    resized_img = cv2.resize(img, (96, 96))
    # Repeat the grayscale image across three channels
    three_channel_img = np.repeat(resized_img[..., np.newaxis], 3, -1)
    X_train.append(three_channel_img)
X_train = np.array(X_train)

# Split data into training, validation, and test sets
X_train, X_test, y_train, y_test = train_test_split(X_train, y, test_size=0.3)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5)

# One-hot encode the targets
y_train = to_categorical(y_train, num_classes=n_classes)
y_val = to_categorical(y_val, num_classes=n_classes)
y_test = to_categorical(y_test, num_classes=n_classes)

def am_softmax_loss(margin=0.35, scale=30.0):
    def loss(y_true, y_pred):
        y_true = tf.cast(y_true, tf.int32)
        y_true = tf.one_hot(y_true, depth=n_classes)
        y_pred = tf.cast(y_pred, tf.float32)
        y_pred = y_true * (y_pred - margin) + (1 - y_true) * y_pred
        y_pred *= scale
        return tf.keras.losses.categorical_crossentropy(y_true, y_pred, from_logits=True)
    return loss

# Define the model with fewer layers for faster training
base_model = MobileNetV2(weights='imagenet', include_top=False)
# Freeze the base_model
base_model.trainable = False

inputs = Input(shape=(96, 96, 3))
x = base_model(inputs)
x = Flatten()(x)
x = BatchNormalization()(x)  # Add a BatchNormalization layer here
x = Dense(n_classes)(x) # Match the number of units to the number of classes
outputs = Lambda(lambda x: K.l2_normalize(x,axis=1))(x)

model = Model(inputs, outputs)

# Compile the model with a learning rate schedule
initial_learning_rate = 0.01

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=100000,
    decay_rate=0.96,
    staircase=True)

model.compile(optimizer=Adam(learning_rate=lr_schedule), loss=am_softmax_loss(margin=0.35, scale=30.0), metrics=['accuracy'])

# Create an ImageDataGenerator object for data augmentation
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)

# Fit the model using the data generator
early_stopping_cb = EarlyStopping(monitor='val_loss', patience=3)
checkpoint_cb = ModelCheckpoint("best_model.h5", save_best_only=True)

model.fit(datagen.flow(X_train, y_train), epochs=20,
          validation_data=(X_val, y_val),
          callbacks=[early_stopping_cb, checkpoint_cb])

# Evaluate the model on the test set using the best model weights
model.load_weights("best_model.h5")
results = model.evaluate(X_test, y_test)
accuracy = results[1]

print(f"Accuracy on test set: {accuracy}")

Exception ignored in: <function _xla_gc_callback at 0x793a7e8d5900>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lib/__init__.py", line 97, in _xla_gc_callback
    def _xla_gc_callback(*args):
KeyboardInterrupt: 


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50

KeyboardInterrupt: ignored