# Notebook to extract features fro the ODIR-5K dataset from the RETFound model

Steps:
1) Load the ODIR-5K dataset
2) Load the RETFound model
3) Extract features from the RETFound model

## Load the ODIR-5K dataset

In [10]:
import os
import numpy as np

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets.folder import pil_loader

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
IMAGENET_STD = np.array([0.229, 0.224, 0.225])

Create a Dataset to store the images and apply basic transformations

In [8]:
class Dataset():
    """Dataset to load images quicker"""

    def __init__(self, root: str):
        """Create an instance of this Dataset"""
        self.paths = [os.path.join(root, f) for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))]

        self.transforms = transforms.Compose([
            transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
        ])

    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index: int) -> torch.tensor:
        """Returns an image"""

        img_path = self.paths[index]
        img = pil_loader(img_path)

        img = self.transforms(img)

        return img

Instantiate Dataset and DataLoader.

Change IMG_FOLDER and BATCH_SIZE depending on user preferences and hardware constraints.

In [12]:
IMG_FOLDER = '../ODIR-5K/ODIR-5K/Training Images'
BATCH_SIZE = 32 # batch size for the dataloader

data: Dataset = Dataset(IMG_FOLDER)
print(f'Will process {len(data):,d} (out of {len(os.listdir(IMG_FOLDER)):,d}) images from folder [{IMG_FOLDER}]')

# create dataloader: shuffle=False is important to maintain the order of the images
dl: DataLoader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=False)
print(f'Created dataloader with {len(dl):,d} batches')

Will process 7,000 (out of 7,000) images from folder [../ODIR-5K/ODIR-5K/Training Images]
Created dataloader with 219 batches


## Load the RETFound model

In [2]:
import models_vit

def prepare_model(chkpt_dir, arch='vit_large_patch16'):
    """Set up model to extract features with"""

    # build model
    model = models_vit.__dict__[arch](
        img_size=224,
        num_classes=2,
        drop_path_rate=0,
        global_pool=True,
    )

    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    _ = model.load_state_dict(checkpoint['model'], strict=False)
    return model

In [14]:
MODEL_PATH = '../RETFound_cfp_weights.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = prepare_model(MODEL_PATH, 'vit_large_patch16')
model.to(device)
model.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(


## Extract features from the RETFound model

In [22]:
import pickle as pk
from tqdm import tqdm

def save_features(latent_features, data: Dataset, out_folder: str):
    """Save the features"""
    latent_features = np.concatenate(latent_features)
    latent_dict: list = [{'path': path, 'features': feats} for path, feats in zip(data.paths, latent_features)]

    # save features
    out_file: str = os.path.join(out_folder, 'features.pkl')
    with open(out_file, 'wb') as f_out:
        pk.dump(latent_dict, f_out)

In [25]:
OUT_FOLDER = '../features'

# extract features using the model
latent_features: list = []
n_batches: int = len(dl)
with torch.no_grad():
    for n_batch, img in tqdm(enumerate(dl), desc='Processing images', total=n_batches):
        img = img.to(device)
        latent = model.forward_features(img.float())
        latent_features.append(latent.cpu().numpy())

save_features(latent_features, data, OUT_FOLDER)

Processing images: 100%|██████████| 219/219 [21:10<00:00,  5.80s/it]
