In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from PIL import Image

In [23]:
# Define the CNN architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# CIFAR-10 class names
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

# Load the saved model
model = CNN()
model.load_state_dict(torch.load('cnn_cifar10.pth'))
model.eval()

# Preprocess the custom image
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    img = Image.open(image_path)
    img = transform(img)
    img = img.unsqueeze(0)
    
    return img

# Make predictions
def predict_image(model, image_path):
    img = preprocess_image(image_path)
    with torch.no_grad():
        outputs = model(img)
        _, predicted = torch.max(outputs, 1)
    predicted_class = classes[predicted[0]]
    print(f'Predicted Class: {predicted_class}')
    return predicted_class

  model.load_state_dict(torch.load('cnn_cifar10.pth'))


In [24]:
# Example usage: Upload and predict a custom image
image_path = 'frog.webp'
predict_image(model, image_path)

Predicted Class: frog


'frog'