In [86]:
# imports
import matplotlib.pyplot as plt
import modules.cosmos_functions as cf
import numpy as np
import torch
import torch.nn.functional as F 
import torchvision.transforms as T
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import resnet18, ResNet18_Weights
from modules.MyAQLclass import MyAQLclass


In [87]:
# functions to use in the tests 

# functions to display images
def reverse_normalize(image):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    image = image.clone()
    for i in range(3):
        image[i] = (image[i] * std[i]) + mean[i]
    return image

def show_batch(test_d):
    # Get the first batch of data from the DataLoader
    data_test = next(iter(test_d))

    # Retrieve the first tensor and its corresponding label
    image_test = data_test[0][0]
    label_test = data_test[1][0]

    # Reverse the normalization of the image
    image_test = reverse_normalize(image_test)

    # Convert the image tensor to a NumPy array and transpose the dimensions
    np_image_test = image_test.permute(1, 2, 0).numpy()

    # Display the image
    plt.imshow(np_image_test)
    plt.title(f'{label_test}, {image_test.shape}')
    plt.axis('off')

    # Show the plot
    plt.show()

# function to test the model
def test_model(model, datasetPath):
    model.eval()    


    # Load the test dataset
    dataset_path = datasetPath
    transform = T.Compose([
        T.Resize((128, 128)),
        T.ToTensor()
                
])
    
    #pull the relevant AQL data
    AQLtest = MyAQLclass()
    lotsize = AQLtest.get_lotsize()
    test_inspection_lvl = AQLtest.get_test_inspection_lvl()
    batch_size = AQLtest.batch_size()

    dataset = ImageFolder(dataset_path, transform=transform)
    test_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    labels_dict = dataset.class_to_idx

    # Track the overall test accuracy and accuracy by each type of apple
    overall_correct = 0
    overall_total = 0
    normal_correct = 0
    normal_total = 0
    abnormal_correct = 0
    abnormal_total = 0

    # Initialize the confusion matrix
    num_classes = len(labels_dict)
    confusion_matrix = np.zeros((num_classes, num_classes), dtype=int)

    # Iterate over the test dataset
    # Iterate over the test dataset
    for batch_idx, (images, labels) in enumerate(test_dataloader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)

        # Get predictions
        _, predicted = torch.max(outputs.data, 1)

        # Update accuracy counts
        overall_correct += (predicted == labels).sum().item()
        overall_total += labels.size(0)

        # Calculate accuracy for normal apples vs. abnormal apples
        normal_mask = labels == labels_dict['Normal_Apple']
        abnormal_mask = ~normal_mask
        normal_correct += (predicted[normal_mask] == labels[normal_mask]).sum().item()
        normal_total += normal_mask.sum().item()
        abnormal_correct += (predicted[abnormal_mask] == labels[abnormal_mask]).sum().item()
        abnormal_total += abnormal_mask.sum().item()

        # Update the confusion matrix
        for true_label, predicted_label in zip(labels.cpu().numpy(), predicted.cpu().numpy()):
            confusion_matrix[true_label][predicted_label] += 1
        
        # Break the loop after processing the first batch
        if batch_idx == 0:
            break

    # Calculate overall accuracy
    overall_accuracy = overall_correct / overall_total

    # Calculate accuracy for normal apples and abnormal apples separately
    normal_accuracy = normal_correct / normal_total if normal_total != 0 else 0.0
    abnormal_accuracy = abnormal_correct / abnormal_total if abnormal_total != 0 else 0.0

    # Print overall accuracy
    print(f"Overall accuracy: {overall_accuracy:.4f}")

    # Print accuracy for normal apples and abnormal apples separately
    print(f"Normal Apple accuracy: {normal_accuracy:.4f}")
    print(f"Abnormal Apple accuracy: {abnormal_accuracy:.4f}")

    # Print the confusion matrix
    print()
    print(labels_dict)
    print("Confusion Matrix:")
    print(confusion_matrix)

  

    # get the AQL label
    rejected_apples = np.sum(confusion_matrix)-np.sum(confusion_matrix[1])
    
    
    AQLtest.test_input = rejected_apples
    x=AQLtest.output()
    
    print(f'From a lot of {lotsize} in accordance quality level {test_inspection_lvl},')
    print(f'a batch of {batch_size} has been randomly drawn.')
    print(f'the number of rejected apples is: {rejected_apples}')
    print(f'The AQL label is: Class_{x}')


In [88]:
# set the device

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
x = torch.ones(1, device=device)

print(f"Device is '{device}' Thus a tensor will look like this: {x}")

Device is 'mps' Thus a tensor will look like this: tensor([1.], device='mps:0')


In [89]:
# add the resnet18 model for testing the test dataset with the trained model
# import the resnet18 model

weights = ResNet18_Weights.DEFAULT  #weights=ResNet18_Weights.IMAGENET1K_V1 is the current default
model = resnet18(weights=weights)  

# freeze the model parameters
for param in model.parameters():
    param.requires_grad = False
    
# change the last layer of the model to fit the number of classes in the dataset
model.fc = nn.Linear(512, 4)
    
# change the last layer of the model to fit the number of classes in the dataset
model.fc = nn.Linear(512, 4)



In [90]:

# imported_model_path = "../storage/data/generated/20230605-134750_pinky_acc.pt"  # high accuracy
# imported_model_path = cf.load_pth('20230605_160852_pinky')  # issues; WIP
# imported_model_path 


In [91]:
# Load the test dataset
dataset_path = "../storage/images/apple_resized_128/Test"

# import the model state
imported_model_state_path = "../storage/data/generated/20230612-151403_pinky_loss.pt"   # test to 128x128 accuracy

# load the model state into the model
model_state_import_path = imported_model_state_path
model.load_state_dict(torch.load(model_state_import_path))
model.to(device)





test_model(model, dataset_path)

Overall accuracy: 0.7500
Normal Apple accuracy: 0.1429
Abnormal Apple accuracy: 0.9200

{'Blotch_Apple': 0, 'Normal_Apple': 1, 'Rot_Apple': 2, 'Scab_Apple': 3}
Confusion Matrix:
[[ 9  0  0  2]
 [ 4  1  1  1]
 [ 0  0 10  0]
 [ 0  0  0  4]]
From a lot of 500 in accordance quality level I,
a batch of 32 has been randomly drawn.
the number of rejected apples is: 25
The AQL label is: Class_Rejected
