In [None]:
import json
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import os
from tqdm import tqdm

# Ensure the directory exists
results_dir = './Simpsons/models'
os.makedirs(results_dir, exist_ok=True)  # Create directory if it doesn't exist

# Define your CNN model (2-class model)
class CNN(nn.Module):
    def __init__(self, num_classes):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(512)

        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)

        self.fc1 = nn.Linear(512 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool(torch.relu(self.bn3(self.conv3(x))))
        x = self.pool(torch.relu(self.bn4(self.conv4(x))))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Function to recursively find all image files in the directory
def find_images(data_dir):
    image_files = []
    for root, _, files in os.walk(data_dir):  # Walk through all subdirectories
        for file in files:
            if file.lower().endswith(('.jpg', '.jpeg', '.png')):  # Check image extensions
                image_files.append(os.path.join(root, file))  # Add image file path
    return image_files

# Inference function
def infer(data_dir, model_path, classes_file='./Simpsons/model/classes.json', results_file='./Simpsons/models/results.json'):
    with open(classes_file, 'r') as f:
        classes = json.load(f)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CNN(num_classes=2).to(device)  # Use 2 classes for the binary classification model
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    results = {}
    image_files = find_images(data_dir)  # Get all images from subdirectories
    print(f"Found {len(image_files)} images to process...")

    for image_path in tqdm(image_files, desc="Processing Images"):
        try:
            image = Image.open(image_path).convert('RGB')
            image = transform(image).unsqueeze(0).to(device)

            with torch.no_grad():
                output = model(image)

            _, predicted = torch.max(output, 1)
            predicted_class = classes[predicted.item()]
            results[os.path.basename(image_path)] = predicted_class  # Save result by image name

        except Exception as e:
            print(f"Error processing {image_path}: {e}")

    # Ensure the results directory exists before saving the results
    os.makedirs(os.path.dirname(results_file), exist_ok=True)

    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"Inference complete! Processed {len(results)} images and saved results to '{results_file}'.")
    return results

# Run inference
infer('./Simpsons/archive', './Simpsons/model/model_best.pth', './Simpsons/model/classes.json')


Found 16765 images to process...


Processing Images:   8%|â–Š         | 1373/16765 [00:32<06:14, 41.07it/s]