# Disease Classification using Medical MNIST

Based on the following references:

* https://cainvas.ai-tech.systems/use-cases/disease-classification-app/
* https://medium.com/ai-techsystems/disease-classification-using-medical-mnist-f468655c0de8

---

## Modified MedMNIST dataset

This dataset was developed in 2017 by Arturo Polanco Lozano. It is also known as the **MedNIST dataset for radiology and medical imaging**. 

For the preparation of this dataset, images have been gathered from several datasets, namely, TCIA, the RSNA Bone Age Challange, and the NIH Chest X-ray dataset.

This dataset contains 58954 medical images belonging to 6 categories — AbdomenCT (10000 images), HeadCT (10000 images), Hand (10000 images), CXR (10000 images), CXR (10000 images), BreastMRI (8954 images), ChestCT (10000 images).

In [None]:
from IPython.display import Image
Image(filename='/content/MedMNIST.png', width=400) 

In [None]:
import os
import numpy as np
import pandas as pd
import random, datetime, os, shutil, math

In [None]:
import matplotlib.pyplot as plt
from matplotlib.image import imread
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.utils import plot_model

### Load the dataset

In [None]:
!wget -N "https://cainvas-static.s3.amazonaws.com/media/user_data/cainvas-admin/MedNIST.zip"
!unzip -qo "MedNIST.zip"
!rm "MedNIST.zip"

In [None]:
test_dir = "/content/Medical/Medical_test"
train_dir = "/content/Medical/Medical_train"

In [None]:
def prep_test_data(med, train_dir, test_dir):
  pop = os.listdir(train_dir+'/'+med)
  test_data=random.sample(pop, 2000)
  #print(test_data)
  for f in test_data:
    shutil.copy(train_dir+'/'+med+'/'+f, test_dir+'/'+med+'/')

In [None]:
for medi in os.listdir(train_dir):
  prep_test_data(medi, train_dir, test_dir)

In [None]:
target_classes = os.listdir(train_dir)
num_classes = len(target_classes)
print('Number of target classes:', num_classes)
print(list(enumerate(target_classes)))

In [None]:
training_set_distribution = [len(os.listdir(os.path.join(train_dir, dir))) for dir in os.listdir(train_dir)]
testing_set_distribution = [len(os.listdir(os.path.join(test_dir, dir))) for dir in os.listdir(test_dir)]

## Exploratory data analysis

In [None]:
def show_mri(med):
  num = len(med)
  if num == 0:
    return None
  rows = int(math.sqrt(num))
  cols = (num+1)//rows
  f, axs = plt.subplots(rows, cols)
  fig = 0
  for b in med:
    img = image.load_img(b)
    row = fig // cols
    col = fig % cols
    axs[row, col].imshow(img)
    fig += 1
  plt.show()

In [None]:
dir_name = os.path.join(train_dir,"Hand")
all_images = [os.path.join(dir_name, fname) for fname in os.listdir(dir_name)]
show_mri(all_images[:9])

### Data preprocessing

In [None]:
image_size = (32, 32, 3)
datagen = ImageDataGenerator(rescale = 1./255,
                           shear_range=0.2,
                           zoom_range=0.2,
                           horizontal_flip=True,
                           )

In [None]:
training_set = datagen.flow_from_directory(train_dir,
                                         target_size=image_size[:2],
                                         batch_size=64, # changed from 32
                                         class_mode='categorical',
                                         shuffle=False
                                         #color_mode='rgb'
                                         )

In [None]:
validation_set = datagen.flow_from_directory(test_dir,
                                           target_size=image_size[:2],
                                           batch_size=64, # changed from 32
                                           class_mode='categorical',
                                           shuffle=False
                                           )

### Callback functions 

These functions are using during training for improving model's learning.

In [None]:
es = EarlyStopping(monitor='val_acc', mode='max', verbose=1, patience=7)
filepath = "modelMedicalMNIST.h5"
ckpt = ModelCheckpoint(filepath, monitor='acc', verbose=1, save_best_only=True, mode='max')
rlp = ReduceLROnPlateau(monitor='acc', patience=3, verbose=1)

## Model definition

In [None]:
def cnn(image_size, num_classes):
    classifier = Sequential()
    classifier.add(Conv2D(64, (5, 5), input_shape=image_size, activation='relu', padding='same'))
    classifier.add(MaxPooling2D(pool_size = (2, 2)))
    classifier.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
    classifier.add(MaxPooling2D(pool_size = (2, 2)))
    classifier.add(Flatten())
    classifier.add(Dense(num_classes, activation = 'softmax'))
    classifier.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
    return classifier

neuralnetwork_cnn = cnn(image_size, num_classes)
neuralnetwork_cnn.summary()
#plot_model(neuralnetwork_cnn, show_shapes=True) 

### Model training

In [None]:
history = neuralnetwork_cnn.fit_generator(
    generator=training_set, validation_data=validation_set,
    callbacks=[es, ckpt, rlp], epochs = 5, 
)

In [None]:
fig, ax = plt.subplots(figsize=(20, 6))
pd.DataFrame(history.history).iloc[:, :-1].plot(ax=ax)

In [None]:
batch_size = 64
pred = neuralnetwork_cnn.predict(validation_set, steps=306/batch_size)
predicted_class_indices = np.argmax(pred,axis=1)

In [None]:
labels = (validation_set.class_indices)
labels = dict((v,k) for k,v in labels.items())
predictions = [labels[k] for k in predicted_class_indices]

In [None]:
filenames = validation_set.filenames[0]
results = pd.DataFrame({"Filename":filenames,
                      "Predictions":predictions})

In [None]:
display(results.tail(100))