In [1]:
import os
from skimage import io
import torch
import pandas as pd
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

In [2]:
device = torch.device("gpu:0" if torch.cuda.is_available() else "cpu:0")

In [3]:
# custom dataset
class CustomDataset(Dataset):
    def __init__(self, csv_file:str, root_dir:str, transform:torchvision.transforms=None) -> None:
        super(CustomDataset, self).__init__()
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transforms = transform
        
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
        image = io.imread(img_path)
        y_label = torch.tensor(int(self.annotations.iloc[index, 1]))
        
        if self.transforms:
            image = self.transforms(image)
        
        return image, y_label
    

In [4]:
my_transforms = transforms.Compose([
    transforms.ToPILImage(), # convert to PIL image
    transforms.Resize((256, 256)), # resize the image
    transforms.RandomCrop((224,224)), # crop the image to the size specified
    transforms.RandomRotation(degrees=45), # rotate by a degree
    transforms.ColorJitter(brightness=0.5), # add a coloring of 0.5 to the image
    transforms.RandomHorizontalFlip(p=0.5), # Flip the image horizontal
    transforms.RandomVerticalFlip(p=0.05), # Flib the image vertical with probability
    transforms.RandomGrayscale(p=0.2), # Add GrayScale to the image
    transforms.ToTensor(), # convert the image to tensor
    # Normalize the image with the mean and std of the image array (Note: This has to be know ahead of time for the images in use)
    transforms.Normalize(mean=[0.0,0.0,0.0], std=[1.0,1.0,1.0]) # (value - mean) / std # Note: This does nothing 
])

In [5]:
dataset = CustomDataset(
    csv_file='../dataset/cats_dogs.csv',
    root_dir='../dataset/cats_dogs/',
    transform=my_transforms
)

In [6]:
augmented_data_dir = "../dataset/augmented_data"

In [8]:
img_num = 0
for _ in range(10):
    for img, label in dataset:
        save_image(img,f"{augmented_data_dir}/img{img_num}.png")
        img_num += 1
        
print("done")

done
