In [None]:
# data loading
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.transforms.functional as FT

from yjlib.Data import DataLoader; 
from yjlib.prep.ImgPrep import ImgPrep 
from PIL import Image
import numpy as np
import json
import scipy.misc # for saving 

class EndoDataset(Dataset):
    def __init__(self, DATA_PATH, IMAGE_SIZE, split):
        dataloader = DataLoader(datatype='img')
        self.imgPrep = ImgPrep() 
        
        # data path
        self.DATA_PATH = DATA_PATH
        self.IMAGE_SIZE = IMAGE_SIZE

        self.split = split.upper()
        assert self.split in {'TRAIN','VAL', 'TEST'}

        # Read data files
        with open(os.path.join(DATA_PATH, self.split + '_images.json'), 'r') as j:
            self.images = json.load(j) # image data pathes in file

        with open(os.path.join(DATA_PATH, self.split + '_labels.json'), 'r') as j:
            self.labels = json.load(j) # boxes and labels data pathes in file

        assert len(self.images) == len(self.labels) # if do not match number of images with number of boxes and labels.
        
    def __getitem__(self, i):
        """getitem
        Parameter
        ---------
        i: number of order data
        
        PIL(raw_image) -> array -> PIL -> Tensor(prep_image)
        
        """
        raw_image = Image.open(self.images[i], mode = 'r') # read PIL Image

        # image processing
        removed_image = self.imgPrep.remove_pad(np.array(raw_image))# convert PIL to array # remove pad
        
        # Transform
        
        
        label = self.labels[i]
        tensor_label = torch.FloatTensor([label])
        
        content_transform = transforms.Compose([transforms.Resize(self.IMAGE_SIZE),
                                                transforms.ToTensor(), 
                                                transforms.Normalize([0,0,0],[1,1,1])])
        
        pil_image = Image.fromarray(removed_image) # Convert array to pil
        prep_image = content_transform(pil_image) # transform needs PIL image
        
        del raw_image, removed_image, pil_image, label
        
        return prep_image, tensor_label
    
    
    
    def __len__(self):
        return len(self.images)

## Double label and read path

In [None]:
# 저장된 path 기반으로 dataset 및 loader 만들기
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.transforms.functional as FT

from yjlib.prep.ImgPrep import ImgPrep 
from PIL import Image
import numpy as np
import json
import scipy.misc # for saving 

class PathDataset(Dataset):
    def __init__(self, IMAGE_SIZE, IMAGE_PATH=[], LABELS=[],get_path_flag=True):
        self.imgPrep = ImgPrep() 
        
        # data path
        self.IMAGE_SIZE = IMAGE_SIZE
        self.images = IMAGE_PATH
        self.labels = LABELS
        self.get_path_flag = get_path_flag

        assert len(self.images) == len(self.labels) # if do not match number of images with number of boxes and labels.
        
    def __getitem__(self, i):
        """getitem
        Parameter
        ---------
        i: number of order data
        
        PIL(raw_image) -> array -> PIL -> Tensor(prep_image)
        
        """
        raw_image = Image.open(self.images[i], mode = 'r') # read PIL Image

        # image processing
        removed_image = self.imgPrep.remove_pad(np.array(raw_image))# convert PIL to array # remove pad
        
        label = self.labels[i]
        tensor_label = torch.FloatTensor([label])
        
        content_transform = transforms.Compose([transforms.Resize(self.IMAGE_SIZE),
                                                transforms.ToTensor(), 
                                                transforms.Normalize([0,0,0],[1,1,1])])
        
        pil_image = Image.fromarray(removed_image) # Convert array to pil
        prep_image = content_transform(pil_image) # transform needs PIL image
        
        del raw_image, removed_image, pil_image, label
        
        if self.get_path_flag == False:
            return prep_image, tensor_label
        elif self.get_path_flag == True:
            return prep_image, tensor_label, self.images[i]
    
    def __len__(self):
        return len(self.images)

## Return path

In [None]:
# 저장된 path 기반으로 dataset 및 loader 만들기
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.transforms.functional as FT

from yjlib.prep.ImgPrep import ImgPrep 
from PIL import Image
import numpy as np
import json
import scipy.misc # for saving 

