<a href="https://colab.research.google.com/github/KonradGonrad/PyTorch-deep-learning/blob/main/04_pytorch_custom_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 04. PyTorch Custom Datasets

## 0. Importing PyTorch and setting up device-agnostic code

In [None]:
import torch
from torch import nn

# PyTorch 1.10 +
torch.__version__

In [None]:
# Setup device-agnostic code
DEVICE_DESTINATION = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE_DESTINATION

## 1. Get data

In [None]:
import requests
import zipfile
from pathlib import Path

# Setup path to data folder
data_path = Path("data/")
images_path = data_path / "pizza_steak_sushi"

# If the image folder doesn't exist, download it and prepare it
if images_path.is_dir():
  print(f"{images_path} already exist. Skipping download...")
else:
  print(f"Creating {images_path} path")
  images_path.mkdir(parents=True, exist_ok=True)

# Download pizza, steak and sushi data
with open(data_path / "pizza_steak_sushi.zip", 'wb') as f:
  request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
  print("Downloading pizza, steak, sushi data...")
  f.write(request.content)

# Unzip pizza, steak, sushi data
with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", 'r') as ziprep:
  print("Extracking pizza_steak_sushi data...")
  ziprep.extractall(images_path)


In [None]:
## 2. Becoming one with the data (Data preparation and data exploration)
import os

def walk_through_dir(dir_path):
  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"there are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

In [None]:
walk_through_dir(images_path)

In [None]:
# Setup training and testing part
train_dir = images_path / "train"
test_dir = images_path / "test"

train_dir, test_dir

## 2.1 Visualizing an image

1. Get all of the image paths
2. Pick a random image path using python's random.choice()
3. Get the image class name 'pathlib.Path.parent.stem'
4. Since we're working with images, let's open the image with Python's PIL
5. We'll then show the image and print metadata

In [None]:
import random
from PIL import Image

# Set seed
#random.seed(42)

# 1. Get all image paths
image_paths = list(images_path.glob('*/*/*.jpg'))

# 2. Pick a random image path
random_image = random.choice(image_paths)

# 3. Get the image class name
image_class = random_image.parent.stem

# 4. Open image with Python PIL
img = Image.open(random_image)

# 5. Print metadata
print(f"Random image path: {random_image}")
print(f"Random image class: {image_class}")
print(f"Image height: {img.height}")
print(f"Image width: {img.width}")
img

In [None]:
# Visualize image with matplotlib - mine approach
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torch
import torchvision

# Set seed
#random.seed(42)

# 1. Get all image paths
image_paths = list(images_path.glob('*/*/*.jpg'))

# 2. Pick a random image path
random_image = random.choice(image_paths)

# 3. Get the image class name
image_class = random_image.parent.stem

# 4. from random_path into 3 dimension (rgb) image
image = mpimg.imread(random_image)

# Visualize image with matplotlib

plt.imshow(image)
plt.axis(False)
plt.title(image_class)
plt.show()

In [None]:
# Visualize image with matplotlib - video approach
import numpy as np
import matplotlib.pyplot as plt

# 1. Get all image paths
image_paths = list(images_path.glob('*/*/*.jpg'))

# 2. Pick a random image path
random_image = random.choice(image_paths)

# 3. Get the image class name
image_class = random_image.parent.stem

# 4. Open image with Python PIL
img = Image.open(random_image)

# 5. Turn the image into an array
img_as_array = np.asarray(img)

# 6. Plot the image
plt.figure(figsize=(10,7))
plt.imshow(img_as_array)
plt.title(f"Image class: {image_class} | Image shape: {img_as_array.shape} -> [height, width, color channels] (HWC)")
plt.axis(False)
plt.show()

## 3. Transforming data

Before we can use our image data with PyTorch:
1. Turn your target data into tensors (in our case, numerical representation of our images)
2. Turn it into a `torch.utils.data.Dataset` and subsequently a `torch.utils.data.DataLoader`, we'l call these `Dataset` and `Dataloader`

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

### 3.1 Transforming data with `torchvision.transforms`

