In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.xception import Xception
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, MaxPooling2D, Dropout, Flatten, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD , Adam
from tensorflow.keras.callbacks import CSVLogger
import pandas as pd
from sklearn.metrics import classification_report
import os
import warnings
warnings.filterwarnings( 'ignore' )

In [None]:
train_dir = '/content/drive/MyDrive/BACHAugment'
test_dir = '/content/drive/MyDrive/BACHtest'

# Define the image size and batch size
image_size = (224, 224)
batch_size = 32

random_seed = np.random.seed(1142)
train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    featurewise_center=True,
    featurewise_std_normalization=True,
    validation_split= 0.2)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size= image_size,
    batch_size=batch_size,
    seed = random_seed,
    shuffle=False,
    subset = 'training',
    class_mode='categorical')

val_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size= image_size,
    batch_size=batch_size,
    seed = random_seed,
    shuffle=False,
    subset = 'validation',
    class_mode='categorical')

test_datagen=ImageDataGenerator(rescale=1./255)
test_generator =test_datagen.flow_from_directory(test_dir,
                                                  target_size=image_size,
                                                  batch_size=batch_size,
                                                  seed=random_seed,
                                                  shuffle=False,
                                                  class_mode='categorical') # set as training data


Found 3368 images belonging to 4 classes.
Found 840 images belonging to 4 classes.
Found 103 images belonging to 4 classes.


In [None]:
# base_model = Xception(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# x = base_model.output
# x = GlobalAveragePooling2D()(x)
# x = Dense(512, activation='relu')(x)
# x = Dropout(0.2)(x)
# output = Dense(4, activation='softmax')(x)

# model = Model(inputs=base_model.input, outputs=output)


# optimizer = Adam(learning_rate=1e-2)
# # Compile the model
# model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

# model.summary()

def create_model():

    METRICS = [
    'accuracy',
    tf.keras.metrics.Precision(name='precision'),
    tf.keras.metrics.Recall(name='recall'),
    tf.keras.metrics.AUC(name='AUC'),
    ]

    model = Sequential([Xception(input_shape=(image_size[0], image_size[1], 3),
                                                              include_top=False),
                               Conv2D(32, (3,3),activation="relu"),
                               BatchNormalization(),
                               MaxPooling2D(2,2),
                               Dropout(0.3),

                               Flatten(),
                               Dense(64, activation="relu"),
                               Dense(4, activation="softmax"),
                               ])
    model.compile(optimizer=SGD(learning_rate = 0.001), loss="categorical_crossentropy", metrics=METRICS)

    return model

model = create_model()
model.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 xception (Functional)       (None, 7, 7, 2048)        20861480  
                                                                 
 conv2d_14 (Conv2D)          (None, 5, 5, 32)          589856    
                                                                 
 batch_normalization_14 (Bat  (None, 5, 5, 32)         128       
 chNormalization)                                                
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 2, 2, 32)         0         
 2D)                                                             
                                                                 
 dropout_2 (Dropout)         (None, 2, 2, 32)          0         
                                                                 
 flatten_2 (Flatten)         (None, 128)              

In [None]:
csv_logger = CSVLogger('/content/drive/MyDrive/BACH-Results/XceptionNet/training_results.csv')

# Train the model
history = model.fit(train_generator, epochs=100, validation_data=val_generator, callbacks=[csv_logger])

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100

In [None]:
from sklearn.metrics import confusion_matrix

# Assuming you have a trained model called 'model' and a test generator called 'test_generator'
# Generate predictions for the test data
predictions = model.predict(test_generator)

# Convert predictions into labels (assuming one-hot encoding)
predicted_labels = np.argmax(predictions, axis=1)

# Get the true labels from the test generator
true_labels = test_generator.classes

# Get the class labels from the test generator
class_labels = list(test_generator.class_indices.keys())

# Create the confusion matrix
cm = confusion_matrix(true_labels, predicted_labels)

# Plot the confusion matrix
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(class_labels))
plt.xticks(tick_marks, class_labels, rotation=45)
plt.yticks(tick_marks, class_labels)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')

# Add text annotations in each cell
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j, i, format(cm[i, j], 'd'),
                 ha="center", va="center",
                 color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.show()