In [121]:
import medmnist
from medmnist import OrganAMNIST, INFO, Evaluator
import random
from timeit import default_timer as timer
import numpy as np
import scipy
import sklearn.neighbors
import sklearn.metrics
import skimage.transform
import skimage.util
import matplotlib.pyplot as plt
from mlxtend.plotting import plot_confusion_matrix
import torch
import torch.nn as nn
from torch import optim, inference_mode
from torch.utils.data import Subset, DataLoader
import torchvision
from torchvision import transforms
from torchmetrics import ConfusionMatrix
from tqdm.auto import tqdm


In [122]:
# download train, validation, and test datasets
train_data = OrganAMNIST(split="train", download=True, size=128)
val_data = OrganAMNIST(split="val", download=True, size=128)
test_data = OrganAMNIST(split="test", download=True, size = 128)
classes = {0: "Bladder", 1: "Femur-left", 2: "Femur-right", 3: "Heart", 4: "Kidney-left", 5: "Kidney-right", 6: "Liver", 7: "Lung-left", 8: "Lung-right", 9: "Pancreas", 10: "Spleen"}
non_lr_classes = {0, 3, 6, 9, 10}

In [None]:

# helper functions
def knn_feature_extraction(data, optimizeFlag=False):
    features = []
    for k in range(len(data.imgs)):
        img = data.imgs[k]
        label = data.labels[k][0]

        if optimizeFlag:
            # don't horizontally flip images that have a left/right component
            notLrFlag = label in non_lr_classes
            augmented_imgs = augment_image(img, notLrFlag)
        else:
            augmented_imgs = [img]

        feature_vectors = []
        for aug_img in augmented_imgs:
            # Image processing
            if optimizeFlag:
                aug_img = skimage.filters.sobel(aug_img)
                aug_img = skimage.morphology.dilation(aug_img, skimage.morphology.disk(3))

            vector = []
            vector.append(np.mean(aug_img))
            vector.append(np.std(aug_img))
            for i in [0, 64]:
                for j in [0, 64]:
                    quadrant = aug_img[i:i+64, j:j+64]
                    vector.append(np.mean(quadrant))
                    vector.append(np.std(quadrant))
            feature_vectors.append(vector)

        # Combine features from augmentations
        feature_avg = np.mean(feature_vectors, axis=0)
        features.append(feature_avg)

    return np.asarray(features)


def augment_image(img, notLrFlag):
    augmented = []

    # Original image
    augmented.append(img)

    # Horizontal flip
    if notLrFlag:
        augmented.append(np.fliplr(img))

    # Rotations
    # takes too long to run
    # for angle in [15, -15]:
    #     rotated = skimage.transform.rotate(img, angle, preserve_range=True)
    #     augmented.append(rotated.astype(img.dtype))

    return augmented


In [124]:

# perform knn algorithm
# extract features
train_feats = knn_feature_extraction(train_data)
val_feats = knn_feature_extraction(val_data)

# normalize features
mu = np.mean(train_feats, axis=0)
sig = np.std(train_feats, axis=0)
norm_train_feats = (train_feats - mu)/sig
norm_val_feats = (val_feats - mu)/sig

# compute dist matrix
D = scipy.spatial.distance_matrix(norm_train_feats, norm_val_feats, )
D.shape

# use sklearn to classify data
classifier = sklearn.neighbors.KNeighborsClassifier(n_neighbors=3, metric="euclidean")
classifier.fit(norm_train_feats, train_data.labels.flatten())
predictions = classifier.predict(norm_val_feats)

# report metrics
s = sklearn.metrics.classification_report(val_data.labels.flatten(), predictions)
print(s)


              precision    recall  f1-score   support

           0       0.89      0.90      0.90       321
           1       0.96      0.95      0.95       233
           2       0.95      0.95      0.95       225
           3       0.81      0.58      0.68       392
           4       0.49      0.49      0.49       568
           5       0.61      0.68      0.64       637
           6       0.61      0.73      0.66      1033
           7       0.98      0.93      0.95      1033
           8       0.99      0.87      0.93      1009
           9       0.40      0.31      0.35       529
          10       0.41      0.51      0.46       511

    accuracy                           0.72      6491
   macro avg       0.73      0.72      0.72      6491
weighted avg       0.73      0.72      0.72      6491



In [125]:

# perform knn algorithm
# extract features
train_feats = knn_feature_extraction(train_data, True)
val_feats = knn_feature_extraction(val_data, True)

# normalize features
mu = np.mean(train_feats, axis=0)
sig = np.std(train_feats, axis=0)
norm_train_feats = (train_feats - mu)/sig
norm_val_feats = (val_feats - mu)/sig

# compute dist matrix
D = scipy.spatial.distance_matrix(norm_train_feats, norm_val_feats, )
D.shape

# use sklearn to classify data
classifier = sklearn.neighbors.KNeighborsClassifier(n_neighbors=3, metric="euclidean")
classifier.fit(norm_train_feats, train_data.labels.flatten())
predictions = classifier.predict(norm_val_feats)

# report metrics
s = sklearn.metrics.classification_report(val_data.labels.flatten(), predictions)
print(s)


              precision    recall  f1-score   support

           0       0.47      0.67      0.55       321
           1       0.65      0.58      0.61       233
           2       0.78      0.57      0.66       225
           3       0.90      0.80      0.85       392
           4       0.63      0.61      0.62       568
           5       0.76      0.59      0.66       637
           6       0.88      0.97      0.92      1033
           7       0.93      0.97      0.95      1033
           8       0.98      0.87      0.92      1009
           9       0.56      0.46      0.50       529
          10       0.45      0.64      0.53       511

    accuracy                           0.77      6491
   macro avg       0.73      0.70      0.71      6491
weighted avg       0.78      0.77      0.77      6491

