<a href="https://colab.research.google.com/github/SandroMuradashvili/CNN/blob/main/inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ============================================================================
# NOTEBOOK 2: inference.ipynb
# Simpsons Character Classification - Inference Pipeline
# ============================================================================

# ============================================================================
# STEP 1: Import Required Libraries
# ============================================================================
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import os
import json

print("✓ Libraries imported")

✓ Libraries imported


In [2]:

# ============================================================================
# STEP 2: Define CNN Architecture (Same as training)
# ============================================================================
class SimpleCNN(nn.Module):
    """Custom CNN for Simpsons Character Classification"""

    def __init__(self, num_classes=41):
        super(SimpleCNN, self).__init__()

        # Convolutional Block 1
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Convolutional Block 2
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Convolutional Block 3
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Convolutional Block 4
        self.conv4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Fully Connected Layers
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 8 * 8, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        return x

print("✓ CNN architecture defined")


✓ CNN architecture defined


In [3]:
# ============================================================================
# STEP 3: Define Inference Function
# ============================================================================
def infer(data_dir, model_path):
    """
    Load trained model and generate predictions for all images in data_dir.

    Args:
        data_dir (str): Path to directory containing test images
        model_path (str): Path to saved model file (.pth)

    Returns:
        dict: Dictionary mapping filename to predicted class name
    """

    print(f"\n{'='*70}")
    print(f"INFERENCE PIPELINE")
    print(f"{'='*70}")
    print(f"Data directory: {data_dir}")
    print(f"Model path: {model_path}")

    # Check if paths exist
    if not os.path.exists(data_dir):
        raise FileNotFoundError(f"Data directory not found: {data_dir}")
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found: {model_path}")

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load model checkpoint
    print("\nLoading model...")
    checkpoint = torch.load(model_path, map_location=device)

    # Extract model info
    if 'class_names' in checkpoint:
        class_names = checkpoint['class_names']
        num_classes = len(class_names)
    elif 'num_classes' in checkpoint:
        num_classes = checkpoint['num_classes']
        # If class_names not saved, create generic ones
        class_names = [f"class_{i}" for i in range(num_classes)]
    else:
        raise ValueError("Model checkpoint doesn't contain class information!")

    print(f"✓ Model loaded with {num_classes} classes")

    # Initialize model
    model = SimpleCNN(num_classes=num_classes)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    print("✓ Model ready for inference")

    # Define image transforms (same as validation)
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Get all image files from data_dir
    print(f"\nScanning directory: {data_dir}")
    image_files = []

    # Check if data_dir contains images directly or in subdirectories
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.lower().endswith(('.jpg', '.jpeg', '.png')) and not file.startswith('.'):
                image_files.append(os.path.join(root, file))

    if len(image_files) == 0:
        raise ValueError(f"No images found in {data_dir}")

    print(f"✓ Found {len(image_files)} images to process")

    # Run inference on all images
    results = {}
    print("\nRunning inference...")

    with torch.no_grad():
        for idx, img_path in enumerate(image_files):
            # Load and preprocess image
            try:
                image = Image.open(img_path).convert('RGB')
                image_tensor = transform(image).unsqueeze(0).to(device)

                # Get prediction
                output = model(image_tensor)
                _, predicted = output.max(1)
                predicted_class = class_names[predicted.item()]

                # Store result with just filename (not full path)
                filename = os.path.basename(img_path)
                results[filename] = predicted_class

                # Print progress every 100 images
                if (idx + 1) % 100 == 0:
                    print(f"  Processed {idx + 1}/{len(image_files)} images...")

            except Exception as e:
                print(f"  Warning: Failed to process {img_path}: {str(e)}")
                continue

    print(f"✓ Inference complete! Processed {len(results)} images")

    # Save results to JSON
    output_file = 'results.json'
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"✓ Results saved to: {output_file}")
    print(f"{'='*70}\n")

    return results

print("✓ Inference function defined")

✓ Inference function defined


In [4]:
# ============================================================================
# STEP 4: Run Inference (Example Usage)
# ============================================================================
# Example: Run inference on test data
# Modify these paths according to your setup

# Path to test data directory (will be provided by professor)
TEST_DATA_DIR = './test_data'  # Change this to actual test directory

# Path to trained model
MODEL_PATH = 'simpson_model.pth'  # Or 'simpson_model_best.pth'

# Check if test data exists (for demonstration)
if os.path.exists(TEST_DATA_DIR):
    print("Running inference on test data...")
    results = infer(TEST_DATA_DIR, MODEL_PATH)

    # Display first few predictions
    print("\nSample predictions:")
    for i, (filename, pred_class) in enumerate(list(results.items())[:5]):
        print(f"  {filename} → {pred_class}")

    print(f"\nTotal predictions: {len(results)}")
    print("Results saved to 'results.json'")
else:
    print(f"\n⚠️  Test data directory not found: {TEST_DATA_DIR}")
    print("This is normal if you haven't created test data yet.")
    print("\nTo use this notebook:")
    print("1. Set TEST_DATA_DIR to your test images directory")
    print("2. Make sure MODEL_PATH points to your trained model")
    print("3. Run all cells")
    print("\nThe infer() function is ready to use!")

# ============================================================================
# STEP 5: Verification
# ============================================================================
print("\n" + "="*70)
print("INFERENCE NOTEBOOK READY!")
print("="*70)
print("\nThis notebook provides:")
print("  ✓ infer(data_dir, model_path) function")
print("  ✓ Automatic results.json generation")
print("  ✓ Support for GPU acceleration")
print("  ✓ Progress tracking during inference")
print("\nTo use:")
print("  1. Update TEST_DATA_DIR to point to your test images")
print("  2. Update MODEL_PATH if needed")
print("  3. Run all cells")
print("="*70)


⚠️  Test data directory not found: ./test_data
This is normal if you haven't created test data yet.

To use this notebook:
1. Set TEST_DATA_DIR to your test images directory
2. Make sure MODEL_PATH points to your trained model
3. Run all cells

The infer() function is ready to use!

INFERENCE NOTEBOOK READY!

This notebook provides:
  ✓ infer(data_dir, model_path) function
  ✓ Automatic results.json generation
  ✓ Support for GPU acceleration
  ✓ Progress tracking during inference

To use:
  1. Update TEST_DATA_DIR to point to your test images
  2. Update MODEL_PATH if needed
  3. Run all cells
