In [31]:
from io import StringIO
from xmlrpc.client import Boolean
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import VisionDataset
from torch.utils.data import Dataset, DataLoader
from enum import Enum, IntEnum, unique, auto
from typing import Tuple, List, Union, Any, Optional, Callable
import pandas as pd
import numpy as np
import functools
import skimage.io as io
from skimage import data, color
from skimage.transform import rescale, resize, downscale_local_mean
from skimage.filters import threshold_mean, sobel
from skimage.color import rgb2gray
from skimage import feature
import os.path as ospath
from PIL import Image

from eye_classifier import *

In [32]:
@unique
class TargetType(IntEnum):
    Normal = 0,
    Diabetes = 1,
    Glaucoma = 2,
    Cataract = 3,
    AgeRelatedMacularDegeneration = 4,
    Hypertension = 5,
    PathologicalMyopia = 6,
    Other = 7

class ImageDataset (VisionDataset):
    files = []
    data = None
    classes = []
    targets = []

    def __init__(self, root: str, transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None) -> None:

        super(ImageDataset, self).__init__(root, transform=transform, target_transform=target_transform)
        
    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:           
            img, target = self.data[index], self.targets[index]

            # doing this so that it is consistent with all other datasets
            # to return a PIL Image
            img = Image.fromarray(img)

            if self.transform is not None:
                img = self.transform(img)

            if self.target_transform is not None:
                target = self.target_transform(target)

            return img, target


def read_images(base_dir: str, image_path: str, data_info_csv_file: str) -> ImageDataset:
    dataset = ImageDataset(base_dir)
    dataset.classes = [e.name for e in TargetType]

    data = []
    limit = 10
    for idx, row in pd.read_csv(data_info_csv_file).iterrows():
        left_file = f"{image_path}/{row['Left-Fundus']}"
        right_file = f"{image_path}/{row['Right-Fundus']}"
        
        if (limit == 0): break
        limit -= 1

        if ospath.exists(left_file):
            dataset.files.append(left_file)
            data.append(io.imread(left_file))
            
        if ospath.exists(right_file):
            dataset.files.append(right_file)
            data.append(io.imread(right_file))
        
        if row['N'] == 1:
            dataset.targets.append(int(TargetType.Normal))
        elif row['D'] == 1:
            dataset.targets.append(int(TargetType.Diabetes))
        elif row['G'] == 1:
            dataset.targets.append(int(TargetType.Glaucoma))
        elif row['C'] == 1:
            dataset.targets.append(int(TargetType.Cataract))
        elif row['A'] == 1:
            dataset.targets.append(int(TargetType.AgeRelatedMacularDegeneration))
        elif row['H'] == 1:
            dataset.targets.append(int(TargetType.Hypertension))
        elif row['M'] == 1:
            dataset.targets.append(int(TargetType.PathologicalMyopia))
        else:
            dataset.targets.append(int(TargetType.Other))
    

    dataset.data = np.asarray(data, dtype=np.uint8)
    return dataset


In [39]:
base_dir = "../../data"
image_dir = f"{base_dir}/preprocessed_images"
csv_file = f'{base_dir}/ODIR-5K/data.csv'

ds = read_images(base_dir, image_path=image_dir, data_info_csv_file=csv_file)


In [38]:
#nn = EyeClassifier([512, 512, 3])
#nn.train(ds)


train_loader = DataLoader(ds, batch_size=4, shuffle=True)
for i,j in enumerate(train_loader):
    k = 0

IndexError: list index out of range

In [None]:
ds.data.shape