## Image Classification with PyTorch: Simple CNN Model

In this JupyterLab notebook, we will walk through the process of building and training a simple Convolutional Neural Network (CNN) model for image classification using PyTorch. Image classification is a fundamental task in computer vision, where the goal is to assign a label to an image based on its content.

### Objective
In this example, we will use a dataset of images stored in the "dataset" folder. Each image belongs to a specific category, and the goal is to train a CNN model that can accurately classify images into their respective categories.

### Key steps
The code provided below covers the following steps:
1. Importing necessary libraries and modules for data handling and deep learning with PyTorch.
2. Defining a basic CNN architecture using the `nn.Module` class.
3. Preprocessing the image data using transformations and setting up data loaders.
4. Initializing the model, loss function, and optimizer for training.
5. Training the model on the image dataset for a specified number of epochs.
6. Saving the trained model's state dictionary as a `.pth` file.

In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# Define a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Linear(16 * 112 * 112, num_classes)  # Adjust input size

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# translate folder int, to scientific name, to common name
animal_lst = os.listdir("./dataset")

animal_dir = {}
for i in range(len(animal_lst)):
    animal_dir[i] = animal_lst[i]
    
import json
# Load the translation JSON file
with open("./animal_name_translation/translation.json") as f:
    class_translation = json.load(f)

names_dict = {}
for key, value in animal_dir.items():
    if value in class_translation:
        names_dict[key] = class_translation[value]
    
with open('./animal_name_translation/names_dict.json', 'w') as file:
     file.write(json.dumps(names_dict)) # use `json.loads` to do the reverse    
        

        
# Set up transformations and data loaders
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = ImageFolder(root="dataset", transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize the model and loss function
num_classes = len(dataset.classes)
model = SimpleCNN(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

if os.path.exists("./animal_models/cnn_model_trained.pth"):
    # Load the trained model's parameters
    model.load_state_dict(torch.load("./animal_models/cnn_model_trained.pth"))
    print("Pre-trained model loaded.")
else:
    # Training loop
    num_epochs = 10
    for epoch in range(num_epochs):
        for images, labels in dataloader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

    # Save the trained model as a .pth file
    torch.save(model.state_dict(), "./animal_models/cnn_model_trained.pth")

Pre-trained model loaded.


In [3]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import json

# Load the translation JSON file
with open("./animal_name_translation/names_dict.json") as f:
    class_translation = json.load(f)
    
# Define the model
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Linear(16 * 112 * 112, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# Load the trained model
model = SimpleCNN(num_classes=len(class_translation))

# Load the model's trained weights
model.load_state_dict(torch.load("./animal_models/cnn_model_trained.pth"))
model.eval()

# Load and preprocess the input image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
input_image = Image.open('./animal_images/cheetah.jpg')
input_tensor = transform(input_image).unsqueeze(0)

# Perform inference
with torch.no_grad():
    output = model(input_tensor)

# Interpret the output
probabilities = nn.functional.softmax(output[0], dim=0)
top_probabilities, top_class_indices = torch.topk(probabilities, k=10)

# Display top predictions with translated names
print("Top 10 predicted animals:")
for prob, class_idx in zip(top_probabilities, top_class_indices):
    class_label = str(class_idx.item())  # Convert class index to string
    class_name = class_translation.get(class_label, "Unknown")
    print(f"{class_name}: {prob*100:.1f}%")

Top 10 predicted animals:
Cheetah: 88.0%
Poison Dart Frog: 2.7%
Red-bellied Woodpecker: 2.0%
Common Lionfish: 1.4%
Western Honey Bee: 1.0%
Orchard Oriole: 0.9%
Mediterranean Fruit Fly: 0.7%
Scarlet Macaw: 0.5%
Eastern Tiger Swallowtail: 0.4%
Plains Zebra: 0.4%
