## Pytorch Custom Datasets

In [3]:
# Main import statements
import torch
from torch import nn

torch.__version__

'2.3.0+cu118'

In [2]:
# Device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"

## 1. Get Data

Getting a dataset that is a subset from Food101 dataset.

Food101 dataset is a large dataset with 101 classes of food. We will use a subset of this dataset with 3 classes and 10% of the original dataset.

75 images per class will be used for training and 25 images per class will be used for validation.

It`s important to note that the dataset is already split into training and validation sets.

Also, when working with ML projects, it is important to start with a small dataset to test the code and the model. After that, we can increase the dataset size.

The whole point is to make sure that the code is working properly before running it on a large dataset.

In [7]:
import requests
import zipfile
from pathlib import Path

# Setup the path for datafolder
DATA_PATH = Path("data")
image_path = DATA_PATH / "pizza_steak_sushi"

# If the image folder does not exist, download the data
if image_path.is_dir():
    print("Data already downloaded.")
else:
    print("Downloading data...")
    # Make the directory
    image_path.mkdir(parents=True, exist_ok=True)

    # Download the data
    with open(DATA_PATH / "pizza_steak_sushi.zip", "wb") as f:
        url = "https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip"
        response = requests.get(url= url)
        f.write(response.content)

    # Unzip the data
    with zipfile.ZipFile(DATA_PATH / "pizza_steak_sushi.zip", "r") as zip_ref:
        print("Unzipping...")
        zip_ref.extractall(image_path)

Downloading data...




Unzipping...


## 2. Data preparation and exploration

In [8]:
import os
def walk_through_dir(dir_path):
    """
    Walks through dir_path returning its directories and files
    dir_path: (str) valid path to target directory
    
    """
    for dirpath, dirnames, filenames in os.walk(dir_path):
        """
        dirpath: (str) path to current directory
        dirnames: (list) list of directories in current directory
        filenames: (list) list of files in current directory
        """
        print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

In [9]:
walk_through_dir(image_path)

There are 2 directories and 0 images in 'data/pizza_steak_sushi'.
There are 3 directories and 0 images in 'data/pizza_steak_sushi/train'.
There are 0 directories and 72 images in 'data/pizza_steak_sushi/train/sushi'.
There are 0 directories and 78 images in 'data/pizza_steak_sushi/train/pizza'.
There are 0 directories and 75 images in 'data/pizza_steak_sushi/train/steak'.
There are 3 directories and 0 images in 'data/pizza_steak_sushi/test'.
There are 0 directories and 31 images in 'data/pizza_steak_sushi/test/sushi'.
There are 0 directories and 25 images in 'data/pizza_steak_sushi/test/pizza'.
There are 0 directories and 19 images in 'data/pizza_steak_sushi/test/steak'.


In [11]:
# Setup train and testing directories
train_dir = image_path / "train"
test_dir = image_path / "test"

train_dir, test_dir

(PosixPath('data/pizza_steak_sushi/train'),
 PosixPath('data/pizza_steak_sushi/test'))