In [None]:
from torchvision import datasets, transforms, models
from torch import nn, optim
import torch.nn.functional as F
from collections import OrderedDict
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
import shutil 
device = 'mps'
model = models.resnet50(pretrained=True)
classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 512)),
                                        ('relu', nn.ReLU()),
                                        ('fc2', nn.Linear(512,38)),
                                        ('output', nn.LogSoftmax(dim=1))]))

model.fc = classifier
model_dict_loader = torch.load('plantvillage_resnet50.pth', map_location=torch.device('mps'))
state_dict = model_dict_loader['state_dict']
model.load_state_dict(state_dict)
model.eval()
model.to(device)

In [None]:
classifier_indices = {
"Apple___Apple_scab" : 0,
"Apple___Black_rot" : 1,
"Apple___Cedar_apple_rust" : 2,
"Apple___healthy" : 3,
"Blueberry___healthy" : 4,
"Cherry_(including_sour)___Powdery_mildew" : 5,
"Cherry_(including_sour)___healthy" : 6,
"Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot" : 7,
"Corn_(maize)___Common_rust_" : 8,
"Corn_(maize)___Northern_Leaf_Blight" : 9,
"Corn_(maize)___healthy" : 10,
"Grape___Black_rot" : 11,
"Grape___Esca_(Black_Measles)" : 12,
"Grape___Leaf_blight_(Isariopsis_Leaf_Spot)" : 13,
"Grape___healthy" : 14,
"Orange___Haunglongbing_(Citrus_greening)" : 15,
"Peach___Bacterial_spot" : 16,
"Peach___healthy" : 17,
"Pepper,_bell___Bacterial_spot" : 18,
"Pepper,_bell___healthy" : 19,
"Potato___Early_blight" : 20,
"Potato___Late_blight" : 21,
"Potato___healthy" : 22,
"Raspberry___healthy" : 23,
"Soybean___healthy" : 24,
"Squash___Powdery_mildew" : 25,
"Strawberry___Leaf_scorch" : 26,
"Strawberry___healthy" : 27,
"Tomato___Bacterial_spot" : 28,
"Tomato___Early_blight" : 29,
"Tomato___Late_blight" : 30,
"Tomato___Leaf_Mold" : 31,
"Tomato___Septoria_leaf_spot" : 32,
"Tomato___Spider_mites Two-spotted_spider_mite" : 33,
"Tomato___Target_Spot" : 34,
"Tomato___Tomato_Yellow_Leaf_Curl_Virus" : 35,
"Tomato___Tomato_mosaic_virus" : 36,
"Tomato___healthy" : 37,
  }

In [None]:
#Here we will check the accuray on the generated images.
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def predict_image(image_path):
    try:
        image = Image.open(image_path)
        image = transform(image)
        image = image.unsqueeze(0).to(device)
    except:
        return -1
    with torch.no_grad():
        outputs = model(image)
        # print(outputs)
        _, predicted = torch.max(outputs, 1)
    return predicted[0].item()

In [None]:
#This function will evaluate the image, if found lying in that particular class, it will copy it to new location.
def predict_image_dir(content_data_dir, verified_dir_path, is_save):
    total = 0
    correct = 0
    for image_file in os.listdir(content_data_dir):
        total += 1
        image_path = os.path.join(content_data_dir, image_file)
        if os.path.isfile(image_path):
            result = predict_image(image_path)
            if is_save:
                index = content_data_dir.rfind("/")
                content_data_dir_name = content_data_dir[index+1:]
                # print(content_data_dir_name)
                # print(classifier_indices[content_data_dir_name])
                # print(image_path)
                if (result == classifier_indices[content_data_dir_name]):
                    correct += 1
                    final_folder_path = os.path.join(verified_dir_path, content_data_dir_name)
                    os.makedirs(final_folder_path, exist_ok=True)
                    shutil.copy(image_path, final_folder_path)
                    # print("Generated Image correctly belongs to the class : " +str(content_data_dir_name))
            else:
                result = predict_image(image_path)
                index = content_data_dir.rfind("/")
                content_data_dir_name = content_data_dir[index+1:]
                if (result == classifier_indices[content_data_dir_name]):
                    correct += 1
                # print(result)
    return total, correct

In [None]:
total_correct = 0
total_images = 0
data_path = './nst-augmented-dataset'
augmented_dir_path_list = os.listdir(data_path)
for i in range(len(augmented_dir_path_list)):
    if os.path.isdir(os.path.join(data_path, augmented_dir_path_list[i])):
        # print(content_data_dir)
        # print(weight_folder)
        total, correct = predict_image_dir(os.path.join(data_path, augmented_dir_path_list[i]), "./classified-augmented-original", True)
        # total_correct += correct
        # total_images += total
        # print("For the directory : " + str(augmented_dir_path_list[i]))
        # # print("Total Image : " + str(total))
        # print("Correctly generated Image : " + str(correct))
        accuracy = (correct/total)*100
        print("For Leaf type : " + augmented_dir_path_list[i])
        print("Correct = " + str(correct) + " Total = " + str(total))
        print("Total accuracy : " + str(accuracy))
    # print("Total Images Augmented : " + str(total_images))
    # print("Total Images Correctly Augmented : " + str(total_correct))