In [1]:
#import all the necessary modules
import os

from fastai.vision.all import *
from fastai.vision import *

import pandas as pd

In [2]:
def create_labels_csv_file(dataset_path: str) -> None:
    class_names = [] # get a list of sub_directories in dataset as class names
    for folder_name in os.listdir(dataset_path):
        class_names.append(folder_name)

    image_paths = []
    labels = []

    # loop through each class i.e train and valid and gather image paths and labels
    for i in range(len(class_names)):
        class_name = class_names[i]

        if '.csv' in os.path.basename(class_name):# skip if it's a CSV file 
            continue
        image_names = os.listdir(
            os.path.join(
                dataset_path,
                class_name
            )
        )

        # collect image paths and corresponding labels
        for image_name in image_names:
            image_path = os.path.join(
                dataset_path,
                class_name,
                image_name
            )
            image_paths.append(image_path)
            labels.append(i)

    # create a DataFrame from the collected data
    df = pd.DataFrame(
        {
            'name': image_paths,
            'label': labels
        }
    )

    # save the DataFrame in a CSV file
    df.to_csv(
        path_or_buf=os.path.join(
            dataset_path,
            'labels.csv'
        ),
        index=False
    )

In [3]:
# create CSV file for train dataset
create_labels_csv_file(
    dataset_path=os.path.join(
        'data',
        'train'
    )
)

# create CSV file for valid dataset
create_labels_csv_file(
    dataset_path=os.path.join(
        'data',
        'valid'
    )
)

In [4]:
# Define data loader
dls = ImageDataLoaders.from_folder(
    path='data',         # Path to the main dataset folder
    train='train',       # Path to the train folder
    valid='valid',       # Path to the valid folder
    item_tfms=Resize(224),  # Resize images to 224x224 pixels
    bs=16,                  # Batch size
    batch_tfms=[Normalize.from_stats(*imagenet_stats), RandTransform()]  # Batch transformations(augmentation)
)

  return getattr(torch, 'has_mps', False)
