In [5]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import numpy as np

# Define the class names
class_names = ['Bread', 'Dairy product', 'Dessert', 'Egg', 'Fried food', 'Meat', 'Noodles/Pasta', 'Rice', 'Seafood', 'Soup', 'Vegetable/Fruit']

# Define the ResNet34 model
model = models.resnet34(pretrained=True)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 11)  # Output layer with 11 classes

# Load pretrained weights
model.load_state_dict(torch.load('../../models/classification/simple_model_weights.pth', map_location=torch.device('cpu')))
model.eval()

# Define transformations for input images
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to predict the class of an input image
def predict_image(image_path, model, transform):
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
        predicted_class_index = predicted.item()
        predicted_class_name = class_names[predicted_class_index]
    return predicted_class_index, predicted_class_name

# Example usage
image_path = '../../assets/readme_assets/examples/rice_example.jpg'
predicted_class_index, predicted_class_name = predict_image(image_path, model, transform)
print(f'Predicted class index: {predicted_class_index}')
print(f'Predicted class name: {predicted_class_name}')


Predicted class index: 7
Predicted class name: Rice
