In [1]:
from numpy.random import seed
seed(72)
import tensorflow as tf
tf.random.set_seed(72)

In [2]:
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
from tensorflow.keras import datasets, layers, models, regularizers
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import seaborn as sns
from tensorflow.keras.callbacks import EarlyStopping

In [3]:
data = tf.keras.preprocessing.image_dataset_from_directory(
    directory='D:/Flatiron/X-Ray_pneumonia__phase_4/data',
    batch_size=10000,
    seed=4356    
)

Found 5856 files belonging to 3 classes.


In [4]:
images, labels = next(iter(data))

In [5]:
images, labels = np.array(images), np.array(labels)

In [6]:
train_images, test_images, train_labels, test_labels = train_test_split(
    images,
    labels,
    random_state=42,
    test_size=585
)

In [7]:
train_images, val_images, train_labels, val_labels = train_test_split(
    train_images,
    train_labels,
    random_state=42,
    test_size=585
)

In [8]:
train_images, test_images, val_images = train_images/255, test_images/255, val_images/255

In [9]:
from sklearn.preprocessing import OneHotEncoder

ohe = OneHotEncoder()

train_labels_encoded = ohe.fit_transform(train_labels.reshape(-1, 1)).toarray()

test_labels_encoded = ohe.fit_transform(test_labels.reshape(-1, 1)).toarray()

val_labels_encoded = ohe.fit_transform(val_labels.reshape(-1, 1)).toarray()

In [10]:
def evaluate(model, results, final=False):
    
    #Create a function that provides useful vis for model
    #performance. This is especially useful as we are most
    #concerned with the number of false negatives
    
    if final:
        val_label="test"
    else:
        val_label="validation"
        

    #Extracts metrics from the results of the model (model fitting)
    
    train_loss = results.history['loss']
    val_loss = results.history['val_loss']
    train_accuracy = results.history['accuracy']
    val_accuracy = results.history['val_accuracy']

    #Setting up the plots
    
    fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3, 2, figsize=(20, 10))

    # Plotting loss
    ax1.set_title("Loss")
    sns.lineplot(x=results.epoch, y=train_loss, ax=ax1, label="train")
    sns.lineplot(x=results.epoch, y=val_loss, ax=ax1, label=val_label)
    ax1.legend()

    # Plotting accuracy
    
    ax2.set_title("Accuracy")
    sns.lineplot(x=results.epoch, y=train_accuracy, ax=ax2, label="train")
    sns.lineplot(x=results.epoch, y=val_accuracy, ax=ax2, label=val_label)
    ax2.legend()
    
    #Uses the model to make predictions and creates a confusion
    #matrix for the multiclass
    
    y_pred = model.predict(test_images)
    cm = confusion_matrix(test_labels, np.argmax(y_pred, axis=1))
    cm_df = pd.DataFrame(cm)
    
    #Plotting the multiclass confusion matrix

    sns.heatmap(cm, ax=ax3, annot=True, cmap='Blues', fmt='0.5g')
    
    #Setting up the barplot showing the accuracy of each class
    #This involves creating labels and heights for the plot
    #The heights are determined from the values in the confusion
    #matrix

    label = ['Healthy Accuracy',
             'Bacterial Accuracy',
             'Viral Accuracy']

    height = [(cm_df[0][0]/sum(cm_df[0]))*100,
              (cm_df[1][1]/sum(cm_df[1]))*100,
              (cm_df[2][2]/sum(cm_df[2]))*100]
    
    #Plotting the class accuracy

    ax4.bar(x=label, height=height)
    plt.sca(ax4)
    xlocs, xlabs = plt.xticks()
    plt.ylim(top=100)
    plt.ylabel('Accuracy Percentage')
    plt.title('Model Accuracy')
    for i, j in enumerate(height):
        ax4.text(xlocs[i],
                 j-30,
                 ((str(round(j,1)))+'%'),
                 ha ='center',
                 bbox = dict(facecolor = 'white', alpha = .5))
        
    #Using the previous confusion matrix to create a binary
    #Confusion matrix
        
    cm_simple = [[cm_df[0][0], cm_df[1][0]+cm_df[2][0]],
                 [cm_df[0][1]+cm_df[0][2], cm_df[1][1]+cm_df[1][2]+cm_df[2][1]+cm_df[2][2]]]
    cm_simple_df = pd.DataFrame(cm_simple)
    
    #Plotting the binary confusion matrix
    
    sns.heatmap(cm_simple, ax=ax5, annot=True, cmap='Blues', fmt='0.5g')
    
    #Setting up the barplot showing the accuracy of each class
    #This involves creating labels and heights for the plot
    #The heights are determined from the values in the confusion
    #matrix
    
    simple_label = ['Healthy\n Accuracy',
                    'Pneumonia\n Accuracy']
    
    simple_height = [(cm_simple_df[0][0]/sum(cm_simple_df[0]))*100,
                     (cm_simple_df[1][1]/sum(cm_simple_df[1]))*100]
    
    #Plotting the class accuracy
    
    ax6.bar(x=simple_label, height=simple_height)
    plt.sca(ax6)
    xlocs, xlabs = plt.xticks()
    plt.ylim(top=100)
    plt.ylabel('Accuracy Percentage')
    plt.title('Model Accuracy')
    for k, l in enumerate(simple_height):
        ax6.text(xlocs[k],
                 l-30,
                 ((str(round(l,1)))+'%'),
                 ha ='center',
                 bbox = dict(facecolor = 'white', alpha = .5))

In [24]:
model = models.Sequential()

model.add(layers.Conv2D(filters=256,
                        kernel_size=3,
                        activation="relu",
                        input_shape=(256,256,3)))

model.add(layers.MaxPooling2D(pool_size=2))

model.add(layers.Conv2D(filters=64,
                        kernel_size=3,
                        padding="same",
                        activation ="relu"))

model.add(layers.Flatten())

model.add(layers.Dense(32,activation="relu"))
model.add(layers.Dense(16,activation="relu"))
model.add(layers.Dense(3,activation="softmax"))

In [None]:
model.compile(optimizer='adam',
                 loss='categorical_crossentropy',
                 metrics=['accuracy'])

results = model.fit(
    train_images, 
    train_labels_encoded,
    validation_data=(val_images, val_labels_encoded),
    epochs=10,
    batch_size=128)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10

In [None]:
evaluate(baseline, results)