In [None]:
import math
import random
from pathlib import Path

import timm
import numpy as np
from rich import print
from PIL import Image
import matplotlib.pyplot as plt

from cataract_classifier.utils import display_dir_items

In [None]:
dataset_path = Path("../input/processed_images/")

display_dir_items(dataset_path / "train")

In [None]:
img_filepaths = [filepath for filepath in (dataset_path / "train/").rglob("*") if filepath.is_file()]

In [None]:
# Create a list to store the first image found for each class
DISPLAY_NUM_IMAGES = 20
sample_image_paths = random.sample(img_filepaths, k=DISPLAY_NUM_IMAGES)
sample_labels = [path.parent.stem for path in sample_image_paths]

# Calculate the number of rows and columns
grid_size = math.floor(math.sqrt(len(sample_image_paths)))
n_rows = grid_size+(1 if grid_size**2 < len(sample_image_paths) else 0)
n_cols = grid_size

# Create a figure for the grid
fig, axs = plt.subplots(n_rows, n_cols, figsize=(12,12))

for i, ax in enumerate(axs.flatten()):
    # If we have an image for this subplot
    if i < len(sample_image_paths) and sample_image_paths[i]:
        ax.imshow(np.array(Image.open(sample_image_paths[i]).convert('RGB')))
        ax.set_title(sample_labels[i])

    # Remove the axis
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
plt.scatter(
    x=[Image.open(f).width for f in img_filepaths],
    y=[Image.open(f).height for f in img_filepaths],
    marker="o",
    s=3.0,
    alpha=0.5
)
plt.xlabel("Image Width")
plt.ylabel("Image Height")
plt.show()

In [None]:
print(timm.list_models("*efficientnet*", pretrained=True))

In [None]:
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=1)

In [None]:
from timm.models import efficientnet

model_cfg = efficientnet.default_cfgs['efficientnet_b0'].default.to_dict()
print(model_cfg)