In [None]:
# Write a transform for image
data_transform = transforms.Compose([
    # Resize our images to 64x64
    transforms.Resize(size=(64,64)),
    # Flip the images randomly on the horizontal
    transforms.RandomHorizontalFlip(p = 0.5),
    # Turn the image into a torch tensor
    transforms.ToTensor()
])

In [None]:
data_transform(img).shape

In [None]:
def plot_transformed_images(image_paths: list, transform, n=3, seed=None):
  """
  Selects random images from a path of images and loads/transforms them
  then plots the original vs transformed version
  """
  if seed:
    random.seed(42)
  random_image_paths = random.sample(image_paths, k=n)
  for random_image in random_image_paths:
    with Image.open(random_image) as f:
      fig, ax = plt.subplots(nrows=1, ncols=2)
      ax[0].imshow(f)
      ax[0].axis(False)
      ax[0].set_title(f"Original size: {f.size}")

      transformed_f = transform(f).permute(1, 2, 0)
      ax[1].imshow(transformed_f)
      ax[1].axis("off")
      ax[1].set_title(f"Shape: {transformed_f.shape}")

      fig.suptitle(f"Class: {random_image.parent.stem}", fontsize=16)
plot_transformed_images(image_paths,
                        data_transform,
                        n=3,
                        seed=42)



## Option 1: Loading image data using ImageFolder
We can load image classification data using `torchvision.datasets.ImageFolder`

In [None]:
# Use ImageFolder to create dataset's
from torchvision import datasets
train_data = datasets.ImageFolder(root=train_dir,
                                  transform=data_transform,
                                  target_transform=None)

test_data = datasets.ImageFolder(root=test_dir,
                                 transform=data_transform,
                                 target_transform=None)

print(train_data, test_data)

In [None]:
# Get class names
class_names = train_data.classes
class_names

In [None]:
# Get class names as dict
class_dict = train_data.class_to_idx
class_dict

In [None]:
# Check the lengths of our dataset
len(train_data), len(test_data)

In [None]:
train_data.samples[0]

In [None]:
# Index on the train_data Dataset to get a single image and label
import random

random_idx = random.randint(0, len(train_data))
img, label = train_data[random_idx][0], train_data[random_idx][1]
print(f"Image tensor:\n {img}")
print(f"Image shape: {img.shape}")
print(f"Image datatype: {img.dtype}")
print(f"Image label: {label}")
print(f"Label datatype: {type(label)}")

In [None]:
img

In [None]:
print(f"Label: {label} which one is {class_names[label]}")

In [None]:
# Rearrange the order dimensions
img_permute = img.permute(1, 2, 0)
print(f"old shape: {img.shape} -> [color_channels, height, width]")
print(f"new shape: {img_permute.shape} -> [height, width, color_channels]")

# Plot the image
plt.figure(figsize=(10, 7))
plt.imshow(img_permute)
plt.axis("off")
plt.title(class_names[label], fontsize=14)
plt.show()

## 4.1 Turn loaded images into `DataLoaders's`

A `Dataloader` is going to help us turn our `Dataset`'s into iterables and we can see `batch_size` images at a time

In [None]:
from torch.utils.data import DataLoader
import os
BATCH_SIZE = 1

train_dataloader = DataLoader(dataset=train_data,
                              batch_size=BATCH_SIZE,
                              num_workers=os.cpu_count(),
                              shuffle=True)

test_dataloader = DataLoader(dataset=test_data,
                             batch_size=BATCH_SIZE,
                             num_workers=1,
                             shuffle=False)

train_dataloader, test_dataloader

In [None]:
len(train_dataloader), len(test_dataloader)

In [None]:
img, label = next(iter(train_dataloader))

print(f"image shape: {img.shape} -> [batch_size, color_channles, height, width]")
print(f"label shape: {label.shape}")

## 5. Option 2: Loading Image data with a custom `dataset`

1. Want to be able to load images from file
2. Want to be able to get class names from the dataset
3. Want to be able to get classes as dictionary from the dataset

