In [25]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from imutils import paths
import numpy as np
import argparse
import imutils
import cv2
import os
from tqdm import tqdm
import random

In [19]:
def image_to_feature_vector(image, size=(150, 150)):
    return cv2.resize(image, size).flatten()

def extract_color_histogram(image, bins=(8, 8, 8)):
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    hist = cv2.calcHist([hsv], [0, 1, 2], None, bins, [0, 180, 0, 256, 0, 256])
    if imutils.is_cv2():
        hist = cv2.normalize(hist)
    else:
        cv2.normalize(hist, hist)
    return hist.flatten()

In [20]:
IMG_SIZE = 150
categories = ["NORMAL", "PNEUMONIA"]

def create_training_data(DATADIR="../chest_xray/train"):
    training_data = []
    for category in categories:

        path = os.path.join(DATADIR,category)
        class_num = categories.index(category)
        count = 0
        for img in tqdm(os.listdir(path)):
            count += 1
            if class_num == 1 and count == 1342:
                break
            try:
                image = cv2.imread(os.path.join(path, img))
                if class_num == 0:
                    label = [1, 0]
                else:
                    label = [0, 1]
                pixels = image_to_feature_vector(image)
                hist = extract_color_histogram(image)

                training_data.append([pixels, hist, label])
            except Exception as e:
                pass

    return training_data

training_data = create_training_data("../chest_xray/train")

100%|██████████| 1341/1341 [00:40<00:00, 33.29it/s]
 35%|███▍      | 1341/3875 [00:14<00:27, 91.87it/s] 


In [21]:
test_data = create_training_data("../chest_xray/test")

100%|██████████| 234/234 [00:05<00:00, 40.05it/s]
100%|██████████| 390/390 [00:03<00:00, 111.57it/s]


In [22]:
val_data = create_training_data("../chest_xray/val")

100%|██████████| 8/8 [00:00<00:00, 45.19it/s]
100%|██████████| 8/8 [00:00<00:00, 98.55it/s]


In [26]:
x_train, x_test, x_val = [], [], []
y_train, y_test, y_val = [], [], []

x_train_pixel, x_test_pixel, x_val_pixel = [], [], []


In [27]:
random.shuffle(training_data)
random.shuffle(test_data)
random.shuffle(val_data)

In [28]:
for pixel,features,label in training_data:
    x_train.append(features)
    y_train.append(label)
    x_train_pixel.append(pixel)
    
for pixel,features,label in test_data:
    x_test.append(features)
    y_test.append(label)
    x_test_pixel.append(pixel)

for pixel,features,label in val_data:
    x_val.append(features)
    y_val.append(label)
    x_val_pixel.append(pixel)

np.save("../saved_data/KNN_saved/x_train_pixel.npy", x_train_pixel)
np.save("../saved_data/KNN_saved/x_train.npy", x_train)
np.save("../saved_data/KNN_saved/y_train.npy", y_train)

np.save("../saved_data/KNN_saved/x_test_pixel.npy", x_test_pixel)
np.save("../saved_data/KNN_saved/x_test.npy", x_test)
np.save("../saved_data/KNN_saved/y_test.npy", y_test)

np.save("../saved_data/KNN_saved/x_val_pixel.npy", x_val_pixel)
np.save("../saved_data/KNN_saved/x_val.npy", x_val)
np.save("../saved_data/KNN_saved/y_val.npy", y_val)

In [34]:
x_train = np.load("../saved_data/KNN_saved/x_train.npy", allow_pickle=True)
x_test = np.load("../saved_data/KNN_saved/x_test.npy", allow_pickle=True)
x_train_pixel = np.load("../saved_data/KNN_saved/x_train_pixel.npy", allow_pickle=True)
x_test_pixel = np.load("../saved_data/KNN_saved/x_test_pixel.npy", allow_pickle=True)
y_train = np.load("../saved_data/KNN_saved/y_train.npy", allow_pickle=True)
y_test = np.load("../saved_data/KNN_saved/y_test.npy", allow_pickle=True)

print("evaluating histogram accuracy...")
model = KNeighborsClassifier(n_neighbors=3,n_jobs=1)
model.fit(x_train, y_train)
acc = model.score(x_test, y_test)
print("histogram accuracy: {:.2f}%".format(acc * 100))

print("evaluating raw pixel accuracy...")
model = KNeighborsClassifier(n_neighbors=1,n_jobs=1)
model.fit(x_train_pixel, y_train)
acc = model.score(x_test_pixel, y_test)
print("raw pixel accuracy: {:.2f}%".format(acc * 100))


evaluating histogram accuracy...
histogram accuracy: 73.72%
evaluating raw pixel accuracy...
raw pixel accuracy: 78.69%
