In [None]:
import ssl
import certifi

ssl._create_default_https_context = ssl._create_unverified_context

from tensorflow.keras.datasets import cifar10
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, Flatten, Dropout, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tensorflow.keras.utils import to_categorical

In [None]:
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

In [None]:
X_train, X_test = X_train / 255.0, X_test / 255.0

In [None]:
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

In [None]:
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.1, random_state=0)

In [None]:
resnet = ResNet50(weights='imagenet', include_top=False, input_shape=X_train.shape[1:])
resnet.trainable = False

In [None]:
model = Sequential()

# Add the ResNet base model
model.add(resnet)

# Add GlobalAveragePooling layer to flatten the feature maps
model.add(GlobalAveragePooling2D())

# Add Batch Normalization for stability
model.add(BatchNormalization())

# Add a fully connected (Dense) layer
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))

# Add the output layer with softmax for classification into 10 categories
model.add(Dense(10, activation='softmax'))

# Set the optimizer
optimizer = Adam(learning_rate=0.0005)

# Compile the model
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

# Set callbacks: Reduce learning rate on plateau and early stopping
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10, min_lr=0.00001)
early_stopping = EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True, verbose=1)

In [None]:
# Train the model
history = model.fit(X_train, y_train,
                    validation_data=(X_valid, y_valid),
                    epochs=20, batch_size=64,
                    callbacks=[reduce_lr, early_stopping],
                    verbose=2)


In [None]:
plt.figure(figsize=(15,6))

In [None]:
# Plot loss
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss', color='#8502d1')
plt.plot(history.history['val_loss'], label='Validation Loss', color='darkorange')
plt.legend()
plt.title('Loss Evolution')

# Plot accuracy
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Accuracy', color='#8502d1')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy', color='darkorange')
plt.legend()
plt.title('Accuracy Evolution')

plt.show()