In [None]:
import math
import random
from pathlib import Path

import numpy as np
from rich import print
from PIL import Image
import matplotlib.pyplot as plt

from cataract_classifier.utils import display_dir_items

In [None]:
dataset_path = Path("../input/processed_images/")

display_dir_items(dataset_path / "train/cataract/")

In [None]:
train_img_paths = [filepath for filepath in (dataset_path / "train/").rglob("*") if filepath.is_file()]
test_img_paths = [filepath for filepath in (dataset_path / "test/").rglob("*") if filepath.is_file()]

### Visualise images

In [None]:
# Create a list to store the first image found for each class
DISPLAY_NUM_IMAGES = 40
sample_image_paths = random.sample(train_img_paths, k=DISPLAY_NUM_IMAGES)
sample_labels = [path.parent.stem for path in sample_image_paths]

# Calculate the number of rows and columns
grid_size = math.floor(math.sqrt(len(sample_image_paths)))
n_rows = grid_size+(1 if grid_size**2 < len(sample_image_paths) else 0)
n_cols = grid_size

def bordered_image(img, label, border_width=5):
    ny, nx, b = img.shape[0], img.shape[1], border_width
    framed_img = np.zeros((b+ny+b, b+nx+b, img.shape[2]))
    if label.lower() == "cataract":
        framed_img[:, :, 0] = 255
    else:
        framed_img[:, :, 1] = 255
    framed_img[b:-b, b:-b] = img
    framed_img = framed_img.astype(np.uint8)
    return framed_img

# Create a figure for the grid
fig, axs = plt.subplots(n_rows, n_cols, figsize=(12,12))

for i, ax in enumerate(axs.flatten()):
    # If we have an image for this subplot
    if i < len(sample_image_paths) and sample_image_paths[i]:
        label = sample_labels[i]
        img = bordered_image(np.array(Image.open(sample_image_paths[i]).convert('RGB')), label, border_width=2)
        ax.imshow(img)
        ax.set_title(label)

    # Remove the axis
    ax.axis('off')

plt.tight_layout()
plt.show()

Looks like the color of the eyeball is the distinguishing factor.

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 4))

ax[0].scatter(
    x=[Image.open(f).width for f in train_img_paths],
    y=[Image.open(f).height for f in train_img_paths],
    marker="o",
    s=3.0,
    alpha=0.5
)
ax[0].set_xlabel("Width")
ax[0].set_ylabel("Height")
ax[0].set_title("Train Images Aspect Ratio (Width vs Height)")

ax[1].scatter(
    x=[Image.open(f).width for f in test_img_paths],
    y=[Image.open(f).height for f in test_img_paths],
    marker="o",
    s=3.0,
    alpha=0.5
)
ax[1].set_xlabel("Width")
ax[1].set_ylabel("Height")
ax[1].set_title("Test Images Aspect Ratio (Width vs Height)")

plt.show()

Few images have very high resolution. We're going to use mini-batch gradient descent for optimization where these small number of samples won't affect model performance.

Data Augmentation will be applied during training through pytorch dataloaders and albumentation transforms.

Also note that no image is corrupted as the above scatter plot is created by reading all training images.

### Exploring available pretrained models

In [None]:
import timm
print(timm.list_models("*efficientnet*", pretrained=True))

In [None]:
from timm.models import efficientnet

model_cfg = efficientnet.default_cfgs['efficientnet_b0'].default.to_dict()
print(model_cfg)