# predction
模型精确度预测，由于从头训练的模型param过大、训练的时常、数据集不足，测试意义不大；因此这里仅使用迁移学习版本的Vit进行测试

In [None]:

from going_modular import data_setup, model_builder, engine, utils
from torchvision import transforms
import torch
from torch import nn
import torchvision
from torchinfo import summary

DEVICE="cuda" if torch.cuda.is_available() else "cpu"
print(f"running on {DEVICE} ")

In [None]:
# custom
TRAIN_DIR="data/pizza_steak_sushi/train/"
TEST_DIR="data/pizza_steak_sushi/test/"

test_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)

train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
    train_dir=TRAIN_DIR,
    test_dir=TEST_DIR,
    train_transform=test_transform,
    test_transform=test_transform,
    batch_size=1
)
MODEL_PATH="modelzoo/VitBase_transfer_learning.pth"
model = torchvision.models.vit_b_16()
model.heads = nn.Sequential(
    nn.Linear(in_features=768, out_features=128),
    nn.ReLU(),
    nn.Linear(in_features=128, out_features=len(class_names)),
)

# model info
print(f"model name: {model.__class__.__name__}")
summary(
    model,
    input_size=(1, 3, 224, 224),
    col_names=(
        "input_size",
        "output_size",
        "mult_adds",
        "trainable",
    ),
)

model = torch.compile(model=model)
model.load_state_dict(torch.load(MODEL_PATH))
model = model.to(DEVICE)

In [None]:
from calflops import calculate_flops

batch_size = 1
input_shape = (batch_size, 3, 224, 224)
flops, macs, params = calculate_flops(model=model,
                                      input_shape=input_shape,
                                      output_as_string=True,
                                      output_precision=4)
print("model %s   FLOPs:%s   MACs:%s   Params:%s \n" %(model.__class__.__name__, flops, macs, params))

In [None]:
# prediction
from going_modular import prediction
prediction.pred_and_plt_image(
    model,
    "data/pizza_steak_sushi/test/pizza/1925494.jpg",
    class_names=class_names,
    transform=test_transform
)

prediction.pred_and_plt_confmat(
    model, test_dataloader=test_dataloader, class_names=class_names, device=DEVICE
)