In [1]:
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchinfo import summary
import modules.data_setup as data_setup
from modules.engine import train_step,test_step,train
import os


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Device: cpu


In [2]:
version = torch.__version__
print(f"PyTorch Version: {version}")

PyTorch Version: 2.8.0


In [3]:
import requests
from pathlib import Path
import zipfile

data_path = Path('data')
image_path = data_path / 'pizza_steak_sushi_20_percent'

if image_path.is_dir():
    print(f'{image_path} exists')
else:
    print(f'{image_path} does not exist, creating...')
    image_path.mkdir(parents=True, exist_ok=True)

    with open(data_path / 'pizza_steak_sushi_20_percent.zip','wb') as f:
         request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi_20_percent.zip")
         print('Downloading pizza_steak_sushi_20_percent data')
         f.write(request.content)
    with zipfile.ZipFile(data_path / 'pizza_steak_sushi_20_percent.zip','r') as zip_ref:
        print("Extracting pizza_steak_sushi_20_percent.zip")
        zip_ref.extractall(image_path)

    os.remove(data_path / 'pizza_steak_sushi_20_percent.zip')

data/pizza_steak_sushi_20_percent exists


In [4]:
def set_seed(seed:int=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

In [5]:
train_dir = image_path/'train'
test_dir = image_path/'test'
normalize = transforms.Normalize(
    mean=[0.485,0.456,0.406],
    std=[0.229,0.224,0.225]
)
simple_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    normalize
])
print(simple_transform)


Compose(
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)


In [6]:
data_augmentation_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
    normalize
])
print(data_augmentation_transform)

Compose(
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
    TrivialAugmentWide(num_magnitude_bins=31, interpolation=InterpolationMode.NEAREST, fill=None)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)


In [7]:
train_dataloader,test_dataloader,class_names = data_setup.create_dataloaders(
                                                train_dir=train_dir,
                                                test_dir=test_dir,
                                                train_transform=simple_transform,
                                                test_transform=simple_transform,
                                                batch_size=32)
train_dataloader,test_dataloader,class_names

(<torch.utils.data.dataloader.DataLoader at 0x15e067e00>,
 <torch.utils.data.dataloader.DataLoader at 0x10e8f53a0>,
 ['pizza', 'steak', 'sushi'])

In [8]:
train_dataloader_with_aug,test_dataloader,class_names = data_setup.create_dataloaders(
    train_dir=train_dir,
    test_dir=test_dir,
    train_transform=data_augmentation_transform,
    test_transform=simple_transform,
    batch_size=32
)
train_dataloader_with_aug,test_dataloader,class_names

(<torch.utils.data.dataloader.DataLoader at 0x105efaea0>,
 <torch.utils.data.dataloader.DataLoader at 0x15e43d3a0>,
 ['pizza', 'steak', 'sushi'])

In [9]:
weights =  torchvision.models.EfficientNet_B1_Weights.DEFAULT
model = torchvision.models.efficientnet_b1(weights=weights)
model_summary = summary(
    model,
    input_size=(32,3,224,224),
    verbose=0,
    col_names = ['input_size', 'output_size', 'num_params','trainable'],
    col_width=20,
    row_settings=['var_names']
)
model_summary

Downloading: "https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth" to /Users/alex/.cache/torch/hub/checkpoints/efficientnet_b1-c27df63c.pth


100%|██████████| 30.1M/30.1M [00:00<00:00, 101MB/s] 


Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
EfficientNet (EfficientNet)                                  [32, 3, 224, 224]    [32, 1000]           --                   True
├─Sequential (features)                                      [32, 3, 224, 224]    [32, 1280, 7, 7]     --                   True
│    └─Conv2dNormActivation (0)                              [32, 3, 224, 224]    [32, 32, 112, 112]   --                   True
│    │    └─Conv2d (0)                                       [32, 3, 224, 224]    [32, 32, 112, 112]   864                  True
│    │    └─BatchNorm2d (1)                                  [32, 32, 112, 112]   [32, 32, 112, 112]   64                   True
│    │    └─SiLU (2)                                         [32, 32, 112, 112]   [32, 32, 112, 112]   --                   --
│    └─Sequential (1)                                        [32, 32, 112, 112]   [32, 16, 112

In [10]:
for param in model.features.parameters():
    param.requires_grad=False

set_seed()
model.classifier = nn.Sequential(
    nn.Dropout(p=0.2,inplace=True),
    nn.Linear(in_features=1280,out_features=len(class_names),bias=True)
)

In [11]:
model_summary = summary(
    model,
    input_size=(32,3,224,224),
    verbose=0,
    col_names = ['input_size', 'output_size', 'num_params','trainable'],
    col_width=20,
    row_settings=['var_names']
)
model_summary

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
EfficientNet (EfficientNet)                                  [32, 3, 224, 224]    [32, 3]              --                   Partial
├─Sequential (features)                                      [32, 3, 224, 224]    [32, 1280, 7, 7]     --                   False
│    └─Conv2dNormActivation (0)                              [32, 3, 224, 224]    [32, 32, 112, 112]   --                   False
│    │    └─Conv2d (0)                                       [32, 3, 224, 224]    [32, 32, 112, 112]   (864)                False
│    │    └─BatchNorm2d (1)                                  [32, 32, 112, 112]   [32, 32, 112, 112]   (64)                 False
│    │    └─SiLU (2)                                         [32, 32, 112, 112]   [32, 32, 112, 112]   --                   --
│    └─Sequential (1)                                        [32, 32, 112, 112]   [32, 