Pros:
* Can create a `Dataset` out of almost anything
* Non limited to PyTorch pre-built `Dataset` functions

Cons:
* Even though you could create `Dataset` out of almost anything, it doesn't mean it will work
* Using a custom `Dataset` often results in us writing more code, which could be prone to errors or performance issues

All custom datasets in PyTorch, often subclass `torch.utils.data.Dataset`

In [None]:
import os
import pathlib
import torch

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from typing import Tuple, Dict, List

In [None]:
# Instance of torchvision.datasets.ImageFolder()
train_data.classes, train_data.class_to_idx

## 5.1 Creating a helper function to get class names

We want a function to:
1. Get the class names using 'os.scandir()' to traverse a taget directory (ideally the directory is in standart image classification format).
2. Raise and error if the class names aren't found (ig this happens, there might be something wrong with the directory structure)
3. Turn the class names into a dict and a list and return them

In [None]:
# Setup path for target directory
target_directory = train_dir
print(f"Target dir: {target_directory}")

# Get the class names from the target directory
class_names_found = sorted([entry.name for entry in list(os.scandir(target_directory))])
class_names_found

In [None]:
list(os.scandir(target_directory))

In [None]:
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
  """
  Finds the class folder names in a target directory.
  """
  # 1. Get the class names by scanning the target directory
  classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())

  # 2. Raise an error if class names could not be found
  if not classes:
    raise FileNotFoundError(f"Couldn't find any classes in {directory}...")

  # 3. Create a dictionary of index labels (computers prefer numbers rather than string as labels)
  class_to_idx = {class_name: i for i, class_name in enumerate(classes)}

  return classes, class_to_idx

In [None]:
find_classes(target_directory)

In [None]:
x = ['Konrad', 'Kamil', 'Wojtek']
list((name, i) for i, name in enumerate(x))

### 5.2 Create a custom `Dataset` to replicate `ImageFolder`

To create our own custom dataset, we want to:

