Loading the saved model weights into the model architecture:

In [1]:
import torch
import sys
sys.path.append("../src/models/")

from load_model import load_resnet50_model


# Load the architecture
# assuming the model structure is defined in this function
model = load_resnet50_model()

# Load the saved weights
model_path = "../models/ResNet50-Plant-model-Final.pth"
model = torch.load(model_path, map_location=torch.device('cpu'))

Evaluate Model Performance:

1. Iterate over the validation dataset and record the model's performance.

In [2]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

VALID_DATA_PATH = '../.kaggle/valid'  

valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

valid_data = ImageFolder(root=VALID_DATA_PATH, transform=valid_transform)
valid_loader = DataLoader(valid_data, batch_size=32,
                          shuffle=False, num_workers=4, prefetch_factor=2, pin_memory=True)


Estimating Time to process Validation Dataset (Optional)

In [7]:
import time

start_time = time.time()

num_batches_to_profile = 10
with torch.no_grad():
    for i, (images, labels) in enumerate(valid_loader):
        if i >= num_batches_to_profile:
            break
        outputs = model(images)

end_time = time.time()
avg_time_per_batch = (end_time - start_time) / num_batches_to_profile
total_batches = len(valid_loader)
estimated_total_time = avg_time_per_batch * total_batches
print(estimated_total_time)

11421.464232206345


In [6]:
correct = 0
total = 0

with torch.no_grad():  # No need to calculate gradients for evaluation
    for images, labels in valid_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)  # Returns max value, indices; only indices used
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy on the validation dataset: {accuracy:.2f}%")

KeyboardInterrupt: 