In [1]:
# This script trains and runs the model along with computing the evalutation criteria. 
# This script can take several hours to run if you do not have access to strong computing power.

In [10]:
# imports 
import tensorflow as tf
import keras 
import numpy as np
from tensorflow.keras import optimizers
from tensorflow.keras.applications import EfficientNetB7
from tensorflow.keras import models, layers
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import os

In [3]:
# set up a base model
base_model = EfficientNetB7(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# freeze the base model layer 
base_model.trainable = False

In [11]:
# splitting the data into train, validation, and test
img_size = (224, 224)
shuffle_value = True
batch_size = 32
seed = 123
validation_split = 0.3

# creating a directory of the data 
data_dir = "Data"
try:
    os.makedirs(data_dir)
    print(f"Directory '{data_dir}' created successfully.")
except FileExistsError:
    print(f"Directory '{data_dir}' already exists.")
except Exception as e:
    print(f"An error occurred: {e}")


Directory 'Data' already exists.


In [5]:
# loading all of the data
full_ds = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    image_size=img_size,
    batch_size=batch_size,
    shuffle=True
)

# define class names/number of classes for model later
class_names = full_ds.class_names
num_classes = len(class_names)


# get test/validation/train sizes
total_size = tf.data.experimental.cardinality(full_ds).numpy()
test_size = int(0.15 * total_size)
val_size = int((15/85) * (total_size - test_size))


Found 9952 files belonging to 2 classes.


In [6]:
# splitting the data

test_ds = full_ds.take(test_size)
train_val_ds = full_ds.skip(test_size)

val_ds = train_val_ds.take(val_size)
train_ds = train_val_ds.skip(val_size)


In [7]:
# building the model 

model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dropout(0.5),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.3),
    layers.Dense(num_classes, activation='softmax')
])


In [8]:
# compiling model 
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)


In [9]:
# model fitting 
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10
)


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [12]:
# unfreezing the base model layer for fine tuning

base_model.trainable = True
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),  # Lower LR for fine-tuning
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)



In [13]:
# evaluating accuracy  on the test set 
test_loss, test_acc = model.evaluate(test_ds)
print(f"Test Accuracy: {test_acc:.4f}")



Test Accuracy: 0.7948


In [14]:
# evaluating precision on the test set 

# define true and predicted y values 
y_true = []
y_pred = []

# run the model on the test images 
for images, labels in test_ds:
    preds = model.predict(images)
    y_true.extend(labels.numpy())
    y_pred.extend(np.argmax(preds, axis=1))

# get the classification report 
print(classification_report(y_true, y_pred, target_names=class_names))

              precision    recall  f1-score   support

  Non-Toxic1       0.85      0.69      0.76       712
      Toxic        0.75      0.89      0.82       760

    accuracy                           0.79      1472
   macro avg       0.80      0.79      0.79      1472
weighted avg       0.80      0.79      0.79      1472



In [15]:
# making a confusion matrix 
cm = confusion_matrix(y_true, y_pred)

# ploting confusion matrix
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')

# save confusion matrix as a png
plt.savefig('confusion_matrix.png')
plt.close()
