In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import datasets, models
from torchvision.transforms import v2
from torchvision.io import read_image

from sklearn.model_selection import train_test_split

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os

from PIL import Image

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)

<torch._C.Generator at 0x16dff534b70>

In [15]:
class AkuDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        sample = {
            "X": self.X[idx],
            "y": self.y[idx]
        }

        if self.transform:
            sample["X"] = self.transform(sample["X"])

        return sample

def loadAkuDataset(labels_file, root_dir, data_split=(0.7, 0.2, 0.1), transform = None):
    """
    @param labels_file: path to the csv file containing the labels for each image
    @param root_dir: path to the dir containing the original and augmented image folders
    @param data_split: three elements tuple containing the ratio of images for train, validation and test sets
    @param transform: a concatenation of trasformations to apply to the images
    @return a three element tuple containing the datasets for the training, validation and test sets
    """
    X = []
    y = []

    # read the labels from the csv file
    labels_df = pd.read_csv(labels_file)
    labels_df.set_index("Unnamed: 0", inplace=True)

    # reading the original images
    original_dir = f"{root_dir}/original"
    for filename in os.listdir(original_dir):
        path = f"{original_dir}/{filename}"
        X.append( Image.open(path) )
        y.append(labels_df.loc[filename].values)

    # reading the augmented images
    augmented_dir = f"{root_dir}/augmented"
    for filename in os.listdir(augmented_dir):
        path = f"{augmented_dir}/{filename}"
        X.append( Image.open(path) )
        y.append(labels_df.loc[filename].values)

    # split the data in train, validation and test sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size = 0.1, random_state=42)

    return AkuDataset(X_train, y_train, transform), AkuDataset(X_val, y_val, transform), AkuDataset(X_test, y_test, transform)

In [16]:
labels_file = "./data/overall_labels.csv"
root_dir = "./data"
BATCH_SIZE = 4

# creating the datasets
train_dataset, val_dataset, test_dataset = loadAkuDataset(
    labels_file, 
    root_dir, 
    transform = v2.Compose([
        v2.Resize(400),
        v2.CenterCrop(256),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

# creating the data loaders
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)