In [6]:
# import torch
# import torch.nn as nn
# from torchvision.models import resnet18
# from torch.utils.data import DataLoader
# from torchvision.transforms import transforms
# from torchvision.datasets import ImageFolder
# from pathlib import Path

# # Constants
# IMG_SIZE = 224
# NUM_CLASSES = 10
# MODEL_FOLDER_PATH = Path.cwd()  # Use current working directory
# model_path = Path("model_resnet18_20240322_125552.pt")  # Assuming model.pt is in the current directory

# # Define neural network architecture
# class Net(nn.Module):
#     def __init__(self, backbone, num_classes):
#         super(Net, self).__init__()
#         # Resnet Backbone (includes avg pooling layer, takes off last FC layer)
#         self.features = nn.Sequential(*list(backbone.children())[:-1])
#         self.out = nn.Linear(backbone.fc.in_features, num_classes)

#     def forward(self, x):
#         x = self.features(x)
#         x = torch.flatten(x, 1)
#         x = self.out(x)
#         return x

# # Load the model
# model = Net(resnet18(pretrained=True), NUM_CLASSES)
# model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
# model.eval()

# # Define data preprocessing transformation
# transform = transforms.Compose([
#     transforms.Resize((IMG_SIZE, IMG_SIZE)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# # Define test dataset and DataLoader
# test_data_path = "/mnt/d/FY2023/DataSets/milVehs/dataset/test"
# test_dataset = ImageFolder(test_data_path, transform=transform)
# test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# # Define the test function
# def test(model, test_loader):
#     model.eval()
#     correct = 0
#     total = 0
#     with torch.no_grad():
#         for images, labels in test_loader:
#             outputs = model(images)
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()
#     accuracy = 100 * correct / total
#     return accuracy

# # Perform predictions on the test dataset and display accuracy
# test_accuracy = test(model, test_loader)
# print(f"Test Accuracy: {test_accuracy:.2f}%")


In [12]:
import torch
from torchvision.models import resnet18
from torchvision import transforms
from PIL import Image
import json

# Define the predict_dict function
def predict_dict(json_file, test_image_path, model):
    # Load the JSON file
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    # Define the transformation to apply to the images
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Open and preprocess the test image
    test_image = Image.open(test_image_path).convert('RGB')
    test_image = preprocess(test_image)
    test_image = test_image.unsqueeze(0)  # Add batch dimension
    
    predictions = {}
    
    # Iterate through each entry in the JSON data
    for entry in data:
        image_path = entry['image_path']
        label = entry['label']
        
        # Perform prediction
        with torch.no_grad():
            output = model(test_image)
            _, predicted = torch.max(output, 1)
            predicted_label = predicted.item()
        
        # Store the prediction result
        predictions[image_path] = {
            'true_label': label,
            'predicted_label': predicted_label
        }
    
    return predictions

# Load the fine-tuned model
model_path = "model_resnet18_20240322_125552.pt"
model = resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, NUM_CLASSES)  # Modify the last layer to match the number of classes
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

# Set the model to evaluation mode
model.eval()

# Specify the path to the JSON file
json_file = "test_data_milVeh.json"

# Specify the path to the test image
test_image_path = "/mnt/d/FY2023/DataSets/milVehs/testImgs/tank4.jpg"

# Predict using the updated JSON file and test image
predictions = predict_dict(json_file, test_image_path, model)
print(predictions)


RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.downsample.0.weight", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.downsample.0.weight", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer4.0.conv1.weight", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.conv2.weight", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.1.conv1.weight", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.conv2.weight", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "fc.weight", "fc.bias". 
	Unexpected key(s) in state_dict: "features.0.weight", "features.1.weight", "features.1.bias", "features.1.running_mean", "features.1.running_var", "features.1.num_batches_tracked", "features.4.0.conv1.weight", "features.4.0.bn1.weight", "features.4.0.bn1.bias", "features.4.0.bn1.running_mean", "features.4.0.bn1.running_var", "features.4.0.bn1.num_batches_tracked", "features.4.0.conv2.weight", "features.4.0.bn2.weight", "features.4.0.bn2.bias", "features.4.0.bn2.running_mean", "features.4.0.bn2.running_var", "features.4.0.bn2.num_batches_tracked", "features.4.1.conv1.weight", "features.4.1.bn1.weight", "features.4.1.bn1.bias", "features.4.1.bn1.running_mean", "features.4.1.bn1.running_var", "features.4.1.bn1.num_batches_tracked", "features.4.1.conv2.weight", "features.4.1.bn2.weight", "features.4.1.bn2.bias", "features.4.1.bn2.running_mean", "features.4.1.bn2.running_var", "features.4.1.bn2.num_batches_tracked", "features.5.0.conv1.weight", "features.5.0.bn1.weight", "features.5.0.bn1.bias", "features.5.0.bn1.running_mean", "features.5.0.bn1.running_var", "features.5.0.bn1.num_batches_tracked", "features.5.0.conv2.weight", "features.5.0.bn2.weight", "features.5.0.bn2.bias", "features.5.0.bn2.running_mean", "features.5.0.bn2.running_var", "features.5.0.bn2.num_batches_tracked", "features.5.0.downsample.0.weight", "features.5.0.downsample.1.weight", "features.5.0.downsample.1.bias", "features.5.0.downsample.1.running_mean", "features.5.0.downsample.1.running_var", "features.5.0.downsample.1.num_batches_tracked", "features.5.1.conv1.weight", "features.5.1.bn1.weight", "features.5.1.bn1.bias", "features.5.1.bn1.running_mean", "features.5.1.bn1.running_var", "features.5.1.bn1.num_batches_tracked", "features.5.1.conv2.weight", "features.5.1.bn2.weight", "features.5.1.bn2.bias", "features.5.1.bn2.running_mean", "features.5.1.bn2.running_var", "features.5.1.bn2.num_batches_tracked", "features.6.0.conv1.weight", "features.6.0.bn1.weight", "features.6.0.bn1.bias", "features.6.0.bn1.running_mean", "features.6.0.bn1.running_var", "features.6.0.bn1.num_batches_tracked", "features.6.0.conv2.weight", "features.6.0.bn2.weight", "features.6.0.bn2.bias", "features.6.0.bn2.running_mean", "features.6.0.bn2.running_var", "features.6.0.bn2.num_batches_tracked", "features.6.0.downsample.0.weight", "features.6.0.downsample.1.weight", "features.6.0.downsample.1.bias", "features.6.0.downsample.1.running_mean", "features.6.0.downsample.1.running_var", "features.6.0.downsample.1.num_batches_tracked", "features.6.1.conv1.weight", "features.6.1.bn1.weight", "features.6.1.bn1.bias", "features.6.1.bn1.running_mean", "features.6.1.bn1.running_var", "features.6.1.bn1.num_batches_tracked", "features.6.1.conv2.weight", "features.6.1.bn2.weight", "features.6.1.bn2.bias", "features.6.1.bn2.running_mean", "features.6.1.bn2.running_var", "features.6.1.bn2.num_batches_tracked", "features.7.0.conv1.weight", "features.7.0.bn1.weight", "features.7.0.bn1.bias", "features.7.0.bn1.running_mean", "features.7.0.bn1.running_var", "features.7.0.bn1.num_batches_tracked", "features.7.0.conv2.weight", "features.7.0.bn2.weight", "features.7.0.bn2.bias", "features.7.0.bn2.running_mean", "features.7.0.bn2.running_var", "features.7.0.bn2.num_batches_tracked", "features.7.0.downsample.0.weight", "features.7.0.downsample.1.weight", "features.7.0.downsample.1.bias", "features.7.0.downsample.1.running_mean", "features.7.0.downsample.1.running_var", "features.7.0.downsample.1.num_batches_tracked", "features.7.1.conv1.weight", "features.7.1.bn1.weight", "features.7.1.bn1.bias", "features.7.1.bn1.running_mean", "features.7.1.bn1.running_var", "features.7.1.bn1.num_batches_tracked", "features.7.1.conv2.weight", "features.7.1.bn2.weight", "features.7.1.bn2.bias", "features.7.1.bn2.running_mean", "features.7.1.bn2.running_var", "features.7.1.bn2.num_batches_tracked", "out.weight", "out.bias". 