In [14]:
import os
import random
import urllib

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image
from torchvision import transforms

import torch
import torchvision
from torchvision import datasets, transforms

from customDataset import ISICDataset

from data_exploration_helper import dataset_overview

In [15]:
#########################################################
# ==================== CONSTANTS ========================
#########################################################
# Training set 2018
TRAIN_2018_LABELS: str = "./data/ISIC2018_Training_GroundTruth.csv"
TRAIN_2018_ROOT_DIR: str = "./data/ISIC2018_Training_Input"

TEST_2018_LABELS: str = "./data/ISIC2018_Validation_GroundTruth.csv"
TEST_2018_ROOT_DIR: str = "./data/ISIC2018_Validation_Input"

IMAGE_FILE_TYPE: str = "jpg"
TRAIN_NROWS: int = None # SET TO None if you want all samples
TEST_NROWS: int = None # SET TO None if you want all samples

In [16]:
# Define image pre-processing steps
# Define the transforms to apply to the training data
augmentation_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(degrees=45),
])

pil_image_transform = transforms.Compose([
    transforms.ToPILImage(), # Removes error
])

In [17]:
# ================= DATASETS ================= #
print("Loading in datasets...")
# Training set 2018 - custom class
train_dataset_2018 = ISICDataset(
    csv_file=TRAIN_2018_LABELS, 
    root_dir=TRAIN_2018_ROOT_DIR, 
    transform=pil_image_transform,
    image_file_type=IMAGE_FILE_TYPE,
    nrows=TRAIN_NROWS # defines the number of rows used, utilized this for testing purposes
    )

# Test set 2018 - custom class
test_dataset_2018 = ISICDataset(
    csv_file=TEST_2018_LABELS, 
    root_dir=TEST_2018_ROOT_DIR, 
    transform=pil_image_transform,
    image_file_type=IMAGE_FILE_TYPE,
    nrows=TEST_NROWS
    )


Loading in datasets...


In [28]:

def augment_images(transform, csv_file:str, num_augmentations_per_class:int, root_dir:str, image_file_type:str="jpg"):
    # Load the annotations from the CSV file
    annotations = pd.read_csv(csv_file)
    
    # Get the names of all the different classes
    class_names = annotations.columns[1:]
    images = annotations.iloc[:, 0]
    
    # Create a dictionary to store the images for each class
    class_images = {}
    for class_name in class_names:
        class_images[class_name] = []
    
    # Iterate over the images and assign them to their corresponding class in the dictionary
    for image_path in images:
        img_path = f"{img_path}.{image_path}"
        image_name = os.path.basename(image_path)
        image_annotation = annotations[annotations['image'] == image_name]
        if len(image_annotation) > 0:
            image_class = image_annotation.iloc[0, 1:].idxmax()
            class_images[image_class].append(image_path)
    
    # Create a dictionary to store the augmented images for each class
    augmented_images = {}
    
    # Iterate over the classes
    for class_name in class_names:
        # Get the images for the class
        class_image_paths = class_images[class_name]
        
        # Create a list to store the augmented images for the class
        class_augmented_images = []
        
        # Iterate over the images for the class and augment them
        for image_path in class_image_paths:
            image = Image.open(image_path)
            for i in range(num_augmentations_per_class):
                augmented_image = transform(image)
                class_augmented_images.append(augmented_image)
        
        # Add the augmented images for the class to the dictionary
        augmented_images[class_name] = class_augmented_images
    
    return augmented_images


In [29]:

augment_images(augmentation_transform, TRAIN_2018_LABELS, 10)

0        ISIC_0024306
1        ISIC_0024307
2        ISIC_0024308
3        ISIC_0024309
4        ISIC_0024310
             ...     
10010    ISIC_0034316
10011    ISIC_0034317
10012    ISIC_0034318
10013    ISIC_0034319
10014    ISIC_0034320
Name: image, Length: 10015, dtype: object


TypeError: reduction operation 'argmax' not allowed for this dtype