class PathDataset(Dataset):
    def __init__(self, IMAGE_SIZE, IMAGE_PATH=[], LABELS=[],get_path_flag=True):
        self.imgPrep = ImgPrep() 
        
        # data path
        self.IMAGE_SIZE = IMAGE_SIZE
        self.images = IMAGE_PATH
        self.labels = LABELS
        self.get_path_flag = get_path_flag

        assert len(self.images) == len(self.labels) # if do not match number of images with number of boxes and labels.
        
    def __getitem__(self, i):
        """getitem
        Parameter
        ---------
        i: number of order data
        
        PIL(raw_image) -> array -> PIL -> Tensor(prep_image)
        
        """
        raw_image = Image.open(self.images[i], mode = 'r') # read PIL Image

        # image processing
        removed_image = self.imgPrep.remove_pad(np.array(raw_image))# convert PIL to array # remove pad
        
        label = self.labels[i]
        tensor_label = torch.FloatTensor([label])
        
        content_transform = transforms.Compose([transforms.Resize(self.IMAGE_SIZE),
                                                transforms.ToTensor(), 
                                                transforms.Normalize([0,0,0],[1,1,1])])
        
        pil_image = Image.fromarray(removed_image) # Convert array to pil
        prep_image = content_transform(pil_image) # transform needs PIL image
        
        del raw_image, removed_image, pil_image, label
        
        if self.get_path_flag == False:
            return prep_image, tensor_label
        elif self.get_path_flag == True:
            return prep_image, tensor_label, self.images[i]
    
    def __len__(self):
        return len(self.images)
    
class EndoDataset(Dataset):
    def __init__(self, DATA_PATH, IMAGE_SIZE, split):
        dataloader = DataLoader(datatype='img')
        self.imgPrep = ImgPrep() 
        
        # data path
        self.DATA_PATH = DATA_PATH
        self.IMAGE_SIZE = IMAGE_SIZE

        self.split = split.upper()
        assert self.split in {'TRAIN','VAL', 'TEST'}

        # Read data files
        with open(os.path.join(DATA_PATH, self.split + '_images.json'), 'r') as j:
            self.images = json.load(j) # image data pathes in file

        with open(os.path.join(DATA_PATH, self.split + '_labels.json'), 'r') as j:
            self.labels = json.load(j) # boxes and labels data pathes in file

        assert len(self.images) == len(self.labels) # if do not match number of images with number of boxes and labels.
        
    def __getitem__(self, i):
        """getitem
        Parameter
        ---------
        i: number of order data
        
        PIL(raw_image) -> array -> PIL -> Tensor(prep_image)
        
        """
        raw_image = Image.open(self.images[i], mode = 'r') # read PIL Image

        # image processing
        removed_image = self.imgPrep.remove_pad(np.array(raw_image))# convert PIL to array # remove pad
        
        # Transform
        
        
        label = self.labels[i]
        tensor_label = torch.FloatTensor([label])
        
        content_transform = transforms.Compose([transforms.Resize(self.IMAGE_SIZE),
                                                transforms.ToTensor(), 
                                                transforms.Normalize([0,0,0],[1,1,1])])
        
        pil_image = Image.fromarray(removed_image) # Convert array to pil
        prep_image = content_transform(pil_image) # transform needs PIL image
        
        del raw_image, removed_image, pil_image, label
        
        return prep_image, tensor_label
    
    
    
    def __len__(self):
        return len(self.images)

## Balanced dataset loading
https://discuss.pytorch.org/t/balanced-sampling-between-classes-with-torchvision-dataloader/2703

In [None]:
def make_weights_for_balanced_classes(images, nclasses):                        
    count = [0] * nclasses                                                      
    for item in images:                                                         
        count[item[1]] += 1                                                     
    weight_per_class = [0.] * nclasses                                      
    N = float(sum(count))                                                   
    for i in range(nclasses):                                                   
        weight_per_class[i] = N/float(count[i])                                 
    weight = [0] * len(images)                                              
    for idx, val in enumerate(images):                                          
        weight[idx] = weight_per_class[val[1]]                                  
    return weight     

In [None]:
dataset_train = datasets.ImageFolder(traindir)                                                                         
                                                                                
# For unbalanced dataset we create a weighted sampler                       
weights = make_weights_for_balanced_classes(dataset_train.imgs, len(dataset_train.classes))                                                                
weights = torch.DoubleTensor(weights)            

sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))                     
                                                                                
train_loader = torch.utils.data.DataLoader(dataset_train, 
                                           batch_size=args.batch_size, 
                                           shuffle = True,
                                           sampler = sampler, 
                                           num_workers=args.workers, 
                                           pin_memory=True)     