In [1]:
import cv2
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from utils import plot_examples
from PIL import Image
import torch
import torch.nn as nn
import os
import seaborn as sns
sns.set()

# Create dataset augmentation

In [2]:
# [(cat.jpg, 0), ..., (dog.jpg, 1)]
class ImageFolder(nn.Module):
    def __init__(self, root_dir, transform=None) -> None:
        super().__init__()
        self.data = []
        self.root_dir = root_dir
        self.transform = transform
        self.class_names = os.listdir(root_dir)
        
        for index, name in enumerate(self.class_names):
            files = os.listdir(os.path.join(root_dir, name))
            self.data += list(zip(files, [index] * len(files)))
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        img_file, label = self.data[index]
        root_and_dir = os.path.join(self.root_dir, self.class_names[label])
        image = np.array(Image.open(os.path.join(root_and_dir, img_file)))
        if self.transform is not None:
            augmentation = self.transform(image=image) # bboxes=bboex...
            image = augmentation["image"]
        return image, label
            

In [3]:
transform = A.Compose(
    [
        A.Resize(width=1920, height=1080),
        A.RandomCrop(width=1280, height=720),
        # If you don't want to have reflections after rotation
        A.Rotate(limit=40, p=0.9, border_mode=cv2.BORDER_CONSTANT),
        A.HorizontalFlip(p=0.1),
        A.VerticalFlip(0.5),
        A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25),
        A.OneOf(# Choose random transformation
            [
                A.Blur(blur_limit=3, p=0.5),
                A.ColorJitter(p=0.5),
                
            ], p=1.0), # Chance to use one of these transforms 
        A.Normalize(
            mean=(0, 0, 0),
            std=(1, 1, 1),
            max_pixel_value=255,
        ),
        ToTensorV2(),
     ]
)

In [4]:
dataset = ImageFolder(root_dir="cat_dogs", transform=transform)

In [5]:
for x, y in dataset:
    print(x.shape, y)

torch.Size([3, 720, 1280]) 0
torch.Size([3, 720, 1280]) 0
torch.Size([3, 720, 1280]) 0
torch.Size([3, 720, 1280]) 0
torch.Size([3, 720, 1280]) 0
torch.Size([3, 720, 1280]) 0
torch.Size([3, 720, 1280]) 0
torch.Size([3, 720, 1280]) 0
torch.Size([3, 720, 1280]) 0
torch.Size([3, 720, 1280]) 1
torch.Size([3, 720, 1280]) 1
torch.Size([3, 720, 1280]) 1
torch.Size([3, 720, 1280]) 1
torch.Size([3, 720, 1280]) 1
torch.Size([3, 720, 1280]) 1
torch.Size([3, 720, 1280]) 1
torch.Size([3, 720, 1280]) 1
torch.Size([3, 720, 1280]) 1
torch.Size([3, 720, 1280]) 1
