In [1]:
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Load the pre-trained models
model_names = {
    'AlexNet': models.alexnet(weights='DEFAULT'),
    'VGG16': models.vgg16(weights='DEFAULT'),
    'ResNet50': models.resnet50(weights='DEFAULT'),
    'InceptionV3': models.inception_v3(weights='DEFAULT', aux_logits=True),
    'DenseNet121': models.densenet121(weights='DEFAULT'),
    'MobileNetV2': models.mobilenet_v2(weights='DEFAULT'),
}

# Set the models to evaluation mode
for model in model_names.values():
    model.eval()

# Define the image transformations
transform_224 = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_299 = transforms.Compose([
    transforms.Resize(340),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the image files
image_dir = '/media'  # Update this to your images directory
image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]
results = []

# Colors for the models
colors = {
    'AlexNet': 'blue',
    'VGG16': 'orange',
    'ResNet50': 'green',
    'InceptionV3': 'red',
    'DenseNet121': 'purple',
    'MobileNetV2': 'cyan',
}

# Classify each image and create plots
for image_file in image_files:
    image_path = os.path.join(image_dir, image_file)
    image_224 = Image.open(image_path).convert("RGB")
    image_299 = Image.open(image_path).convert("RGB")
    
    image_224 = transform_224(image_224).unsqueeze(0)  # Add batch dimension
    image_299 = transform_299(image_299).unsqueeze(0)  # Add batch dimension

    model_predictions = {}

    for model_name, model in model_names.items():
        with torch.no_grad():
            if model_name in ['InceptionV3', 'MobileNetV2']:
                outputs = model(image_299)
            else:
                outputs = model(image_224)

            probabilities = torch.nn.functional.softmax(outputs[0], dim=0)

        # Get the top 1 prediction
        top_prob, top_index = torch.max(probabilities, 0)
        model_predictions[model_name] = (top_index.item(), top_prob.item())
        results.append({
            'image': image_file,
            'model': model_name,
            'predictions': model_predictions[model_name],
        })

    # Generate a horizontal bar plot for predictions
    plt.figure(figsize=(10, 6))
    
    model_names_list = list(model_predictions.keys())
    predicted_classes = [model_predictions[model][0] for model in model_names_list]
    probabilities = [model_predictions[model][1] for model in model_names_list]

    y_pos = np.arange(len(model_names_list))

    plt.barh(y_pos, probabilities, align='center', color=[colors[model] for model in model_names_list])
    plt.yticks(y_pos, model_names_list)
    plt.xlabel('Probability')
    plt.title(f'Predictions for {image_file}')

    # Save the plot to the /home directory
    plt.tight_layout()
    plt.savefig(f'/home/outcome/predictions_{image_file}')
    plt.close()

# Create a DataFrame to store the results
results_df = pd.DataFrame(results)

# Save results to a CSV file
results_df.to_csv('/home/outcome/classification_results.csv', index=False)

# Print the results
for _, row in results_df.iterrows():
    print(f"Image: {row['image']}, Model: {row['model']}, Predicted Class: {row['predictions']}")


Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /Users/mzkk/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth


URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1000)>