1. Subclass `torch.utils.data.Dataset`
2. Init our subclass with a target directory (the directory we'd like to get data from) as well as a transform if we'd like to transform our data.
3. Create several attributes:
 * paths - paths of our images
 * transform - the transform we'd like to use
 * classes - a list of the target classes
 * class_to_idx - a dict of the target classes mapped to integer labels
4. Create a function to `load_images()`, this function will open an image
5. Overwrite the `__len__()` method to return the length of our dataset
6. Overwrite the `__getitem()__` method to return a given sample when passed an index

In [None]:
# Write a custom dataset class
from torch.utils.data import Dataset
import pathlib

# 1. Subclass torch.utils.data.Dataset
class ImageFolderCustom(Dataset):
  # 2. initialize our custom dataset
  def __init__(self,
               targ_dir,
               transform = None):
    # 3. Create class atributes
    self.paths = list(pathlib.Path(targ_dir).glob("*/*.jpg"))
    # Setup transform
    self.transform = transform
    # Create classes and class_to_idx
    self.class_names, self.class_to_idx = find_classes(targ_dir)

  # 4. Create a function to load images
  def load_images(self, index: int) -> Image.Image:
    "Opens an image via a path and returns it"
    image_path = self.paths[index]
    return Image.open(image_path)

  # 5. Overwrite the __len__()
  def __len__(self) -> int:
    "Returns the total number of samples"
    return len(self.paths)

  # 6. Overwrite the __getitem__()
  def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
    "Returns one sample of data, data and label (X, y)"
    img = self.load_images(index)
    class_name = self.paths[index].parent.name
    class_idx = self.class_to_idx[class_name]

    # Transform if necesarry
    if self.transform:
      return self.transform(img), class_idx
    else:
      return img, class_idx


In [None]:
# Create a transform
from torchvision import transforms

train_transforms = transforms.Compose([
    transforms.Resize(size=(64, 64)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

test_transforms = transforms.Compose([
    transforms.Resize(size=(64, 64)),
    transforms.ToTensor()
])

In [None]:
# Test out ImageFolderCustom
train_data_custom = ImageFolderCustom(targ_dir=train_dir,
                                      transform=train_transforms)

test_data_custom = ImageFolderCustom(targ_dir=test_dir,
                                     transform=test_transforms)

In [None]:
train_data_custom, test_data_custom

In [None]:
len(train_data), len(train_data_custom)

In [None]:
len(test_data), len(test_data_custom)

In [None]:
train_data_custom.class_names, train_data.classes

In [None]:
train_data_custom.class_to_idx, train_data.class_to_idx

In [None]:
train_data_custom.paths

In [None]:
# Check for equality between original ImageFolder Dataset and ImageFolderCustomDataset
print(train_data_custom.class_names == train_data.classes)
print(test_data_custom.class_names == test_data.classes)

### 5.3 Create a function to display random images

1. Take in a `Dataset` and a number of other parameters as class names and how many images to Visualize
2. To prevent the display getting out of hand, let's cap the number of images to see at 10
3. Set the random seed for reproducibility
4. Get a list of random sample indexes from the target dataset
5. Setup a matplotlib plot
6. Loop through the random sample images and plot them with plt.matplotlib
7. Make suer the dimensions of our image line up with matplotlib (HWC)


In [None]:
# 1. Create a function to take in a dataset
def display_random_images(dataset: torch.utils.data.Dataset,
                          classes: List[str] = None,
                          n: int = 10,
                          display_shape: bool = True,
                          seed: int = None):
  # 2. Adjust display if n is to high
  if n > 10:
    print(f"Due to too high n, what may cause problems with visualization, n is set to 10 instead of {n}")
    display_shape = False
    n = 10

  # 3. Set the seed
  if seed:
    random.seed(seed)

  # 4. Get random sample indexes
  random_sample_idx = random.sample(range(len(dataset)), k=n)

  # 5. Setup plot
  plt.figure(figsize=(16,8))

  # 6. Loop through random indexes and plot them with matplotlib
  for i, targ_image in enumerate(random_sample_idx):
    image, label = dataset[targ_image][0], dataset[targ_image][1]

    # 7.Adjust tensor dimensions for plotting
    image_adjust = image.permute(1, 2, 0)

    # Plot adjusted samples
    plt.subplot(1, n, i + 1)
    plt.imshow(image_adjust)
    plt.axis("off")

    if classes:
      title = f"Class: {classes[label]}"
      if display_shape:
        title = title + f"\nshape: {image_adjust.shape}"
    plt.title(title)



In [None]:
# Display random images
display_random_images(train_data,
                      n=4,
                      classes=class_names,
                      seed=None)

In [None]:
display_random_images(train_data_custom,
                      n=5,
                      classes=class_names,
                      seed=42)

### 5.4 Turn Custom loaded images into `DataLoader's`

In [None]:
from torch.utils.data import DataLoader
import os

BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()

train_dataloader_custom = DataLoader(dataset=train_data_custom,
                                     batch_size=BATCH_SIZE,
                                     num_workers=NUM_WORKERS,
                                     shuffle=True)

test_dataloader_custom = DataLoader(dataset=test_data_custom,
                                    batch_size=BATCH_SIZE,
                                    num_workers=NUM_WORKERS,
                                    shuffle=False)

train_dataloader_custom

In [None]:
# Get image and label from custom dataloader
img_custom, label_custom = next(iter(train_dataloader_custom))

# Print out the shapes
img_custom.shape, label_custom.shape

## 6. Other forms of transforms (data augmentation)

Data augmentation is the process fo artificially adding diversity to your training data.

In the case of image data, this may mean applying various image transformations to the training images

`It's something like looking on the same image but from different perspectives`

Let's take a look at one particular type of data augmentation used to train PyTorch vision models to state of the art levels


In [None]:
# Let's look at trivailaugment
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.TrivialAugmentWide(num_magnitude_bins=31),
    transforms.ToTensor()
])

test_transform = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor()
])

In [None]:
# Get all image paths
image_path_list = list(pathlib.Path(data_path / "pizza_steak_sushi").glob('*/*/*.jpg'))
image_path_list[:10]

