# Transfer learning

## Imports, device setup and utility functions

In [41]:
import torch
import torchvision

print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch torchvision version: {torchvision.__version__}")

device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device("cpu")
print(f"Torch device: {device}")
if torch.cuda.is_available():
    for device_id in range(torch.cuda.device_count()):
        print(f"Found CUDA device: cuda:{device_id} - {torch.cuda.get_device_name(device_id)}")


PyTorch version: 2.8.0+cu126
PyTorch torchvision version: 0.23.0+cu126
Torch device: cpu


## Getting the base model

In [42]:
from torchvision.models import efficientnet_b3 as base_model_class, EfficientNet_B3_Weights as BaseModelWeights
from pathlib import Path
from PIL import Image

base_transforms = BaseModelWeights.DEFAULT.transforms();

custom_image_path = Path("data")/"04-pizza-dad.jpg"

if not custom_image_path.is_file():
  with open(custom_image_path, "wb") as f:
    request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/refs/heads/main/images/04-pizza-dad.jpeg")
    print(f"Downloading {custom_image_path}....")
    f.write(request.content)
else:
  print(f"Custom image already downloaded")

sample_image = Image.open(custom_image_path)

print(f"Loaded {custom_image_path}")

sample_image_transformed = base_transforms(sample_image)
print(f"Transformed image shape: {sample_image_transformed.shape}, dtype: {sample_image_transformed.dtype}")

base_model = base_model_class(weights = BaseModelWeights.DEFAULT)

try:
  import torchinfo
except:
  !pip install torchinfo
  import torchinfo

torchinfo.summary(base_model, input_size=[32,sample_image_transformed.size(dim=0),
                                       sample_image_transformed.size(dim=1),
                                       sample_image_transformed.size(dim=2)],
                  col_names=["input_size", "output_size", "num_params", "trainable"],
                  col_width=20,
                  row_settings=["var_names"])

Custom image already downloaded
Loaded data/04-pizza-dad.jpg
Transformed image shape: torch.Size([3, 300, 300]), dtype: torch.float32


Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
EfficientNet (EfficientNet)                                  [32, 3, 300, 300]    [32, 1000]           --                   True
├─Sequential (features)                                      [32, 3, 300, 300]    [32, 1536, 10, 10]   --                   True
│    └─Conv2dNormActivation (0)                              [32, 3, 300, 300]    [32, 40, 150, 150]   --                   True
│    │    └─Conv2d (0)                                       [32, 3, 300, 300]    [32, 40, 150, 150]   1,080                True
│    │    └─BatchNorm2d (1)                                  [32, 40, 150, 150]   [32, 40, 150, 150]   80                   True
│    │    └─SiLU (2)                                         [32, 40, 150, 150]   [32, 40, 150, 150]   --                   --
│    └─Sequential (1)                                        [32, 40, 150, 150]   [32, 24, 150

## Get the extra data

In [43]:
import requests
import zipfile

data_path= Path("data/")
image_path= data_path / "pizza_steak_sushi"

# If the data already exists, don't download again
if image_path.is_dir():
  print(f"{image_path} directory already exists, not downloading")
else:
  print(f"{image_path} does not exist, creating")
  image_path.mkdir(parents=True, exist_ok=True)

with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
  request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/refs/heads/main/data/pizza_steak_sushi.zip")
  print("Downloading pizza, steak, sushi data")
  f.write(request.content)

with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
  print("Unzipping pizza_steak_sushi data")
  zip_ref.extractall(image_path)

data/pizza_steak_sushi directory already exists, not downloading
Downloading pizza, steak, sushi data
Unzipping pizza_steak_sushi data


In [44]:
import os
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

train_dir = image_path / "train"
test_dir = image_path / "test"

train_data = ImageFolder(root=train_dir,
                     transform=base_transforms
                     )
test_data = ImageFolder(root=test_dir,
                     transform=base_transforms
                     )

BATCH_SIZE = 32
SPARE_THREADS = 2

train_dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=os.cpu_count())
test_dataloader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=os.cpu_count())

print(f"Loaded {len(train_data)} training samples and {len(test_data)} testing samples.")
print(f"Additional classes: {train_data.classes}")

Loaded 225 training samples and 75 testing samples.
Additional classes: ['pizza', 'steak', 'sushi']


## Freezing layers and replacing others

The summary information above shows the parameter name of the layer in the output, the feature layer is called `features`, so we can freeze all of the layers that make the `features` `Sequential` up in one hit

In [53]:
# Freeze the layers we want to keep unchanged
for layer in base_model.features.parameters():
  layer.requires_grad = False

torch.manual_seed(42)
torch.cuda.manual_seed(42)

# Replace the layers we want to modify.
output_shape = len(test_data.classes)

base_model.classifier = torch.nn.Sequential(
    torch.nn.Dropout(p=0.2, inplace=True),
    torch.nn.Linear(in_features=1536,
                    out_features=output_shape, # same number of output units as our number of classes
                    bias=True)).to(device)

torchinfo.summary(base_model, input_size=[32,sample_image_transformed.size(dim=0),
                                       sample_image_transformed.size(dim=1),
                                       sample_image_transformed.size(dim=2)],
                  col_names=["input_size", "output_size", "num_params", "trainable"],
                  col_width=20,
                  row_settings=["var_names"])

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
EfficientNet (EfficientNet)                                  [32, 3, 300, 300]    [32, 3]              --                   Partial
├─Sequential (features)                                      [32, 3, 300, 300]    [32, 1536, 10, 10]   --                   False
│    └─Conv2dNormActivation (0)                              [32, 3, 300, 300]    [32, 40, 150, 150]   --                   False
│    │    └─Conv2d (0)                                       [32, 3, 300, 300]    [32, 40, 150, 150]   (1,080)              False
│    │    └─BatchNorm2d (1)                                  [32, 40, 150, 150]   [32, 40, 150, 150]   (80)                 False
│    │    └─SiLU (2)                                         [32, 40, 150, 150]   [32, 40, 150, 150]   --                   --
│    └─Sequential (1)                                        [32, 40, 150, 150]   [32, 