In [1]:
import torch
import torch.nn as nn
import cv2
import numpy as np
import pandas as pd
import os
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
import warnings

warnings.filterwarnings("ignore")

In [3]:
def inference():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    categories = [0, 1, 2, 3, 4]
    predict = []
    test_pd = pd.read_csv("train-1.csv")
    for index in range(0, len(test_pd)):
        img_path = os.path.join("data/train_images", "{}.png".format(test_pd.iloc[index, 0]))
        
        if os.path.exists(img_path):
            print(img_path)
            ori_image = cv2.imread(img_path)
            image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB)
            image = cv2.resize(image, (224, 224))
            image = np.transpose(image, (2, 0, 1))/255
            image = np.expand_dims(image, axis=0)
            image = torch.from_numpy(image).float()
            image = image.to(device)
    
            model = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
            model.classifier[3] = nn.Linear(1280, 5)
            checkpoint = torch.load("trained_models/mobilenet_v3/last.pt")
            model.load_state_dict(checkpoint["model"])
            model.to(device)
            model.eval()
            softmax = nn.Softmax()
    
            with torch.no_grad():
                prediction = model(image)
                # print(prediction)
                prob = softmax(prediction)
    
            max_value, max_index = torch.max(prob, dim=1)
            print(max_value[0], max_index[0], categories[max_index[0]])
            print("----------------------------------")
            predict.append(categories[max_index[0]])
        elif not os.path.exists(img_path):
            predict.append(test_pd.iloc[index, 1])
            
    test_pd["predict"] = predict
    test_pd.to_csv("train-1.csv", index=False)
    # print(predict)

if __name__ == '__main__':
    inference()
    

data/train_images\000c1434d8d7.png
tensor(0.7655, device='cuda:0') tensor(2, device='cuda:0') 2
----------------------------------
data/train_images\001639a390f0.png
tensor(0.9482, device='cuda:0') tensor(1, device='cuda:0') 1
----------------------------------
data/train_images\0024cdab0c1e.png
tensor(0.5966, device='cuda:0') tensor(1, device='cuda:0') 1
----------------------------------
data/train_images\002c21358ce6.png
tensor(0.9937, device='cuda:0') tensor(0, device='cuda:0') 0
----------------------------------
data/train_images\005b95c28852.png
tensor(0.9812, device='cuda:0') tensor(2, device='cuda:0') 2
----------------------------------
data/train_images\0097f532ac9f.png
tensor(0.9278, device='cuda:0') tensor(1, device='cuda:0') 1
----------------------------------
data/train_images\00a8624548a9.png
tensor(0.9999, device='cuda:0') tensor(2, device='cuda:0') 2
----------------------------------
data/train_images\00b74780d31d.png
tensor(0.7633, device='cuda:0') tensor(1, device