In [9]:
from typing import List
import numpy as np
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from pathlib import Path
from PIL import Image

In [3]:
path = "./archive/"

In [1]:
# Helper function to check whether moving of images and data augmentations preserved 
# the number of images etc.

def walk_through_dir(path: str):
    dir_tree_generator = os.walk(path)
    dirpath, dirnames, filenames = next(dir_tree_generator)
    print("Root folder: ", dirpath)
    print("It has {} subfolders and {} images".format(len(dirnames), len(filenames)))
    print()
    
    for dirpath, dirnames, filenames in dir_tree_generator:
        print("This is a folder: ", dirpath)
        print("It has {} subfolders and {} images".format(len(dirnames), len(filenames)))
        label = os.path.basename(dirpath)
        print("Class: ", label, "\n")

In [4]:
# To move images to train, val and test folders while also renaming them

def create_train_val_test_folders(root_path: str):
    """Given a folder structure:
                root_folder:
                    - subfolder_1:
                        -img_11
                        -img_2
                        ...
                        -img_n
                    - subfolder_2:
                    ...
                    - subfolder_n
        this function creates a train, val, test folder 
        outside the root folder and moves the images to these
        folders randomly. An image is renamed to include its original label
        which was the original subfolder name. An extension _00 is added
        to indicate no data augmentation in case data augmentation will used later. 
    """
    root_path = Path(root_path)
    os.makedirs("train", exist_ok=True)
    os.makedirs("val", exist_ok=True)
    os.makedirs("test", exist_ok=True)
    
    dir_tree_generator = os.walk(root_path)
    dirpath, dirnames, filenames = next(dir_tree_generator)
    
    for dirpath, dirnames, filenames in dir_tree_generator:
        label = os.path.basename(dirpath)
        subfolder = path / Path(label)
        
        for original_filename in filenames:
            filename, ext = os.path.splitext(original_filename)
            new_filename = filename + "_" + label + "_00" + ext #00 for original
            
            n = np.random.random()
            if n < 0.1:
                os.rename(subfolder / original_filename, Path("test") / new_filename)
            elif n < 0.3:
                os.rename(subfolder / original_filename, Path("val") / new_filename)
            else:
                os.rename(subfolder / original_filename, Path("train") / new_filename)
            
            
            
#create_train_val_test_folders(path)

In [5]:
# Helper dataset opener and helper function

def _get_name(path: str):
    name, ext = os.path.splitext(path.split("_")[-2])
    
    return name

class SeaAnimalsDataset_Open:
    """Dataset class used to open images. 
    This is merely used later to do data augmentation. 
    """
    def __init__(self, imgs_path, transform):
        self.imgs_path = imgs_path
        self.transform = transform
        self.imgs = os.listdir(self.imgs_path)
        self.labels = list(map(lambda img_path: _get_name(img_path), self.imgs))
    
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, idx):
        img_loc = os.path.join(self.imgs_path, self.imgs[idx])
        image = Image.open(img_loc).convert("RGB")
        image_out = self.transform(image)
        label = self.labels[idx]
        
        return image_out, label, self.imgs[idx]

In [None]:
input_size = (128, 128)
transformation = transforms.Compose([transforms.Resize(input_size)])
sea_animals = SeaAnimalsDataset_Open("train", transform=transformation)

In [None]:
# Defining all transformations
# Defining endings and augmentation

# _00: original data
# _01: horizontal transformation
# _02: shape transformation (resize and crop)
# _03: brightness transformation
# _04: contrast transformation
# _05: gaussian noise
# _06: total: all transformations

aug_horizontal = transforms.RandomHorizontalFlip(p = 1)

shape_aug = transforms.RandomResizedCrop(
    (128, 128), scale=(0.1, 0.9), ratio=(0.5,2))

brightness_aug = transforms.ColorJitter(brightness=0.5, contrast=0,
                                       saturation=0, hue=0)

contrast_aug = transforms.ColorJitter(brightness=0, contrast=0.5, saturation=0.2,
                                       hue=0.1)

class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.001):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


transform_gaussian=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0,1),
    AddGaussianNoise(0., 0.001), #Change 0.001 to be a higher number of we need more noise - shall also be done in the AddGaussia
    transforms.ToPILImage()
])

total_aug_data = transforms.Compose([
    aug_horizontal, shape_aug, brightness_aug, contrast_aug, transform_gaussian])

total_aug_labels = transforms.Compose([
    aug_horizontal, shape_aug]) 

    Below we create the folder train_augment which will contain original images and the augmented versions

In [None]:
path_augment = "./train_augment/"
os.makedirs(path_augment, exist_ok=True)

for i in range(len(sea_animals)):
    # Define current image and labels
    img = sea_animals[i][0]
    label = sea_animals[i][1]
    img_name = sea_animals[i][2]
    name, ext = os.path.splitext(img_name)
    
    # Save original img
    img.save(path_augment+name+ext)

    # Horizontal flip
    img_flip = aug_horizontal(img)
    img_flip.save(path_augment+name[:-3]+"_01"+ext) #ext=.jpg, .png etc.
    
    # Crop and resize
    rand_1 = np.random.randint(1000)
    torch.manual_seed(rand_1)
    img_crop = shape_aug(img)
    img_crop.save(path_augment+name[:-3]+"_02"+ext)
    
    # Brightness
    img_bright = brightness_aug(img)
    img_bright.save(path_augment+name[:-3]+"_03"+ext)

    # Colour and saturation
    img_contrast = contrast_aug(img)
    img_contrast.save(path_augment+name[:-3]+"_04"+ext)

    # Gaussian noise
    img_gaussian = transform_gaussian(img)
    img_gaussian.save(path_augment+name[:-3]+"_05"+ext)

    # Combined augmentation
    rand_2 = np.random.randint(1000)
    torch.manual_seed(rand_2)
    img_combined = total_aug_data(img)
    img_combined.save(path_augment+name[:-3]+"_06"+ext)

    Finally, the SeaAnimalsDataset class we will use during training

In [7]:
class SeaAnimalsDataset:

    def __init__(self, img_path, transform,
                 train: bool = True, augmentations: List[str] = ['00']):
        self.img_path = img_path
        self.transform = transform
        self.total_imgs = os.listdir(self.img_path)

        if train:
            self.total_imgs = [
                img_path for img_path in self.total_imgs if img_path[-6:-4] in augmentations]
        self.total_labels = list(map(lambda img_path: _get_name(img_path), self.total_imgs))

    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.img_path, self.total_imgs[idx])
        image = Image.open(img_loc).convert("RGB")
        label = self.total_labels[idx]
        out_image = self.transform(image)
        out_image = transforms.Compose([transforms.ToTensor()])(out_image)

        return out_image, label

    test

In [None]:
augmentations = ['00']
train_imgs_path = 'train_augment'
sea_animals_train = SeaAnimalsDataset(
        img_path=train_imgs_path,
        transform=transformation,
        train=True,
        augmentations=augmentations)
sea_animals_val = SeaAnimalsDataset(
    img_path="val",
    transform=transformation,
    train=False)
sea_animals_test = SeaAnimalsDataset(
    img_path="test",
    transform=transformation,
    train=False)

In [None]:
train_loader = DataLoader(sea_animals_train, batch_size=1, shuffle=True, drop_last=False)
val_loader = DataLoader(sea_animals_val, batch_size=1, shuffle=True, drop_last=False)
test_loader = DataLoader(sea_animals_test, batch_size=1, shuffle=True, drop_last=False)