# <font color=#264653><u>Ocular Disease Recognition</u></font>

## <font color=#264653><u>Introduction</u></font>
The eye's fundus is the eye's interior surface opposite the lens, and includes the retina, optic disc, macula, fovea, and posterior pole<sup>[[1]](https://en.wikipedia.org/wiki/Fundus_(eye))</sup>. 

<figure>
<img src="https://cdn.3d4medical.com/media/blog/funduscopy/retina-correlation.jpg" alt="fundus anatomy" width="500" height="500">

<figurecaption><i>Image from [3d4medical](https://3d4medical.com/blog/funduscopy)</i></figurecaption>
</figure>

Ophthalmologists use fundus photography to detect fundus disease such as diabetic retinopathy, glaucoma, age-related macular degeneration, cataracts, hypertension, and myopia. Computer-aided diagnosis have been progressively adopted by ophthalmologists, as its accuracy increased in recent years<sup>[[2]](https://pesquisa.bvsalud.org/portal/resource/pt/wpr-822979)</sup>.

## <font color=#264653><u>Data</u></font>
The data in this notebook is the ODIR binocular fundus image dataset [[3]](https://pesquisa.bvsalud.org/portal/resource/pt/wpr-822979), containing color fundus photographs (CFP), available as multi-class multi-label instances. As we will further discuss while exploring the dataset, there are several different ML problems which can be derived from this dataset - in this notebook, I chose to tackle a binary classification problem, only focusing on a single ocular disease.

## <font color=#264653><u>Imports</u></font>

In [None]:
import os
os.environ["TF_GPU_ALLOCATOR"]="cuda_malloc_async"

import tensorflow as tf
import tensorflow_addons as tfa
import pandas as pd
pd.options.mode.chained_assignment = None;
import numpy as np

from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
from sklearn.manifold import TSNE

from tensorflow.keras.applications import MobileNetV3Small

from kaggle_datasets import KaggleDatasets
from collections import Counter

import random
import re

import matplotlib as mpl
from matplotlib.colors import ListedColormap
from matplotlib.patches import Rectangle
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import seaborn as sns
import cv2

def seed_it_all(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

seed_it_all()

print(f"\n...COMPLETED IMPORT...")

## <font color=#264653><u>Accelerator Setup</u></font>

In [None]:
# Detect hardware, return appropriate distribution strategy
try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  
except ValueError:
    TPU = None

if TPU:
    print(f"\n... RUNNING ON TPU - {TPU.master()} ...")
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.experimental.TPUStrategy(TPU)
else:
    print(f"\n... RUNNING ON CPU/GPU ...")
    physical_devices = tf.config.list_physical_devices('GPU')
    try:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
    except:
        # Invalid device or cannot modify virtual devices once initialized.
        pass
    strategy = tf.distribute.get_strategy()

N_REPLICAS = strategy.num_replicas_in_sync
    
print(f"... # OF REPLICAS: {N_REPLICAS} ...\n")

## <font color=#264653><u>Data Loading</u></font>

In [None]:
DATA_PATH_INIT = "ocular-disease-recognition-odir5k"

if TPU:
    # Google Cloud Dataset path to training and validation images
    DATA_DIR = KaggleDatasets().get_gcs_path(DATA_PATH_INIT)
    save_locally = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
    load_locally = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
else:
    # Local path to training and validation images
    DATA_DIR = "/kaggle/input/" + DATA_PATH_INIT
    save_locally = None
    load_locally = None

print(f"\n... DATA DIRECTORY PATH IS:\n\t--> {DATA_DIR}")

print(f"\n... IMMEDIATE CONTENTS OF DATA DIRECTORY IS:")
for file in tf.io.gfile.glob(os.path.join(DATA_DIR, "*")): print(f"\t--> {file}")

print("\n\n... DATA ACCESS SETUP COMPLETED ...\n")

In [None]:
df = pd.read_csv(os.path.join("/kaggle/input", DATA_PATH_INIT, "full_df.csv"))

In [None]:
df.head()

## <font color=#264653><u>Constants and Helper Functions</u></font>

In [None]:
labels_long = ["Normal", "Diabetes", "Glaucoma", "Cataract", "AMD", "Hypertension", "Myopia", "Other"]
labels_short = [ll[0] for ll in labels_long]

class_short2full = {
    ls: ll
    for ls, ll in zip(labels_short, labels_long)
}

class_dict = {class_ : i for i, class_ in enumerate(class_short2full.keys())}
class_dict_rev = {v: k for k, v in class_dict.items()}

NUM_CLASSES = len(class_dict)

In [None]:
SEED = 42

COLORS = {
    "fig_bg": "#f6f5f5",
    "plot_neut": "#ddbea9",
    "plot_text": "#343a40",
    
    "cmap_color_list": ["#001219", "#005F73", "#0A9396", "#94D2BD", "#E9D8A6",
                        "#EE9B00", "#CA6702", "#BB3E03", "#AE2012", "#9B2226"],
    
    "split": {
        "train": "#264653",
        "val": "#2a9d8f",
        "test": "#e9c46a"
    }
}

COLORS["class"] = {ls: c for ls, c in zip(class_short2full.keys(), COLORS["cmap_color_list"][:len(class_short2full.keys())])}
COLORS["cmap"] = mpl.colors.LinearSegmentedColormap.from_list("", COLORS["cmap_color_list"])
COLORS["cmap_pos"] = mpl.colors.LinearSegmentedColormap.from_list("", ["#F0F3F8", "#D1DBE9", "#A2B7D2", "#7493BC", "#6487B4", "#3D5A80"])

colors_class_list = list(COLORS["class"].values())

FONT_KW = {
    "plot_title" : {
        "fontname": "serif",
        "weight": "bold",
        "size": "25",
        "style": "normal"
    },
    "plot_title_small" : {
        "fontname": "serif",
        "weight": "bold",
        "size": "16",
        "style": "normal"
    },
    "plot_subtitle" : {
        "fontname": "serif",
        "weight": "bold",
        "size": "12",
        "style": "normal"
    },
    "subplot_title" : {
        "fontname": "serif",
        "weight": "bold",
        "size": "18",
        "style": "normal"
    },
    "subplot_title_small" : {
        "fontname": "serif",
        "weight": "bold",
        "size": "12",
        "style": "normal"
    },
    "plot_label" : {
        "fontname": "serif",
        "weight": "bold",
        "size": "16",
        "style": "normal"
    },
    "plot_label_small" : {
        "fontname": "serif",
        "weight": "bold",
        "size": "12",
        "style": "normal"
    },
    "plot_text" : {
        "fontname": "serif",
        "weight": "normal",
        "size": "12",
        "style": "normal"
    },
    "plot_text_small" : {
        "fontname": "serif",
        "weight": "normal",
        "size": "8",
        "style": "normal"
    },
}

In [None]:
def count_values_relative(y):
    bins, vals = np.unique(y, return_counts=True)
    return bins, 100 * vals / np.sum(vals)

def ceil_d(n, d=1000):
    return int(np.ceil(n / d) * d)

def get_subplot_dims(N):
    r = np.ceil(np.sqrt(N))
    c = np.floor(np.sqrt(N))
    if r*c < N:
        r += 1
    return int(r), int(c)


# human sorting based on https://stackoverflow.com/questions/5967500/how-to-correctly-sort-a-string-with-a-number-inside
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    '''
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

def natural_sort_col_unique(df, colname, missing="NA"):
    arr = df[colname].unique().tolist()
    if np.nan in arr:
        arr[arr.index(np.nan)] = missing
    arr.sort(key=natural_keys)
    return arr

    
def conditional_entropy(x,y):
    # entropy of x given y
    y_counter = Counter(y)
    xy_counter = Counter(list(zip(x,y)))
    total_occurrences = sum(y_counter.values())
    entropy = 0
    for xy in xy_counter.keys():
        p_xy = xy_counter[xy] / total_occurrences
        p_y = y_counter[xy[1]] / total_occurrences
        entropy += p_xy * math.log(p_y/p_xy)
    return entropy

def theil_u(x,y):
    s_xy = conditional_entropy(x,y)
    x_counter = Counter(x)
    total_occurrences = sum(x_counter.values())
    p_x = list(map(lambda n: n/total_occurrences, x_counter.values()))
    s_x = ss.entropy(p_x)
    if s_x == 0:
        return 1
    else:
        return (s_x - s_xy) / s_x
    
def correlation_ratio(categories, measurements):
    if isinstance(categories, pd.Series):
        categories = categories.values
    if isinstance(measurements, pd.Series):
        measurements = measurements.values
    fcat, _ = pd.factorize(categories)
    cat_num = np.max(fcat) + 1
    y_avg_array = np.zeros(cat_num)
    n_array = np.zeros(cat_num)
    for i in range(0, cat_num):
        cat_measures = measurements[np.argwhere(fcat == i).flatten()]
        n_array[i] = len(cat_measures)
        y_avg_array[i] = np.average(cat_measures)
    y_total_avg = np.sum(np.multiply(y_avg_array, n_array)) / np.sum(n_array)
    numerator = np.sum(
        np.multiply(n_array, np.power(np.subtract(y_avg_array, y_total_avg),
                                      2)))
    denominator = np.sum(np.power(np.subtract(measurements, y_total_avg), 2))
    if numerator == 0:
        eta = 0.0
    else:
        eta = np.sqrt(numerator / denominator)
    return eta

In [None]:
DATA_PATH = "/kaggle/input/ocular-disease-recognition-odir5k/preprocessed_images"
IMG_SIZE = 224
IMAGE_SIZE = [IMG_SIZE, IMG_SIZE]

def label_image(c):
    label = np.full((NUM_CLASSES), 0, dtype=int)
    label[c] = 1
    return label

def get_gaussian_filter_shape(IMG_SIZE):
    return IMG_SIZE//4 - 1

def blur_image(image, sigma=10):
    filter_shape=get_gaussian_filter_shape(IMG_SIZE)
    return tfa.image.gaussian_filter2d(image, filter_shape=filter_shape, sigma=sigma)

def weighted_image(image, alpha=4, beta=-4, gamma=128):
    return image*alpha + blur_image(image)*beta + gamma

def create_dataset(dataset, img_list, class_label, augment={}):
    for img in img_list:
        image_path = os.path.join(DATA_PATH, img)
        image_label = label_image(class_label)
        try:
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = cv2.resize(image,(IMG_SIZE, IMG_SIZE))
            image = weighted_image(image)
        except:
            continue
        
        dataset.append([np.array(image), image_label])
        
        if augment:
            if class_label in augment.keys():
                if augment[class_label]:
                    image_lr = tf.image.flip_left_right(image)
                    image_ud = tf.image.flip_up_down(image)
                    image_rot90 = tf.image.rot90(image, k=1)
                    image_rot180 = tf.image.rot90(image, k=2)
                    image_rotm90 = tf.image.rot90(image, k=-1)
                    
                    dataset.append([np.array(image_lr), image_label])
                    dataset.append([np.array(image_ud), image_label])
                    dataset.append([np.array(image_rot90), image_label])
                    dataset.append([np.array(image_rot180), image_label])
                    dataset.append([np.array(image_rotm90), image_label])
        
    random.shuffle(dataset)
    return dataset

## <font color=#264653><u>Exploratory Data Analysis</u></font>

In [None]:
df_eda = df.copy()

In [None]:
df_eda["class"] = df_eda["labels"].apply(lambda x: " ".join(re.findall("[a-zA-Z]+", x)))

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16, 6), dpi=70, gridspec_kw={"wspace": 0.5})

fig.patch.set_facecolor(COLORS["fig_bg"])

value_counts = df_eda["class"].value_counts().rename("num").to_frame()
value_counts["percent"] = value_counts / value_counts.sum()
value_counts.reindex(index=COLORS["class"].keys())

b1 = ax1.barh(value_counts.index, value_counts["percent"])

ax1.set_yticks(
    value_counts.index,
    [class_short2full[i] for i in value_counts.index],
    **FONT_KW["plot_label"], color=COLORS["plot_text"]
)
ax1.tick_params(axis="y", length=0)
ax1.set_title("Label-Based", loc="left", **FONT_KW["subplot_title"], color=COLORS["plot_text"], pad=30)
ax1.text(0, 8.2, "(Multi-class)", **FONT_KW["subplot_title_small"], color=COLORS["plot_text"])

ax1.bar_label(
    b1,
    labels=[str(val) + f"\n({str(np.round(100*pcnt,1))}%)" for val, pcnt in zip(value_counts["num"], value_counts["percent"])],
    color=COLORS["plot_text"],
    **FONT_KW["plot_text"]
)

ax1.set_facecolor(COLORS["fig_bg"])
for i in range(NUM_CLASSES):
    c = COLORS["class"][value_counts.index[i]]
    ax1.get_yticklabels()[i].set_color(c)
    b1[i].set_color(c)

ax1.axes.get_xaxis().set_visible(False)

for spine in ["bottom", "right", "top"]:
    ax1.spines[spine].set_visible(False)


value_count_diag = df_eda[labels_short].sum().rename("num").to_frame()
value_count_diag["percent"] = value_count_diag / df_eda.shape[0]
value_count_diag = value_count_diag.reindex(index=value_counts.index)

b2 = ax2.barh(value_count_diag.index, value_count_diag["percent"])

ax2.set_yticks(
    value_count_diag.index,
    [class_short2full[i] for i in value_count_diag.index],
    **FONT_KW["plot_label"], color=COLORS["plot_text"]
)
ax2.tick_params(axis="y", length=0)
ax2.set_title("Diagnosis-Based", loc="left", **FONT_KW["subplot_title"], color=COLORS["plot_text"], pad=30)
ax2.text(0, 8.2, "(Multi-class Multi-label)", **FONT_KW["subplot_title_small"], color=COLORS["plot_text"])

ax2.bar_label(
    b2,
    labels=[str(val) + f"\n({str(np.round(100*pcnt,1))}%)" for val, pcnt in zip(value_count_diag["num"], value_count_diag["percent"])],
    color=COLORS["plot_text"],
    **FONT_KW["plot_text"]
)

ax2.set_facecolor(COLORS["fig_bg"])
for i in range(NUM_CLASSES):
    c = COLORS["class"][value_count_diag.index[i]]
    ax2.get_yticklabels()[i].set_color(c)
    b2[i].set_color(c)

ax2.axes.get_xaxis().set_visible(False)

for spine in ["bottom", "right", "top"]:
    ax2.spines[spine].set_visible(False)
    
plt.figtext(0, 1.05, "Class Distribution", **FONT_KW["plot_title"], color=COLORS["plot_text"])

plt.show()

We can frame the problem, as presented in the available dataset, in several different ways, e.g.:
- A <b><u>multi-class multi-label classification</u></b> problem, where each instance (image) can contain several classes (e.g., an image belonging to both <i>Diabetes</i> and <i>Other</i> classes).
- A <b><u>multi-class classification</u></b> problem, where each instance can belong to a single class (by using the final labels as classes, e.g.).
- A <b><u>binary classification</u></b> problem, where only instances belonging to 2 classes are samples from the dataset (e.g., only <i>Cataract</i> and <i>Normal</i> instances).

Note that in the above dataframe, each instance contains some information regarding the specific image (e.g., <i>filename</i> and <i>labels</i>) and some information regarding the patient, i.e., each row in the dataframe may contain diagnostic information regarding the right and eft eyes, but the final <i>label</i> only codes the class of the image depicted in the <i>filename</i>.

In [None]:
df["class"] = df["labels"].apply(lambda x: " ".join(re.findall("[a-zA-Z]+", x)))

In [None]:
dict_img_list = {
    class_: df.loc[df["class"]==class_]["filename"].values
    for class_ in class_short2full.keys()
}

We will generate a small dataset with a few samples from each class.
Then, we will use this smaller dataset to visualize the 2D embeddings calculated by t-SNE, to get a sense of which classes or more likely to be similar and which are more easily differentiable:

In [None]:
# based on https://github.com/ageron/handson-ml2/blob/master/08_dimensionality_reduction.ipynb
from sklearn.preprocessing import MinMaxScaler
from matplotlib.offsetbox import AnnotationBbox, OffsetImage

def plot_classes(X, y, min_distance=0.05, images=None, figsize=(13, 10), cmap=COLORS["cmap"], annot=False):
    # Let's scale the input features so that they range from 0 to 1
    X_normalized = MinMaxScaler().fit_transform(X)
    # Now we create the list of coordinates of the digits plotted so far.
    # We pretend that one is already plotted far away at the start, to
    # avoid `if` statements in the loop below
    neighbors = np.array([[10., 10.]])
    fig, ax = plt.subplots(figsize=figsize)
    classes = np.unique(y)
    n_classes = len(classes)
    for class_ in classes:
        ax.scatter(
            X_normalized[y == class_, 0],
            X_normalized[y == class_, 1],
            c=COLORS["class"][class_dict_rev[class_]],
            alpha=0.7,
        )
    
    if annot:
        for index, image_coord in enumerate(X_normalized):
            closest_distance = np.linalg.norm(neighbors - image_coord, axis=1).min()
            if closest_distance > min_distance:
                neighbors = np.r_[neighbors, [image_coord]]
                if images is None:
                    ax.text(
                        image_coord[0],
                        image_coord[1],
                        class_dict_rev[y[index]],
                        color=COLORS["class"][class_dict_rev[y[index]]],
                        alpha=0.7,
                        **FONT_KW["plot_text_small"]
                    )
                else:
                    image = images[index].reshape(28, 28)
                    imagebox = AnnotationBbox(OffsetImage(image, cmap="binary"), image_coord)
                    ax.add_artist(imagebox)
    
    fig.patch.set_facecolor(COLORS["fig_bg"])
    ax.set_facecolor(COLORS["fig_bg"])
    ax.axis("off")
    
    ax.legend(
        [class_short2full[class_dict_rev[label]] for label in np.unique(y)],
        prop={"family": "serif", "size": 8},
        facecolor=COLORS["fig_bg"]
    )
    
    plt.show()

In [None]:
n_images_per_class = 50
rng = np.random.default_rng(seed=SEED)

for class_ in class_dict.keys():
    ind = rng.choice(len(dict_img_list[class_]), n_images_per_class, replace=False)
    dict_img_list[class_] = dict_img_list[class_][ind]

In [None]:
NUM_CLASSES = len(class_dict)
dataset_viz = []
print("START building visualization dataset")
for i, class_ in enumerate(class_dict.keys()):
    print(f"[{i+1}/{len(class_dict)}] adding {class_short2full[class_]} to dataset ...")
    dataset_viz = create_dataset(dataset_viz, dict_img_list[class_], class_dict[class_])
print("COMPLETE building visualization dataset")

In [None]:
X_viz = np.array([i[0] for i in dataset_viz]).reshape(-1, IMG_SIZE, IMG_SIZE, 3)
y_viz = np.array([i[1] for i in dataset_viz])

In [None]:
tsne = TSNE(
    n_components=2,
    init="pca",
    learning_rate="auto",
    perplexity=50,
    n_iter=5000,
    random_state=SEED
)
  
X_viz = X_viz.reshape((X_viz.shape[0], np.prod(X_viz.shape[1:])))
X_tsne_reduced = tsne.fit_transform(X_viz)

In [None]:
plot_classes(X_tsne_reduced, y_viz.argmax(axis=1), figsize=(8,8), annot=True, min_distance=0.05)

In [None]:
y_viz_label = y_viz.argmax(axis=1)
ind_selected_samples = (y_viz_label == class_dict["N"]) | (y_viz_label == class_dict["G"])
plot_classes(X_tsne_reduced[ind_selected_samples], y_viz_label[ind_selected_samples], figsize=(4,4), annot=True, min_distance=0.05)

We can also examine the 2D embeddings learned by UMAP:

In [None]:
from umap import UMAP

reducer = UMAP()

X_umap_reduced = reducer.fit_transform(X_viz)
plot_classes(X_umap_reduced, y_viz.argmax(axis=1), figsize=(8,8), annot=True, min_distance=0.05)

In [None]:
plot_classes(X_umap_reduced[ind_selected_samples], y_viz_label[ind_selected_samples], figsize=(4,4), annot=True, min_distance=0.05)

As a starting point, we will frame the problem as a <b><u>binary classification</u></b> problem, and will only use instances from the <font color="#0A9396"><b><i>Glaucoma</i></b></font> and <font color="#001219"><b><i>Normal</i></b></font> classes. For this purpose, we will use the final <i>label</i> for each instance, to select the images based on their <i>filename</i>.

Now, we will generate the full dataset from the chosen classes:

## <font color=#264653><u>Generating Dataset</u></font>

In [None]:
CLASSES = ["N", "G"]
NUM_CLASSES = len(CLASSES)

class_dict = {class_ : i for i, class_ in enumerate(CLASSES)}
class_dict_rev = {v: k for k, v in class_dict.items()}

df = df.loc[df["class"].isin(CLASSES)]

dict_img_list = {
    class_: df.loc[df["class"]==class_]["filename"].values
    for class_ in CLASSES
}

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(14,6), dpi=70, gridspec_kw={"wspace": 0.5})

fig.patch.set_facecolor(COLORS["fig_bg"])

value_counts = df["class"].value_counts().rename("num").to_frame()
value_counts["percent"] = value_counts / value_counts.sum()
value_counts.reindex(index=CLASSES)

b1 = ax1.barh(value_counts.index, value_counts["percent"])

ax1.set_yticks(
    value_counts.index,
    [class_short2full[i] for i in value_counts.index],
    **FONT_KW["plot_label_small"], color=COLORS["plot_text"]
)
ax1.tick_params(axis="y", length=0)
ax1.set_title("Original", loc="left", **FONT_KW["subplot_title"], color=COLORS["plot_text"])

ax1.bar_label(
    b1,
    labels=[str(val) + f"\n({str(np.round(100*pcnt,1))}%)" for val, pcnt in zip(value_counts["num"], value_counts["percent"])],
    padding=5,
    color=COLORS["plot_text"],
    **FONT_KW["plot_text"]
)

ax1.set_facecolor(COLORS["fig_bg"])

for i in range(NUM_CLASSES):
    c = COLORS["class"][value_counts.index[i]]
    ax1.get_yticklabels()[i].set_color(c)
    b1[i].set_color(c)

ax1.axes.get_xaxis().set_visible(False)

for spine in ["bottom", "right", "top"]:
    ax1.spines[spine].set_visible(False)


NUM_AUGMENTATIONS = 5
value_counts_aug = value_counts.copy()
value_counts_aug.loc["G", "num"] *= NUM_AUGMENTATIONS
value_counts_aug["percent"] = value_counts_aug["num"] / value_counts_aug["num"].sum()
    
b2 = ax2.barh(value_counts_aug.index, value_counts_aug["percent"])

ax2.set_yticks(
    value_counts_aug.index,
    [class_short2full[i] for i in value_counts_aug.index],
    **FONT_KW["plot_label_small"], color=COLORS["plot_text"]
)
ax2.tick_params(axis="y", length=0)
ax2.set_title("With Minority Class Augmentations", loc="left", **FONT_KW["subplot_title"], color=COLORS["plot_text"])

ax2.bar_label(
    b2,
    labels=[str(val) + f"\n({str(np.round(100*pcnt,1))}%)" for val, pcnt in zip(value_counts_aug["num"], value_counts_aug["percent"])],
    padding=5,
    color=COLORS["plot_text"],
    **FONT_KW["plot_text"]
)

ax2.set_facecolor(COLORS["fig_bg"])

for i in range(NUM_CLASSES):
    c = COLORS["class"][value_counts_aug.index[i]]
    ax2.get_yticklabels()[i].set_color(c)
    b2[i].set_color(c)

ax2.axes.get_xaxis().set_visible(False)

for spine in ["bottom", "right", "top"]:
    ax2.spines[spine].set_visible(False)
    
plt.figtext(0, 1.05, "Class Distribution", **FONT_KW["plot_title"], color=COLORS["plot_text"])
plt.show()

When generating the dataset from images of the chosen classes, we will somewhat try to alleviate the class imbalance, as there are predominantly fewer instances of <font color="#0A9396"><b><i>Glaucoma</i></b></font> than <font color="#001219"><b><i>Normal</i></b></font> instances.

One way to do that is by <b><u>data augmentation</u></b>: when loading images, we will add augmented version of instances from the minority class (i.e., <font color="#0A9396"><b><i>Glaucoma</i></b></font>). The augmentation used in this case are affine image transformations, specifically, flipping and rotating the input image: flipping vertically and horizontally, rotating counter clockwise by 90° and 180°, and rotating clockwise by 90°. In addition to significantly increasing the number of training instances, such augmentations also introduce some positional invariance, as they allow the model to learn different possible locations for the regions of interest (e.g., a lesion, atrophy, etc.). Nonetheless, such traditional data augmentation techniques are often insufficient, as they do not regularize the model enough<sup>[[4]](chrome-extension://efaidnbmnnnibpcajpcglclefindmkaj/https://arxiv.org/pdf/1807.10225.pdf)</sup>.

Other image augmentation techniques are possible, such as elastic transformations, pixel-level augmentation (e.g., modifying brightness, sharpening, blurring, and more), generating artificial data (e.g., by using GANs), and more, however, I will not apply those in this notebook.

In their article <i>"Multi-Label Fundus Image Classification Using Attention Mechanisms and Feature Fusion"</i> <sup>[[5]](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9230753/)</sup>, Zhenwei et. al. tackle this dataset as a multi-class multi-label problem. They suggest an <i>Image Augmentation Model</i> which performs affine transformations (rotations and flips) on image-weighted enhanced version of the input images, as:

$I_{weight} = I_{org}\ast\alpha+I_{blur}\ast\beta + \gamma$

$I_{blur}=I_{org}\ast{kernel_{h\times{w}}}$

That is, the input image is first convolved with a gaussian filter, and the image-weighted enhancement is a linear combination of the original image and the blurred image. In their paper, Zhenwei et. al. used a gaussian kernel with $h=w=63$, however, they resized the input images to have the dimensions $H=W=256$, whereas I opted for slightly smaller images with $H=W=224$; hence, I adjusted the kernel's dimensions to be $h=w=224*0.25-1 = 55$. 

In [None]:
augment = {
    class_dict["N"]: False,
    class_dict["G"]: True,
}

dataset = []
print("START building dataset")
for i, class_ in enumerate(CLASSES):
    print(f"[{i+1}/{len(CLASSES)}] adding {class_short2full[class_]} to dataset ...")
    dataset = create_dataset(dataset, dict_img_list[class_], class_dict[class_], augment=augment)
print("COMPLETE building dataset")

In [None]:
X = np.array([i[0] for i in dataset]).reshape(-1, IMG_SIZE, IMG_SIZE, 3)
y = np.array([i[1] for i in dataset])

### <font color=#264653><u>Data Sampling</u></font>
We will use 60% of the dataset for training, 20% for validation, and 20% for testing.

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X , y, test_size=0.2, stratify=y, random_state=SEED)
X_train, X_val, y_train, y_val = train_test_split(X_train , y_train, test_size=0.25, stratify=y_train, random_state=SEED)

We can plot some samples from each class to examine the training instances before they are fed into a model:

In [None]:
def plot_samples(n_images_per_class=8, figsize=(8, 8), seed=SEED):
    rng = np.random.default_rng(seed=seed)

    num_images = n_images_per_class*NUM_CLASSES
    nrows, ncols = get_subplot_dims(num_images)

    idx_neg = np.concatenate(
        (
            np.arange(start=1, stop=num_images, step=nrows),
            np.arange(start=1, stop=num_images, step=nrows)+1
        )
    )
    idx_neg.sort()

    images_neg = X_train[y_train.argmax(axis=1)==0]
    ind_neg = rng.choice(len(images_neg), n_images_per_class, replace=False)

    images_pos = X_train[y_train.argmax(axis=1)==1]
    ind_pos = rng.choice(len(images_pos), n_images_per_class, replace=False)

    fig = plt.figure(figsize=figsize, tight_layout=True)
    count_neg = 0
    count_pos = 0
    for row in range(nrows):
        for col in range(ncols):
            img_ind = row*ncols + col + 1
            ax = plt.subplot(nrows, ncols, img_ind)
            if img_ind in idx_neg:
                ax.imshow(images_neg[ind_neg[count_neg]])
                count_neg += 1
                ax.set_title(class_short2full[class_dict_rev[0]], color=COLORS["class"][class_dict_rev[0]], **FONT_KW["subplot_title_small"])
            else:
                ax.imshow(images_pos[ind_pos[count_pos]])
                count_pos += 1
                ax.set_title(class_short2full[class_dict_rev[1]], color=COLORS["class"][class_dict_rev[1]], **FONT_KW["subplot_title_small"])
            ax.axis("off")
    plt.show()

In [None]:
plot_samples()

We can see that the images are indeed enhanced, with the <font color="#0A9396"><b><i>Glaucoma</i></b></font> instances also rotated and flipped, as intented in our pipeline.

## <font color=#264653><u>Model</u></font>
We will use <b><u>Transfer Learning</u></b>, as a commonly used approach for dealing with insufficient labeled data. We will use a <code>EfficientNetB0</code> base, pretrained on ImageNet, and will fine-tune it.

<figure>
<img src="https://1.bp.blogspot.com/-DjZT_TLYZok/XO3BYqpxCJI/AAAAAAAAEKM/BvV53klXaTUuQHCkOXZZGywRMdU9v9T_wCLcBGAs/s1600/image2.png" alt="EfficientNet-B0 anatomy" width="600">

<figurecaption><i>EfficientNet-B0 architecture, where MBConv is a mobile inverted bottleneck convolution layer <sup>[[6]](https://ai.googleblog.com/2019/05/efficientnet-improving-accuracy-and.html)</sup>.</i></figurecaption>
</figure>

We will freeze the parameters of all layers in the pretrained base, barring the final block (excluding the classifying head as well, of course).

In [None]:
def plot_train(history, start_epoch=0, is_shift_val=True, suptitle=None, subtitle=None, **fig_opts):
    history_df = pd.DataFrame(history.history)
    mets = history_df.columns[history_df.columns.str.startswith("val")].str.replace("val_","").tolist()
    L = history_df.shape[0]
    
    nrows, ncols = get_subplot_dims(len(mets))

    fig = plt.figure(**fig_opts)
    fig.subplots_adjust(hspace=0.75, wspace=0.75)
    fig.patch.set_facecolor(COLORS["fig_bg"])
    
    for i,met in enumerate(mets):
        if met=="loss":
            value_train = np.round(history_df[met].min(),4)
            value_val = np.round(history_df["val_"+met].min(),4)
        else:
            value_train = np.round(history_df[met].max(),4)
            value_val = np.round(history_df["val_"+met].max(),4)
            
        ax = plt.subplot(nrows, ncols, i+1)
        
        ax.plot(np.arange(start_epoch, L), history_df[met].iloc[start_epoch:], color=COLORS["split"]["train"])
        ax.plot(np.arange(start_epoch - is_shift_val*0.5, L - is_shift_val*0.5), history_df["val_"+met].iloc[start_epoch:], color=COLORS["split"]["val"])
        
        ax.set_title(met, **FONT_KW["subplot_title_small"], pad=30)
        ax.set_xlabel("epoch", **FONT_KW["plot_text_small"])
        ax.set_ylabel("")
        ax.legend(["train", "validation"], prop={"family": "serif", "size": 8}, facecolor=COLORS["fig_bg"])
            
        plt.xticks(**FONT_KW["plot_text_small"])
        plt.yticks(**FONT_KW["plot_text_small"])
        
        ax.text(
            0.5, 1.1,
            f"Best train {met}: {value_train}\n" +\
            f"Best val {met}: {value_val}",
            **FONT_KW["plot_text_small"],
            color=COLORS["plot_text"],
            transform=ax.transAxes,
            ha="center",
            bbox={
                "boxstyle": "Round",
                "fill": False,
                "edgecolor": COLORS["plot_text"]
            }
        )
        
        ax.set_facecolor(COLORS["fig_bg"])
        for spine in ["right", "top"]:
            ax.spines[spine].set_visible(False)

    if suptitle is not None:
        plt.suptitle(
            x=0.5, y=1.005,
            t=suptitle,
            ha="center",
            va="bottom",
            **FONT_KW["plot_title_small"],
            color=COLORS["plot_text"]
        )
        
    if subtitle is not None:
        plt.figtext(
            x=0.5, y=0.98,
            s=subtitle,
            ha="center",
            va="bottom",
            **FONT_KW["plot_subtitle"],
            color=COLORS["plot_text"]
        )
    plt.show()

We can also calculate the initial bias, based on the class distribution in the training data, to allow the model's final layer have a better starting point.

In [None]:
label_count = np.bincount(y_train.argmax(axis=1))
tot = np.sum(label_count)
initial_bias = [np.log(c/tot) for c in label_count]
initial_bias

And also the class weights (note that these were not used in training, as experimenting concluded they do not inprove the performance in this case).

In [None]:
class_weight = {c: (1/lc)*(tot/NUM_CLASSES) for c,lc in enumerate(label_count)}
class_weight

In [None]:
effnet_model = MobileNetV3Small(input_shape=(*IMAGE_SIZE, 3), include_top=False, weights="imagenet")

last_layer_include = 5

if last_layer_include is not None:
    for layer in effnet_model.layers[:-last_layer_include]:
        layer.trainable=False
else:
    for layer in effnet_model.layers:
        layer.trainable=False

preprocess_input_effnet = tf.keras.applications.mobilenet_v3.preprocess_input

The final model will contain the pretrained base, along with its preprocessing unit (in the case of <code>EfficientNet-B0</code>, this is a passthrough function, since for EfficientNet, input preprocessing is included as part of the model). We then add a <b><i>dropout</i></b> layer, followed by <b><i>global averaging</i></b> and <b><i>batch normalization</i></b>. The kernel of the final layer will be regularized by $\ell_1$.

In [None]:
def build_model(base, preprocess_input, output_bias=None, dropout=0.0, L1=1e-3, name="model"):
    if output_bias is not None:
        output_bias = tf.keras.initializers.Constant(output_bias)

    model = tf.keras.Sequential([
        tf.keras.Input(shape=(*IMAGE_SIZE, 3)),
        
        tf.keras.layers.Lambda(preprocess_input, name='preprocessing'),
        base,
        
        tf.keras.layers.Dropout(dropout),
        tf.keras.layers.GlobalAveragePooling2D(),       
        tf.keras.layers.Flatten(),
        tf.keras.layers.BatchNormalization(),
        
        tf.keras.layers.Dense(
            NUM_CLASSES,
            activation='softmax',
            bias_initializer=output_bias,
            kernel_regularizer=tf.keras.regularizers.l1(L1)
        )
    ], name=name)
    
    return model

We will train the model with a <b><i>Focal Loss</i></b> function, as it was previously shown to improve performance when dealing with class imbalance <sup>[[7]](chrome-extension://efaidnbmnnnibpcajpcglclefindmkaj/https://arxiv.org/pdf/1708.02002.pdf)</sup>. By using focal loss, the model is incentivized to focus on learning the samples it still has difficulty classifying.

We will also monitor several metrics along with the loss function: AUC ROC, accuracy, F1 score, and the AUC PR (area under the percision-recall curve). These metrics will assist in better evaluating the model's performance, given the class imbalance prominent in the data.

In [None]:
learning_rate = 3e-4
gamma = 2
dropout = 0.1
L1 = 2e-3

with strategy.scope():
    model = build_model(effnet_model, preprocess_input_effnet, name="effnet_model", output_bias=initial_bias, dropout=dropout, L1=L1)

    METRICS = [
        tf.keras.metrics.AUC(name="auc"),
        tf.keras.metrics.BinaryAccuracy(name="acc"),
        tfa.metrics.F1Score(num_classes=NUM_CLASSES, average="weighted", name="f1"),
        tf.keras.metrics.AUC(name="prc", curve="PR"),
    ]
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss=tf.keras.losses.BinaryFocalCrossentropy(gamma=gamma),
        metrics=METRICS
    )

In [None]:
model.summary()

We will be using a larger than usuall <b><i>batch size</i></b>, to make sure each batch has a higher probability of containing at least some minority class instances.

In [None]:
EPOCHS = 200
BATCH_SIZE = 128

We will also use a learning rate scheduler (reducing the learning rate on plateu), and employ early stopping.

In [None]:
# reduce learning rate on plateu
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    factor=0.75,
    patience=10,
    verbose=1,
    min_delta=0.0001,
    cooldown=0,
    min_lr=1e-6,
)

# early stopping
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    patience=EPOCHS//10,
    restore_best_weights=True,
    verbose=1,
)

# checkpoint
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint("/kaggle/working/ocir_model_initial.h5", save_best_only=True)

callbacks = [checkpoint_cb, early_stopping_cb, reduce_lr]

In [None]:
history = model.fit(
    X_train, y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_val, y_val),
    verbose=1,
    callbacks=callbacks,
)

In [None]:
plot_train(history, start_epoch=1, figsize=(10,8), suptitle="EfficientNet Model", subtitle=f"Dropout={dropout}, L1={L1}, $\gamma$={gamma}, Batch={BATCH_SIZE}")

### <font color=#264653><u>Model Evaluation</u></font>
We will examine the model's performance on the validation set, and later on the test set:

In [None]:
def plot_confusion_matrix(y_true, y_pred, figsize=(16,6), cmap="Blues", suptitle=None):
    cm = confusion_matrix(y_true, y_pred)
    cm_norm = confusion_matrix(y_true, y_pred, normalize="true")
    
    fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=figsize)
    fig.patch.set_facecolor(COLORS["fig_bg"])

    sns.heatmap(
        cm,
        annot=True,
        annot_kws=FONT_KW["plot_text"],
        fmt="d",
        linewidths=3.0,
        linecolor=COLORS["fig_bg"],
        cmap=cmap,
        cbar=False,
        square=True,
        xticklabels=[class_short2full[k] for k in class_dict.keys()],
        yticklabels=[class_short2full[k] for k in class_dict.keys()],
        ax=ax1
    )

    ax1.set_title("Confusion Matrix", **FONT_KW["subplot_title_small"])
    ax1.set_xlabel("Predicted Labels", **FONT_KW["plot_label_small"])
    ax1.set_ylabel("True Labels", **FONT_KW["plot_label_small"])
    
    ax1.set_facecolor(COLORS["fig_bg"])
    ax1.tick_params(axis="both", length=0)
    ax1.set_yticks(
        np.arange(len(ax1.get_yticklabels()))+0.5,
        [label.get_text() for label in ax1.get_yticklabels()],
        **FONT_KW["plot_label_small"]
    )
    ax1.set_xticks(
        np.arange(len(ax1.get_xticklabels()))+0.5,
        [label.get_text() for label in ax1.get_xticklabels()],
        **FONT_KW["plot_label_small"]
    )
    
    for class_, i in class_dict.items():
        ax1.get_xticklabels()[i].set_color(COLORS["class"][class_])
        ax1.get_yticklabels()[i].set_color(COLORS["class"][class_])


    sns.heatmap(
        cm_norm,
        annot=True,
        annot_kws=FONT_KW["plot_text"],
        fmt= ".0%" if np.all(np.allclose(cm_norm, cm_norm.astype(int))) else ".1%",
        linewidths=3.0,
        linecolor=COLORS["fig_bg"],
        cmap=cmap,
        cbar=False,
        square=True,
        xticklabels=[class_short2full[k] for k in class_dict.keys()],
        yticklabels=[class_short2full[k] for k in class_dict.keys()],
        ax=ax2
    )

    ax2.set_title("Confusion Matrix (Normalized)", **FONT_KW["subplot_title_small"])
    ax2.set_xlabel("Predicted Labels", **FONT_KW["plot_label_small"])
    ax2.set_ylabel("True Labels", **FONT_KW["plot_label_small"])
    
    ax2.set_facecolor(COLORS["fig_bg"])
    ax2.tick_params(axis="both", length=0)
    ax2.set_yticks(
        np.arange(len(ax2.get_yticklabels()))+0.5,
        [label.get_text() for label in ax2.get_yticklabels()],
        **FONT_KW["plot_label_small"]
    )
    ax2.set_xticks(
        np.arange(len(ax2.get_xticklabels()))+0.5,
        [label.get_text() for label in ax2.get_xticklabels()],
        **FONT_KW["plot_label_small"]
    )
        
    for class_, i in class_dict.items():
        ax2.get_xticklabels()[i].set_color(COLORS["class"][class_])
        ax2.get_yticklabels()[i].set_color(COLORS["class"][class_])
    
    if suptitle is not None:
        plt.suptitle(suptitle, y=0.98, **FONT_KW["plot_title_small"])

    plt.show()

In [None]:
y_val_pred = np.argmax(model.predict(X_val), axis=1)

In [None]:
plot_confusion_matrix(
    np.argmax(y_val, axis=1), y_val_pred,
    figsize=(8,4),
    cmap=COLORS["cmap_pos"],
    suptitle="Model Performance (Validation)"
)

In [None]:
y_test_prob = model.predict(X_test)
y_test_pred = np.argmax(y_test_prob, axis=1)

In [None]:
plot_confusion_matrix(
    np.argmax(y_test, axis=1), y_test_pred,
    figsize=(8,4),
    cmap=COLORS["cmap_pos"],
    suptitle="Model Performance (Test)"
)

report = classification_report(np.argmax(y_test, axis=1), y_test_pred, target_names=[class_short2full[k] for k in class_dict.keys()])
test_roc_auc = roc_auc_score(np.argmax(y_test, axis=1), y_test_prob[:, 1])

print(report)
print(f"     roc auc       {np.round(test_roc_auc, 2)}")

In [None]:
tflite_model = tf.lite.TFLiteConverter.from_keras_model(model).convert()
with open('/kaggle/working/ocular.tflite', 'wb') as f:
  f.write(tflite_model)

In [None]:
os.system("rm -rf /kaggle/working/*")