# PyTorch Custom Datasets

In [None]:
#!pip install torch
#!pip install Pillow

## 0. Importing PyTorch and setting up device-agnostic code

In [None]:
import torch
from torch import nn

# Note: this notebook requires torch >= 1.10.0
torch.__version__

In [None]:
# Setup device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

## 1. Get data

In [None]:
# Import Operating System
import os
from pathlib import Path

# Setup path to data folder
ipynb_path = os.path.dirname(os.path.realpath("__file__"))
data_path = Path("\\".join(ipynb_path.split("\\")[:-2])) / "datasets"
image_path = data_path / "custom"

# If the image folder doesn't exist, prepare it...
if image_path.is_dir():
    print(f"{image_path} directory exists.")
else:
    print(f"Did not find {image_path} directory, creating one...")
    image_path.mkdir(parents=True, exist_ok=True)

## 2. Become one with the data (data preparation)
```
custom/ <- overall dataset folder
    train/ <- training images
        class01/ <- class name as folder name
            image01.jpeg
            image02.jpeg
            ...
        class02/
            image24.jpeg
            image25.jpeg
            ...
        class03/
            image37.jpeg
            ...
    test/ <- testing images
        class01/
            image101.jpeg
            image102.jpeg
            ...
        class02/
            image154.jpeg
            image155.jpeg
            ...
        class03/
            image167.jpeg
            ...
```
[`os.walk()`](https://docs.python.org/3/library/os.html#os.walk). 

In [None]:
import os
def walk_through_dir(dir_path):
  """
  Walks through dir_path returning its contents.
  Args:
    dir_path (str or pathlib.Path): target directory
  
  Returns:
    A print out of:
      number of subdiretories in dir_path
      number of images (files) in each subdirectory
      name of each subdirectory
  """
  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

In [None]:
walk_through_dir(image_path)

In [None]:
# Setup train and testing paths
train_dir = image_path / "train"
test_dir = image_path / "test"

train_dir, test_dir

### 2.1 Visualize an image

In [None]:
import random
from PIL import Image

# Set seed
random.seed(42) # <- try changing this and see what happens

# 1. Get all image paths (* means "any combination")
image_path_list = list(image_path.glob("*/*/*.png"))

if len(image_path_list):
    # 2. Get random image path
    random_image_path = random.choice(image_path_list)

    # 3. Get image class from path name (the image class is the name of the directory where the image is stored)
    image_class = random_image_path.parent.stem

    # 4. Open image
    img = Image.open(random_image_path)

    # 5. Print metadata
    print(f"Random image path: {random_image_path}")
    print(f"Image class: {image_class}")
    print(f"Image height: {img.height}") 
    print(f"Image width: {img.width}")
    display(img)