import os import random from random import shuffle import numpy as np import torch from torch.utils import data from torchvision import transforms as T from torchvision.transforms import functional as F from PIL import Image class ImageFolder(data.Dataset): def __init__(self, root, image_size=256, mode='train'): # """Initializes image paths and preprocessing module.""" # self.root = root # # # GT : Ground Truth # self.GT_paths = root[:-1] + '_GT/' # self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root))) # self.image_size = image_size # self.mode = mode # self.RotationDegree = [0, 90, 180, 270] # self.augmentation_prob = augmentation_prob # print("image count in {} path :{}".format(self.mode, len(self.image_paths))) self.root = root self.GT_paths = [] self.image_paths = [] image_names = os.listdir(root) for name in image_names: if 'mask' in name: continue else: if int(name.split('_')[1]) < 20 and mode == 'train': image_name = os.path.join(root, name) mask_name = os.path.join(root, 'mask' + name[5:]) self.image_paths.append(image_name) self.GT_paths.append(mask_name) else: if int(name.split('_')[1]) >= 20 and mode == 'valid': image_name = os.path.join(root, name) mask_name = os.path.join(root, 'mask' + name[5:]) self.image_paths.append(image_name) self.GT_paths.append(mask_name) self.image_size = image_size self.mode = mode self.RotationDegree = [0, 90, 180, 270] print("image count in {} path :{}".format(self.mode, len(self.image_paths))) def __getitem__(self, index): """Reads an image from a file and preprocesses it and returns.""" image_path = self.image_paths[index] # filename = image_path.split('_')[-1][:-len(".jpg")] # GT_path = self.GT_paths + 'ISIC_' + filename + '_segmentation.png' GT_path = self.GT_paths[index] image = Image.open(image_path) GT = Image.open(GT_path) Transform = [] Transform.append(T.ToTensor()) Transform = T.Compose(Transform) image = Transform(image) GT = Transform(GT) GT[GT>0.5] = 1 GT[GT<=0.5] = 0 Norm_ = T.Normalize((0.5,), (0.18,)) image = Norm_(image) return image, GT def __len__(self): """Returns the total number of font files.""" return len(self.image_paths) def get_loader(image_path, image_size, batch_size, num_workers=2, mode='train'): """Builds and returns Dataloader.""" dataset = ImageFolder(root = image_path, image_size =image_size, mode=mode) data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) return data_loader