In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory


data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

image_datasets = {x: datasets.ImageFolder(os.path.join('dataset', x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

class_names = image_datasets['train'].classes
print(class_names)

['ants', 'bees']


In [18]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def resnet_detect(images_path, model_name=None):
    
    if not model_name:
        print(f"Model name is not provided, use default model")
        model_path = "models/resnet_epoch_3.pt"
    else:
        model_path = os.path.join("models", model_name)
        if not os.path.exists(model_path):
            print(f"{model_name} does not exist, use default model")
            model_path = "models/resnet_epoch_3.pt"
        
    # 加载模型
    model = models.resnet50(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, len(class_names))
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    model.eval()
    
    # Todo: 加 log
    print('Model loaded')
    
    # 加载图像
    res = []
   
    all_images = os.listdir(images_path)
    img_count = len(all_images)
    for idx, img_name in enumerate(all_images):
        # Todo: 加日志 idx/img_count 

        image_path = os.path.join(images_path, img_name)
        image = Image.open(image_path)
        image = data_transforms['val'](image).unsqueeze(0).to(device)
    
        # 预测
        outputs = model(image)
        _, preds = torch.max(outputs, 1)

        if preds:
            label = class_names[preds.item()]
        else:
            label = "unknown"

        res.append({"img_name": img_name, "class": {label}})
        
        # Todo: 加日志 log  f"{img}" 处理已完成，预测类别为 {label}
        print(f'The file {img_name} predicted class is: {label}')     

        
    print("res", res)
    
    # Todo: 结果保存到 MongoDB 
    
if __name__ == "__main__":
    images_path = 'dataset/val/bees'
    res = resnet_detect(image_path)


Model name is not provided, use default model
Model loaded
The file 1297972485_33266a18d9.jpg predicted class is: bees
The file 2321144482_f3785ba7b2.jpg predicted class is: bees
The file 144098310_a4176fd54d.jpg predicted class is: bees
The file 1328423762_f7a88a8451.jpg predicted class is: bees
The file 26589803_5ba7000313.jpg predicted class is: bees
The file 1032546534_06907fe3b3.jpg predicted class is: bees
The file 1519368889_4270261ee3.jpg predicted class is: unknown
The file 348291597_ee836fbb1a.jpg predicted class is: bees
The file 151603988_2c6f7d14c7.jpg predicted class is: bees
The file 2501530886_e20952b97d.jpg predicted class is: bees
The file 215512424_687e1e0821.jpg predicted class is: bees
The file 2525379273_dcb26a516d.jpg predicted class is: unknown
The file 1355974687_1341c1face.jpg predicted class is: bees
The file 2815838190_0a9889d995.jpg predicted class is: bees
The file 65038344_52a45d090d.jpg predicted class is: bees
The file 353266603_d3eac7e9a0.jpg predicted

['1297972485_33266a18d9.jpg',
 '2321144482_f3785ba7b2.jpg',
 '144098310_a4176fd54d.jpg',
 '1328423762_f7a88a8451.jpg',
 '26589803_5ba7000313.jpg',
 '1032546534_06907fe3b3.jpg',
 '1519368889_4270261ee3.jpg',
 '348291597_ee836fbb1a.jpg',
 '151603988_2c6f7d14c7.jpg',
 '2501530886_e20952b97d.jpg',
 '215512424_687e1e0821.jpg',
 '2525379273_dcb26a516d.jpg',
 '1355974687_1341c1face.jpg',
 '2815838190_0a9889d995.jpg',
 '65038344_52a45d090d.jpg',
 '353266603_d3eac7e9a0.jpg',
 '57459255_752774f1b2.jpg',
 '177677657_a38c97e572.jpg',
 '152789693_220b003452.jpg',
 '2509402554_31821cb0b6.jpg',
 '603709866_a97c7cfc72.jpg',
 '2470492902_3572c90f75.jpg',
 '372228424_16da1f8884.jpg',
 '44105569_16720a960c.jpg',
 '2668391343_45e272cd07.jpg',
 '2407809945_fb525ef54d.jpg',
 '149973093_da3c446268.jpg',
 '1181173278_23c36fac71.jpg',
 '759745145_e8bc776ec8.jpg',
 '2685605303_9eed79d59d.jpg',
 '2809496124_5f25b5946a.jpg',
 '2103637821_8d26ee6b90.jpg',
 '2741763055_9a7bb00802.jpg',
 '2745389517_250a397f31.jpg',