## How to load images from directory with appropriate class in pytorch? mainly for vae autoencoder 
### directory structure is:

main_directory/ <br>
------------train/<br>
----------------class_a/ <br>
---------------------a_image_1.jpg <br>
---------------------a_image_2.jpg <br>
----------------class_b/ <br>
---------------------b_image_1.jpg <br>
---------------------b_image_2.jpg <br>
------------val/<br>
----------------class_a/ <br>
---------------------a_image_1.jpg <br>
---------------------a_image_2.jpg <br>
----------------class_b/ <br>
---------------------b_image_1.jpg <br>
---------------------b_image_2.jpg <br>
------------test/<br>
----------------class_a/ <br>
---------------------a_image_1.jpg <br>
---------------------a_image_2.jpg <br>
----------------class_b/ <br>
---------------------b_image_1.jpg <br>
---------------------b_image_2.jpg <br>


### Imports

In [1]:
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
from torchvision import transforms
import cv2 as cv
import torch
import glob
import os 

### Prepare the paths for loading the images

In [2]:
data_dir = r'./CatsDogs/'
train_data_dir = data_dir + 'train' 
validation_data_dir = data_dir + 'valid' 
test_data_dir = data_dir + 'test' 

### Building a Custom dataset that read images.

In [3]:
class yourCustomDatasets(Dataset):
   
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = ['cats', 'dogs']
        self.file_list = []
        for c in self.classes:
            c_dir = os.path.join(self.root_dir, c)
            for file in glob.glob(c_dir+'/*.png'):
                self.file_list.append([file, self.classes.index(c)])
                                      
    def __getitem__(self, idx):

        # Read Image using cv im grayscale
        image_path, label = self.file_list[idx]
        image = cv.imread(image_path, cv.COLOR_BGR2RGB)
        image = image.transpose(2,0,1) # image type till hier ndarray
        image = torch.from_numpy(image)
        image =  self.transform(image)
        #print(image.shape)
        
        return image, label
      
    
    def __len__(self):
        return len(self.file_list)

### Transformation

In [4]:
# transform data with out mean_std normaization
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128,128)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), 
    #transforms.Normalize(mean=mean, std=std)
])
test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128,128)),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), 
    #transforms.Normalize(mean=mean, std=std)
])
#train_data = torchvision.datasets.ImageFolder(root= train_data_dir,transform=train_transforms)

### Import image using Custom DataLoader 

In [5]:
train_dataLoader = DataLoader(dataset= yourCustomDatasets(train_data_dir,transform=train_transforms), batch_size=64)
valid_dataLoader = DataLoader(dataset= yourCustomDatasets(validation_data_dir,transform=train_transforms), batch_size=64)
test_dataLoader = DataLoader(dataset= yourCustomDatasets(test_data_dir,transform=test_transforms), batch_size=64)

### Displays images 

In [None]:
# show patches of respective image
images, _ = next(iter(train_dataLoader))

rows = 1
cols = 5
fig=plt.figure(figsize=(15, 15), dpi=100)
for j in range(0, cols*rows):
    fig.add_subplot(rows, cols, j+1)
    #fig.subplots_adjust(hspace=0.0, wspace=0.5)
    plt.imshow(images[j].permute(2, 1, 0), cmap='gray')
    #print(images[j].shape)
    plt.title((j+1))
    plt.axis('off')
plt.show()