In [None]:
# Plot random transformed images
plot_transformed_images(
    image_paths=image_path_list,
    transform=train_transform,
    n=3,
    seed=None
)

## 7. Model -: TinyVGG without data augmentation
Let's replicate TinyVGG architecture from the ccn explainer

### 7.1 Creating transforms and loading data for Model 0

In [None]:
# Create simple transform
simple_transform = transforms.Compose([
    transforms.Resize(size=(64, 64)),
    transforms.ToTensor()
])

In [None]:
train_and_test_path = data_path / "pizza_steak_sushi"

train_data = list(train_and_test_path.glob('train/*/*.jpg'))
test_data = list(train_and_test_path.glob('test/*/*.jpg'))

In [None]:
# 1. Load and transform data
from torchvision import datasets
train_data_simple = datasets.ImageFolder(root=train_dir,
                                         transform=simple_transform)
test_data_simple = datasets.ImageFolder(root=test_dir,
                                        transform=simple_transform)

# 2. Turn the datasets into DataLoaders
import os
from torch.utils.data import DataLoader

# Setup batch size and number of workers
BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()

# Create DataLoader's
train_dataloader_simple = DataLoader(train_data_simple,
                                     batch_size=BATCH_SIZE,
                                     shuffle=True,
                                     num_workers=NUM_WORKERS)

test_dataloader_simple = DataLoader(test_data_simple,
                                    batch_size=BATCH_SIZE,
                                    shuffle=False,
                                    num_workers=NUM_WORKERS)

In [None]:
### 7.2 Create TinyVGG model class
class TinyVGG(nn.Module):
  """
  Model architecture copying TinyVGG from cnn explainer
  """
  def __init__(self,
               input_shape: int,
               hidden_units: int,
               output_shape: int) -> None:
    super().__init__()
    self.conv_block_1 = nn.Sequential(
        nn.Conv2d(in_channels=input_shape,
                  out_channels=hidden_units,
                  kernel_size=3,
                  stride=1,
                  padding=0),
        nn.ReLU(),
        nn.Conv2d(in_channels=hidden_units,
                  out_channels=hidden_units,
                  kernel_size=3,
                  stride=1,
                  padding=0),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,
                     stride=2)
    )
    self.conv_block_2 = nn.Sequential(
        nn.Conv2d(in_channels=hidden_units,
                  out_channels=hidden_units,
                  kernel_size=3,
                  stride=1,
                  padding=0),
        nn.ReLU(),
        nn.Conv2d(in_channels=hidden_units,
                  out_channels=hidden_units,
                  kernel_size=3,
                  stride=1,
                  padding=0),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,
                     stride=2)
    )
    self.classifier_layer = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features=hidden_units*13*13,
                  out_features=output_shape)
    )
  def forward(self, x):
    x = self.conv_block_1(x)
    #print(x.shape)
    x = self.conv_block_2(x)
    #print(x.shape)
    x = self.classifier_layer(x)
    #print(x.shape)
    return x



In [None]:
torch.manual_seed(42)
model_0 = TinyVGG(input_shape=3,
                  hidden_units=10,
                  output_shape=len(class_names)).to(DEVICE_DESTINATION)
model_0

In [None]:
### 7.3 Try a forward pass on a single image (to test the model)
image_batch, label_batch = next(iter(train_dataloader_simple))
image_batch.shape

In [None]:
# Try a forward pass
model_0(image_batch.to(DEVICE_DESTINATION))

In [None]:
### 7.4 Use `torch.info` to get an idea of the shapes going through our model

In [None]:
# Install torchinfo, import if it's available
try:
  import torchinfo
except:
  !pip install torchinfo
  import torchinfo

from torchinfo import summary
summary(model_0, input_size=[1, 3, 64, 64])

## 7.5 Create train and test loop functions

* `train_step()` - takes in a model and dataloader and trains the model on the dataloader
* `test_step()` - takes in a model and dataloader and evaluates the model on the dataloader

