# predction
模型精确度预测，由于从头训练的模型param过大、训练的时常、数据集不足，测试意义不大；因此这里仅使用迁移学习版本的Vit进行测试

In [1]:

from going_modular import data_setup, model_builder, engine, utils
from torchvision import transforms
import torch
from torch import nn
import torchvision
from torchinfo import summary

DEVICE="cuda" if torch.cuda.is_available() else "cpu"
print(f"running on {DEVICE} ")

running on cuda 


In [None]:
import os
import zipfile

from pathlib import Path

import requests

def download_data(source: str, 
                  destination: str,
                  remove_source: bool = True) -> Path:
    """Downloads a zipped dataset from source and unzips to destination.

    Args:
        source (str): A link to a zipped file containing data.
        destination (str): A target directory to unzip data to.
        remove_source (bool): Whether to remove the source after downloading and extracting.
    
    Returns:
        pathlib.Path to downloaded data.
    
    Example usage:
        download_data(source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip",
                      destination="pizza_steak_sushi")
    """
    # Setup path to data folder
    data_path = Path("data/")
    image_path = data_path / destination

    # If the image folder doesn't exist, download it and prepare it...
    if image_path.is_dir():
        print(f"[INFO] {image_path} directory exists, skipping download.")
    else:
        print(f"[INFO] Did not find {image_path} directory, creating one...")
        image_path.mkdir(parents=True, exist_ok=True)

        # Download pizza, steak, sushi data
        target_file = Path(source).name
        with open(data_path / target_file, "wb") as f:
            request = requests.get(source)
            print(f"[INFO] Downloading {target_file} from {source}...")
            f.write(request.content)

        # Unzip pizza, steak, sushi data
        with zipfile.ZipFile(data_path / target_file, "r") as zip_ref:
            print(f"[INFO] Unzipping {target_file} data...") 
            zip_ref.extractall(image_path)

        # Remove .zip file
        if remove_source:
            os.remove(data_path / target_file)

    return image_path

image_path = download_data(
    source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip",
    destination="pizza_steak_sushi",
    remove_source=False,
)
image_path

[INFO] data/pizza_steak_sushi directory exists, skipping download.


PosixPath('data/pizza_steak_sushi')

In [None]:
# custom
TRAIN_DIR="data/pizza_steak_sushi/train/"
TEST_DIR="data/pizza_steak_sushi/test/"

test_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
    train_dir=TRAIN_DIR,
    test_dir=TEST_DIR,
    train_transform=test_transform,
    test_transform=test_transform,
    batch_size=1
)
MODEL_PATH="modelzoo/VitBase_transfer_learning.pth"
model = torchvision.models.vit_b_16()
model.heads = nn.Sequential(
    nn.Linear(in_features=768, out_features=128),
    nn.ReLU(),
    nn.Linear(in_features=128, out_features=len(class_names)),
)

# model info
print(f"model name: {model.__class__.__name__}")
summary(
    model,
    input_size=(1, 3, 224, 224),
    col_names=(
        "input_size",
        "output_size",
        "mult_adds",
        "trainable",
    ),
)

model = torch.compile(model=model)
model.load_state_dict(torch.load(MODEL_PATH))
model = model.to(DEVICE)

FileNotFoundError: [Errno 2] No such file or directory: 'data/pizza_steak_sushi/train/'

In [None]:
from calflops import calculate_flops

batch_size = 1
input_shape = (batch_size, 3, 224, 224)
flops, macs, params = calculate_flops(model=model,
                                      input_shape=input_shape,
                                      output_as_string=True,
                                      output_precision=4)
print("model %s   FLOPs:%s   MACs:%s   Params:%s \n" %(model.__class__.__name__, flops, macs, params))

In [None]:
# prediction
from going_modular import prediction
prediction.pred_and_plt_image(
    model,
    "data/pizza_steak_sushi/test/pizza/1925494.jpg",
    class_names=class_names,
    transform=test_transform
)

prediction.pred_and_plt_confmat(
    model, test_dataloader=test_dataloader, class_names=class_names, device=DEVICE
)