In [None]:
import random
import os
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt

In [None]:

# pre-processing of image
def transform_img(img):
    # resize images to 384x384
    if img is not None:
        img = cv2.resize(img, (384, 384))

        img = np.transpose(img, (2,0,1))
        img = img.astype('float32')
        # adjust data range to [-1.0, 1.0]
        img = img / 255.
        img = img * 2.0 - 1.0
    return img

# define reader to train training data
def data_loader(datadir, batch_size=10, mode = 'train'):
    # list items under datadir
    filenames = os.listdir(datadir)
    def reader():
        if mode == 'train':
            # ramomize data order when traning 
            random.shuffle(filenames)
        batch_imgs = []
        batch_labels = []
        for name in filenames:
            filepath = os.path.join(datadir, name)
            img = cv2.imread(filepath)
            img = transform_img(img)
            if name[0] == 'H' or name[0] == 'N':
                # files' name with H means high degree myopic, with N means normal 
                label = 0
            elif name[0] == 'P':
                label = 1
            else:
                print('Not excepted file name')
            # put every data entry to data list
            batch_imgs.append(img)
            batch_labels.append(label)
            if len(batch_imgs) == batch_size:
                # when the lengh of data list is batch_size
                # use those data as a mini-batch, and render it as a data generator
                imgs_array = np.array(batch_imgs).astype('float32')
                labels_array = np.array(batch_labels).astype('float32').reshape(-1, 1)
                yield imgs_array, labels_array
                batch_imgs = []
                batch_labels = []

        if len(batch_imgs) > 0:
            # when leftover sample data is not enough to a batch_size, pack it to one mini_batch
            imgs_array = np.array(batch_imgs).astype('float32')
            labels_array = np.array(batch_labels).astype('float32').reshape(-1, 1)
            yield imgs_array, labels_array

    return reader

# Defining the validation set data reader
def valid_data_loader(datadir, csvfile, batch_size=10, mode='valid'):
    filelists = open(csvfile).readlines()
    def reader():
        batch_imgs = []
        batch_labels = []
        for line in filelists[1:401]:
            line = line.strip().split(',')
            name = line[1]
            #print(line[2])
            label = int(line[2])
            # Load images according to their file names and pre-process the image data
            filepath = os.path.join(datadir, name)
            img = cv2.imread(filepath)
            img = transform_img(img)
            # Every time the data of a sample is read, it is put into the data list
            batch_imgs.append(img)
            batch_labels.append(label)
            if len(batch_imgs) == batch_size:
                # When the length of the data list is equal to batch_size
                # Treat this data as a mini-batch and use it as an output of the data generator
                imgs_array = np.array(batch_imgs).astype('float32')
                labels_array = np.array(batch_labels).astype('float32').reshape(-1, 1)
                yield imgs_array, labels_array
                batch_imgs = []
                batch_labels = []

        if len(batch_imgs) > 0:
            # The remaining samples with less than one batch_size are packed together into a mini-batch
            imgs_array = np.array(batch_imgs).astype('float32')
            labels_array = np.array(batch_labels).astype('float32').reshape(-1, 1)
            yield imgs_array, labels_array

    return reader

In [2]:
import torch
import torchvision

DATADIR = './data/eye/training/PALM-Training400/PALM-Training400/'
train_loader = data_loader(DATADIR, 
                           batch_size=5, mode='train')
data_reader = train_loader()
data = next(data_reader)
data[0].shape, data[1].shape

ModuleNotFoundError: No module named 'torchvision'

In [None]:
DATADIR2 = './data/eye/validation/PALM-Validation400'
CSVFILE = './data/eye/valid_gt/PALM-Validation-GT/labels.csv'
valid_loader = valid_data_loader(DATADIR2, CSVFILE,batch_size=5,mode='valid')
data_reader = valid_loader()
data = next(data_reader)
data[0].shape, data[1].shape