In [1]:
import torch
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from glob import glob
import os
import pandas as pd
from torchvision import io
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from PIL import Image, ImageFont, ImageDraw 

from components.dataset import SoilDataset, SoilDataset_phone
from components.mymodel import get_model

In [2]:
def save_result(img:Image.Image, label:str, predict:str, picname:str):
    draw = ImageDraw.Draw(img)
    font = ImageFont.load_default()
    # font.
    import cv2
    font_size = 30
    font_path = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
    font = ImageFont.truetype(font_path, size=font_size)
    draw.text((10, 10),f"Label:{label}",(0,0,0), font=font)
    draw.text((10, 40 + font_size),f"Predict:{predict}",(0,0,0), font=font)
    img.save(picname)

In [3]:
def run_result(model_name:str, image_set:str):
    preprocess = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        # transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    model = get_model(model_name=model_name, image_set=image_set)
    dataset = SoilDataset_phone('./dataset/phones', image_set, transform=preprocess)
    loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=1)
    # toPIL = transforms.transforms.ToPILImage()
    result_path = os.path.join('./result/',model_name,image_set)
    if(os.path.exists(result_path) == False):
        os.makedirs(result_path)

    for X,y,pic_name in loader:
        yhat = model(X).detach()[0][0]
        img = Image.open(f'./dataset/phones/{image_set}/{pic_name[0]}')
        y = y[0]
        pic_name = os.path.join(result_path,pic_name[0])
        save_result(img, y.numpy(), yhat.numpy(), pic_name)
        # break

In [4]:
image_set_list = next(os.walk('./dataset/phones/'))[1]
# mobilenet_v3_large
# resnet50
# efficientnet_v2_l
# alexnet
model_name = 'alexnet'
for image_set in image_set_list:
    run_result(model_name,image_set)

Found 775 images in ./dataset/phones/Apple.


KeyboardInterrupt: 