<a href="https://colab.research.google.com/github/KonradGonrad/PyTorch-deep-learning/blob/main/04_pytorch_custom_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 04. PyTorch Custom Datasets

## 0. Importing PyTorch and setting up device-agnostic code

In [None]:
import torch
from torch import nn

# PyTorch 1.10 +
torch.__version__

In [None]:
# Setup device-agnostic code
DEVICE_DESTINATION = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE_DESTINATION

## 1. Get data

In [None]:
import requests
import zipfile
from pathlib import Path

# Setup path to data folder
data_path = Path("data/")
images_path = data_path / "pizza_steak_sushi"

# If the image folder doesn't exist, download it and prepare it
if images_path.is_dir():
  print(f"{images_path} already exist. Skipping download...")
else:
  print(f"Creating {images_path} path")
  images_path.mkdir(parents=True, exist_ok=True)

# Download pizza, steak and sushi data
with open(data_path / "pizza_steak_sushi.zip", 'wb') as f:
  request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
  print("Downloading pizza, steak, sushi data...")
  f.write(request.content)

# Unzip pizza, steak, sushi data
with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", 'r') as ziprep:
  print("Extracking pizza_steak_sushi data...")
  ziprep.extractall(images_path)


In [None]:
## 2. Becoming one with the data (Data preparation and data exploration)
import os

def walk_through_dir(dir_path):
  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"there are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

In [None]:
walk_through_dir(images_path)

In [None]:
# Setup training and testing part
train_dir = images_path / "train"
test_dir = images_path / "test"

train_dir, test_dir

## 2.1 Visualizing an image

1. Get all of the image paths
2. Pick a random image path using python's random.choice()
3. Get the image class name 'pathlib.Path.parent.stem'
4. Since we're working with images, let's open the image with Python's PIL
5. We'll then show the image and print metadata

In [None]:
import random
from PIL import Image

# Set seed
#random.seed(42)

# 1. Get all image paths
image_paths = list(images_path.glob('*/*/*.jpg'))

# 2. Pick a random image path
random_image = random.choice(image_paths)

# 3. Get the image class name
image_class = random_image.parent.stem

# 4. Open image with Python PIL
img = Image.open(random_image)

# 5. Print metadata
print(f"Random image path: {random_image}")
print(f"Random image class: {image_class}")
print(f"Image height: {img.height}")
print(f"Image width: {img.width}")
img

In [None]:
# Visualize image with matplotlib - mine approach
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torch
import torchvision

# Set seed
#random.seed(42)

# 1. Get all image paths
image_paths = list(images_path.glob('*/*/*.jpg'))

# 2. Pick a random image path
random_image = random.choice(image_paths)

# 3. Get the image class name
image_class = random_image.parent.stem

# 4. from random_path into 3 dimension (rgb) image
image = mpimg.imread(random_image)

# Visualize image with matplotlib

plt.imshow(image)
plt.axis(False)
plt.title(image_class)
plt.show()

In [None]:
# Visualize image with matplotlib - video approach
import numpy as np
import matplotlib.pyplot as plt

# 1. Get all image paths
image_paths = list(images_path.glob('*/*/*.jpg'))

# 2. Pick a random image path
random_image = random.choice(image_paths)

# 3. Get the image class name
image_class = random_image.parent.stem

# 4. Open image with Python PIL
img = Image.open(random_image)

# 5. Turn the image into an array
img_as_array = np.asarray(img)

# 6. Plot the image
plt.figure(figsize=(10,7))
plt.imshow(img_as_array)
plt.title(f"Image class: {image_class} | Image shape: {img_as_array.shape} -> [height, width, color channels] (HWC)")
plt.axis(False)
plt.show()

## 3. Transforming data

Before we can use our image data with PyTorch:
1. Turn your target data into tensors (in our case, numerical representation of our images)
2. Turn it into a `torch.utils.data.Dataset` and subsequently a `torch.utils.data.DataLoader`, we'l call these `Dataset` and `Dataloader`

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

### 3.1 Transforming data with `torchvision.transforms`

In [None]:
# Write a transform for image
data_transform = transforms.Compose([
    # Resize our images to 64x64
    transforms.Resize(size=(64,64)),
    # Flip the images randomly on the horizontal
    transforms.RandomHorizontalFlip(p = 0.5),
    # Turn the image into a torch tensor
    transforms.ToTensor()
])

In [None]:
data_transform(img).shape

In [None]:
def plot_transformed_images(image_paths: list, transform, n=3, seed=None):
  """
  Selects random images from a path of images and loads/transforms them
  then plots the original vs transformed version
  """
  if seed:
    random.seed(42)
  random_image_paths = random.sample(image_paths, k=n)
  for random_image in random_image_paths:
    with Image.open(random_image) as f:
      fig, ax = plt.subplots(nrows=1, ncols=2)
      ax[0].imshow(f)
      ax[0].axis(False)
      ax[0].set_title(f"Original size: {f.size}")

      transformed_f = transform(f).permute(1, 2, 0)
      ax[1].imshow(transformed_f)
      ax[1].axis("off")
      ax[1].set_title(f"Shape: {transformed_f.shape}")

      fig.suptitle(f"Class: {random_image.parent.stem}", fontsize=16)