In [None]:
# Create train_step()
def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim,
               device: torch.device):
  # Put the model in train mode
  model.train()

  # Setup train loss and train accuracy values
  acc, loss = 0, 0

  # Loop through data loader data batches
  for batch, (X, y) in enumerate(dataloader):
    # Send data to the target device
    X, y = X.to(device), y.to(device)

    # 1. Forward pass
    y_pred = model(X)

    # 2. Calculate the loss
    loss = loss_fn(y_pred, y)
    loss += loss.item()

    # 3. Optimzier zero grad
    optimizer.zero_grad()

    # 4. Loss backward
    loss.backward()

    # 5. Optimizer step
    optimizer.step()

    # Calculate accuracy
    y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
    acc += (y_pred_class == y).sum().item()/len(y_pred)
  # Adjust metrics to get average loss and accuracy per batch
  loss /= len(dataloader)
  acc /= len(dataloader)
  return loss, acc

In [None]:
# Create a test step
def test_step(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn: torch.nn.Module,
              device: torch.device):
  # Put model in eval mode
  model.eval()

  # Setup test loss and test accuracy values
  loss, acc = 0, 0

  # Turn on inference mode
  with torch.inference_mode():
    # Loop through dataloader batches
    for batch, (X, y) in enumerate(dataloader):
      # Send data to device
      X, y = X.to(device), y.to(device)

      # 1. Forward pass
      test_pred_logits = model(X)

      # 2. Calculate the loss
      loss += loss_fn(test_pred_logits, y).item()

      # 3. Calculate the accuracy
      test_pred_labels = test_pred_logits.argmax(dim=1)
      acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))

  # Adjust metrics to get average loss
  loss /= len(dataloader)
  acc /= len(dataloader)
  return loss, acc

### 7.6 Creating `train()` function to combine `train_step()` and `test_step()`


In [None]:
from tqdm.auto import tqdm

# 1. Create a train function that takes in various model parameters + optimizer + dataloader + etc

def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim,
          loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),
          epochs: int = 5,
          device: torch.device = DEVICE_DESTINATION):
  # 2. Create empty results dictionary
  results = {"train_loss": [],
             "train_acc": [],
             "test_loss": [],
             "test_acc": []}
  # 3. Loop through training and testing steps for a number of epochs
  for epoch in tqdm(range(epochs)):
    train_loss, train_acc = train_step(model=model,
                                       dataloader=train_dataloader,
                                       loss_fn=loss_fn,
                                       optimizer=optimizer,
                                       device=device)
    test_loss, test_acc = test_step(model=model,
                                    dataloader=test_dataloader,
                                    loss_fn=loss_fn,
                                    device=device)
    # 4. Print out what's happening
    print(f"Epoch: {epoch} train_loss: {train_loss:.2f} train_acc: {train_acc:.2f} | test_loss: {test_loss:.2f} test_acc: {test_acc:.2f}")

    # 5. Update results dictionary
    results['train_loss'].append(train_loss)
    results['train_acc'].append(train_acc)
    results['test_loss'].append(test_loss)
    results['test_acc'].append(test_acc)

  # Return the filled results at the end of the epochs
  return results

### 7.7 Train and evaluate model 0


In [None]:
# set random seeds
torch.manual_seed(42)
torch.cuda.manual_seed(42)

# Set number of epochs
NUM_EPOCHS = 10

# Recreate an instance of TinyVGG
model_0 = TinyVGG(input_shape=3,
                  hidden_units=10,
                  output_shape=len(train_data_simple.classes)).to(DEVICE_DESTINATION)

# Setup loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_0.parameters(),
                             lr=0.001)

# Start the timer
from timeit import default_timer as timer
start_time = timer()

# Train model_0
model_0_results = train(model=model_0,
                        train_dataloader=train_dataloader_simple,
                        test_dataloader=test_dataloader_simple,
                        optimizer=optimizer,
                        loss_fn=loss_fn,
                        epochs=NUM_EPOCHS,
                        device=DEVICE_DESTINATION)

# End the timer and print out how long it took
end_time = timer()
print(f"Total training time: {end_time-start_time:.2f}")