In [2]:


import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense,Conv2D,MaxPooling2D,Flatten,Dropout,BatchNormalization
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix,f1_score,precision_score,recall_score
import seaborn as sns



mnist= tf.keras.datasets.mnist



In [3]:




(x_train,y_train),(x_test,y_test)=mnist.load_data()



Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [4]:


input_shape=(28,28,1)





In [5]:

# Reshaping for CNN
x_train=x_train.reshape(x_train.shape[0],28,28,1)
x_test=x_test.reshape(x_test.shape[0],28,28,1)



In [6]:



# Converting to float32
x_train=x_train.astype('float32')
x_test=x_test.astype('float32')




# Normalization
x_train=x_train/255
x_test=x_test/255




In [7]:


# Data Augmentation
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rotation_range=10,
    zoom_range=0.1,
    width_shift_range=0.1,
    height_shift_range=0.1
)
datagen.fit(x_train)




In [8]:

# Building the model with Batch Normalization
model=Sequential()
model.add(Conv2D(28,kernel_size=(3,3),input_shape=input_shape)),
model.add(BatchNormalization()),
model.add(MaxPooling2D(pool_size=(2,2))),
model.add(Flatten()),
model.add(Dense(200,activation="relu")),
model.add(Dropout(0.3)),
model.add(Dense(10,activation="softmax"))




  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [9]:


# Compiling with AdamW optimizer
adamw_optimizer = tf.keras.optimizers.AdamW(learning_rate=0.001)
model.compile(optimizer=adamw_optimizer,
              loss="sparse_categorical_crossentropy",
             metrics=['accuracy'])


In [None]:



# Training the model with data augmentation
history = model.fit(datagen.flow(x_train, y_train, batch_size=32),
                    validation_data=(x_test, y_test),
                    epochs=6,
                    verbose=1)



Epoch 1/6


  self._warn_if_super_not_called()


[1m 290/1875[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m1:11[0m 45ms/step - accuracy: 0.5437 - loss: 1.4544

In [None]:


# Plotting Epochs vs. Loss and Accuracy
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Epochs vs. Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Epochs vs. Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()



In [None]:


# Evaluation and Metrics Calculation
test_loss,test_acc=model.evaluate(x_test,y_test, verbose=0)
print("\nModel Evaluation:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")



In [None]:

# Predictions for F1 Score, Precision, Recall, and Confusion Matrix
y_pred_probs = model.predict(x_test)
y_pred_classes = np.argmax(y_pred_probs, axis=1)

# Calculating F1 Score, Precision, and Recall
f1 = f1_score(y_test, y_pred_classes, average='weighted')
precision = precision_score(y_test, y_pred_classes, average='weighted')
recall = recall_score(y_test, y_pred_classes, average='weighted')

print(f"\nF1 Score: {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")


In [None]:

# Calculating Specificity and Sensitivity
cm = confusion_matrix(y_test, y_pred_classes)
FP = cm.sum(axis=0) - np.diag(cm)
FN = cm.sum(axis=1) - np.diag(cm)
TP = np.diag(cm)
TN = cm.sum() - (FP + FN + TP)

sensitivity = TP / (TP + FN)
specificity = TN / (TN + FP)

print("\nSensitivity (True Positive Rate) per class:")
print(sensitivity)
print("\nSpecificity (True Negative Rate) per class:")
print(specificity)

# Plotting Confusion Matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()