plot_transformed_images(image_paths,
                        data_transform,
                        n=3,
                        seed=42)



## Option 1: Loading image data using ImageFolder
We can load image classification data using `torchvision.datasets.ImageFolder`

In [None]:
# Use ImageFolder to create dataset's
from torchvision import datasets
train_data = datasets.ImageFolder(root=train_dir,
                                  transform=data_transform,
                                  target_transform=None)

test_data = datasets.ImageFolder(root=test_dir,
                                 transform=data_transform,
                                 target_transform=None)

print(train_data, test_data)

In [None]:
# Get class names
class_names = train_data.classes
class_names

In [None]:
# Get class names as dict
class_dict = train_data.class_to_idx
class_dict

In [None]:
# Check the lengths of our dataset
len(train_data), len(test_data)

In [None]:
train_data.samples[0]

In [None]:
# Index on the train_data Dataset to get a single image and label
import random

random_idx = random.randint(0, len(train_data))
img, label = train_data[random_idx][0], train_data[random_idx][1]
print(f"Image tensor:\n {img}")
print(f"Image shape: {img.shape}")
print(f"Image datatype: {img.dtype}")
print(f"Image label: {label}")
print(f"Label datatype: {type(label)}")

In [None]:
img

In [None]:
print(f"Label: {label} which one is {class_names[label]}")

In [None]:
# Rearrange the order dimensions
img_permute = img.permute(1, 2, 0)
print(f"old shape: {img.shape} -> [color_channels, height, width]")
print(f"new shape: {img_permute.shape} -> [height, width, color_channels]")

# Plot the image
plt.figure(figsize=(10, 7))
plt.imshow(img_permute)
plt.axis("off")
plt.title(class_names[label], fontsize=14)
plt.show()

## 4.1 Turn loaded images into `DataLoaders's`

A `Dataloader` is going to help us turn our `Dataset`'s into iterables and we can see `batch_size` images at a time

In [None]:
from torch.utils.data import DataLoader
import os
BATCH_SIZE = 1

train_dataloader = DataLoader(dataset=train_data,
                              batch_size=BATCH_SIZE,
                              num_workers=os.cpu_count(),
                              shuffle=True)

test_dataloader = DataLoader(dataset=test_data,
                             batch_size=BATCH_SIZE,
                             num_workers=1,
                             shuffle=False)

train_dataloader, test_dataloader

In [None]:
len(train_dataloader), len(test_dataloader)

In [None]:
img, label = next(iter(train_dataloader))

print(f"image shape: {img.shape} -> [batch_size, color_channles, height, width]")
print(f"label shape: {label.shape}")

## 5. Option 2: Loading Image data with a custom `dataset`

1. Want to be able to load images from file
2. Want to be able to get class names from the dataset
3. Want to be able to get classes as dictionary from the dataset

Pros:
* Can create a `Dataset` out of almost anything
* Non limited to PyTorch pre-built `Dataset` functions

Cons:
* Even though you could create `Dataset` out of almost anything, it doesn't mean it will work
* Using a custom `Dataset` often results in us writing more code, which could be prone to errors or performance issues

All custom datasets in PyTorch, often subclass `torch.utils.data.Dataset`

In [None]:
import os
import pathlib
import torch

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from typing import Tuple, Dict, List

In [None]:
# Instance of torchvision.datasets.ImageFolder()
train_data.classes, train_data.class_to_idx

## 5.1 Creating a helper function to get class names

We want a function to:
1. Get the class names using 'os.scandir()' to traverse a taget directory (ideally the directory is in standart image classification format).
2. Raise and error if the class names aren't found (ig this happens, there might be something wrong with the directory structure)
3. Turn the class names into a dict and a list and return them

In [None]:
# Setup path for target directory
target_directory = train_dir
print(f"Target dir: {target_directory}")

# Get the class names from the target directory
class_names_found = sorted([entry.name for entry in list(os.scandir(target_directory))])
class_names_found

In [None]:
list(os.scandir(target_directory))

In [None]:
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
  """
  Finds the class folder names in a target directory.
  """
  # 1. Get the class names by scanning the target directory
  classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())

  # 2. Raise an error if class names could not be found
  if not classes:
    raise FileNotFoundError(f"Couldn't find any classes in {directory}...")

  # 3. Create a dictionary of index labels (computers prefer numbers rather than string as labels)
  class_to_idx = {class_name: i for i, class_name in enumerate(classes)}

  return classes, class_to_idx

In [None]:
find_classes(target_directory)

In [None]:
x = ['Konrad', 'Kamil', 'Wojtek']
list((name, i) for i, name in enumerate(x))