## 0. Getting setup

In [34]:
import torch
import torchvision
import matplotlib.pyplot as plt

from torch import nn
from torchvision import transforms

print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")

try:
    from torchinfo import summary
except:
    print("[INFO] Coudnl't find torchinfo... installing it")
    !pip install -q torchinfo
    from torchinfo import summary

try:
    from going_modular import data_setup, engine
except:
    print("[INFO] Couldn't find going_modular scripts...")


torch version: 1.10.0
torchvision version: 0.11.1


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [6]:
def set_seeds(seed: int=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

# 1. Get data

In [13]:
import os
import zipfile
import requests

from pathlib import Path

def download_data(source: str,
                  destination: str,
                  remove_source: bool = True) -> Path:
    data_path = Path("data/")
    image_path = data_path / destination

    if image_path.is_dir():
        print(f"[INFO] {image_path} directory exists, skipping download.")
    else:
        print(f"[INFO] did not find {image_path} directory, creating one...")

    # Downlaading pizza, steak, sushi
    target_file = Path(source).name
    with open(data_path / target_file, 'wb') as f:
        request = requests.get(source)
        print(f"[INFO] Downloading {target_file} from {source}...")
        f.write(request.content)

    with zipfile.ZipFile(data_path / target_file, 'r') as zip_ref:
        print(f"[INFO] Unzipping {target_file} data...")
        zip_ref.extractall(image_path)

    if remove_source:
        os.remove(data_path / target_file)

    return image_path

image_path = download_data(source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip",
                           destination="pizza_steak_sushi")
image_path

[INFO] data\pizza_steak_sushi directory exists, skipping download.
[INFO] Downloading pizza_steak_sushi.zip from https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip...
[INFO] Unzipping pizza_steak_sushi.zip data...


WindowsPath('data/pizza_steak_sushi')

# 2. Create Datasets and DataLoaders

In [20]:
train_dir = image_path / "train"
test_dir = image_path / "test"

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.226])

manual_transforms = transforms.Compose([
    transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize])
print(f"Manually created transforms: {manual_transforms}")

train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
    train_dir=train_dir,
    test_dir=test_dir,
    transform=manual_transforms,
    batch_size=32
)

train_dataloader, test_dataloader, class_names

Manually created transforms: Compose(
    Resize(size=(224, 224), interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.226])
)


(<torch.utils.data.dataloader.DataLoader at 0x2199e9fee50>,
 <torch.utils.data.dataloader.DataLoader at 0x2199e9fed60>,
 ['pizza', 'steak', 'sushi'])

In [22]:
model = torchvision.models.efficientnet_b0(pretrained=True).to(device) # for torchvision < 13
# weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT for torchvision > 0.13
# model = torchvision.models.efficientnet_b0(weights=weights)

In [23]:
# Freeze all base layers
for param in model.features.parameters():
    param.requires_grad = False

set_seeds()

model.classfier = torch.nn.Sequential(
    nn.Dropout(p=0.2, inplace=True),
    nn.Linear(in_features=1280,
              out_features=len(class_names),
              bias=True).to(device))

In [28]:
# from torchinfo import summary

# summary(model,
#         input_size=(32, 3, 224, 224),
#         verbose=0,
#         col_names=["input_size", "output_size", "num_params", "trainable"],
#         col_width=20,
#         row_settings=["var_names"]
#        )

# 4. Train model and track results

In [30]:
criterion = nn.CrossEntropyLoss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [39]:
try:
    import wandb
except:
    !pip install wandb
    import wandb

In [None]:
from typing import Dict, List
from tqdm.auto import tqdm
from going_modular.engine import train_step, test_step

def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.Dataloader,
          test_dataloader: torch.utils.data.Dataloader,
          optimizer: torch.optim.Optimizer,
          criterion: torch.nn.Module,
          epochs: int,
          device: torch.device) -> Dict[str, List]:
