# 04. Pytorch Custom Datasets

**Domain libraries:**
* torchvision
* torchtext
* torchaudio
* TorchRec

## 0. Importing PyTorch and setting up device-agnostic code

In [1]:
import torch
from torch import nn

# Device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"

torch_version = torch.__version__
index = torch_version.index("+")

print(f"Torch version: {torch.__version__[:index]}")
print(f"Device: {device.upper()}")

Torch version: 2.0.1
Device: CUDA


In [2]:
# Checking the GPU
if torch.cuda.is_available():
  !nvidia-smi
else:
  print("No GPU available at the moment.")

Mon Jun 19 22:21:16 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   51C    P8    13W /  70W |      3MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## 1. Get DATA

Food101 - 101 different calsses of food - 1000 images pers class (750 for training, 250 for testing)

Our dataset - 3 classes of food (pizza, steak, sushi) - 10% of Food101 images for each class (~75% for training, ~25% for testing)

In [3]:
import requests
import zipfile
from pathlib import Path

# Set up to a data folder
data_path = Path('data')
image_path = data_path / "pizza_steak_sushi"

if image_path.is_dir():
  print(f"{image_path} directory already exists, skipping download...")
else:
  print(f"{image_path} does not exist, creating one...")
  image_path.mkdir(parents = True, exist_ok = True)

# Download pizza_steak_sushi.zip data
with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
  request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
  print("Downloading pizza_steak_sushi.zip data...")
  f.write(request.content)

with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
  print("Unzipping pizza_steak_sushi.zip data...")
  zip_ref.extractall(image_path)

data/pizza_steak_sushi does not exist, creating one...
Downloading pizza_steak_sushi.zip data...
Unzipping pizza_steak_sushi.zip data...


## 2. Data preparation and exploration

In [4]:
import os
def walk_through_dir(dir_path):
  """Walks through dir_path returning it's contents."""
  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

In [5]:
walk_through_dir(image_path)

There are 2 directories and 0 images in 'data/pizza_steak_sushi'.
There are 3 directories and 0 images in 'data/pizza_steak_sushi/train'.
There are 0 directories and 72 images in 'data/pizza_steak_sushi/train/sushi'.
There are 0 directories and 75 images in 'data/pizza_steak_sushi/train/steak'.
There are 0 directories and 78 images in 'data/pizza_steak_sushi/train/pizza'.
There are 3 directories and 0 images in 'data/pizza_steak_sushi/test'.
There are 0 directories and 31 images in 'data/pizza_steak_sushi/test/sushi'.
There are 0 directories and 19 images in 'data/pizza_steak_sushi/test/steak'.
There are 0 directories and 25 images in 'data/pizza_steak_sushi/test/pizza'.


In [6]:
# Setup train and testing paths

train_dir = image_path / "train"
test_dir = image_path/ "test"

train_dir, test_dir

(PosixPath('data/pizza_steak_sushi/train'),
 PosixPath('data/pizza_steak_sushi/test'))

### 2.1 Visualizing and image

1. Get all of the image paths
2. Pick a random image path using Python's random.choice()
3. Get the image class name using `pathlib.Path.parent.stem`
4. Open the image with Python PIL
5. Show the image and print metadata

In [7]:
import random
from PIL import Image

# Set seed
random.seed(42)

# 1. Get all image paths
image_path_list = list(image_path.glob("*/*/*.jpg"))

# 2. Pick a random image path
random

<module 'random' from '/usr/lib/python3.10/random.py'>