In [1]:
import os
import torch
import data_setup, engine, model_builder, utils

from torchvision import transforms

NUM_EPOCHS = 5
BATCH_SIZE = 32
HIDDEN_UNITS = 10
LEARNING_RATE = 0.001

# setup directories
train_dir = "data/pizza_steak_sushi/train"
test_dir = "data/pizza_steak_sushi/test"

# setup target device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# create transforms
train_transform = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.TrivialAugmentWide(num_magnitude_bins=31),
    transforms.ToTensor()
])

test_transform = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor()
])

# create dataloader with data_setup.py
train_dataloader, test_dataloader, class_names = data_setup.create_dataloader(
    train_dir = train_dir,
    test_dir = test_dir,
    train_transform = train_transform,
    test_transform = test_transform,
    batch_size = BATCH_SIZE
)

# create model with help from model_builder.py
model = model_builder.TinyVGG(
    input_shape=3,
    hidden_units=HIDDEN_UNITS,
    output_shape=len(class_names)
).to(device)

# set up loss and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=LEARNING_RATE
)

engine.train(
    model=model,
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    epochs=NUM_EPOCHS,
    device=device 
)

# Save the model with utils.py
utils.save_model(
    model=model,
    target_dir="models",
    model_name="05_TinyVGG_model_1.pth"
)

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 1.1030 | train_accuracy: 0.2461 | test_loss: 1.1045 | test_accuracy: 0.1979
Epoch: 2 | train_loss: 1.1044 | train_accuracy: 0.2695 | test_loss: 1.1049 | test_accuracy: 0.2604
Epoch: 3 | train_loss: 1.0997 | train_accuracy: 0.3047 | test_loss: 1.0993 | test_accuracy: 0.2604
Epoch: 4 | train_loss: 1.0922 | train_accuracy: 0.4219 | test_loss: 1.0856 | test_accuracy: 0.2396
Epoch: 5 | train_loss: 1.0843 | train_accuracy: 0.4258 | test_loss: 1.0765 | test_accuracy: 0.2604
[INFO] Saving model to: models\05_TinyVGG_model_1.pth
