In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os
import random

In [14]:
# --- Device Configuration ---
device = torch.device("mps" if torch.backends.mps.is_available() and torch.backends.mps.is_built() else "cpu")
print(f"Using {'M1 GPU (MPS backend)' if device.type == 'mps' else 'CPU'} for inference")

# --- Dataset Configuration ---
NUM_CLASSES = 7
# Define CLASS_NAMES in alphabetical order to match training (ImageFolder ordering)
CLASS_NAMES = ['Basalt', 'Coal', 'Granite', 'Limestone', 'Marble', 'Quartzite', 'Sandstone']
DATA_DIR = './Dataset_Augmented/'

# --- Model Setup ---
def load_fine_tuned_model(model_path: str) -> models.ResNet:
    """
    Load and configure the fine-tuned ResNet-50 model for inference.
    
    Args:
        model_path (str): Path to the saved model weights.
    
    Returns:
        models.ResNet: Configured ResNet-50 model ready for inference.
    """
    model = models.resnet50(weights=None)
    
    for param in model.parameters():
        param.requires_grad = False
    
    for name, param in model.named_parameters():
        if "layer4" in name or "fc" in name:
            param.requires_grad = True
    
    # Match the fc layer structure used during training (single Linear layer, no Dropout)
    model.fc = nn.Linear(in_features=model.fc.in_features, out_features=NUM_CLASSES)
    
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    return model

# --- Image Preprocessing ---
def get_preprocessing_transforms() -> transforms.Compose:
    """
    Define the preprocessing pipeline for input images (matches training setup).
    
    Returns:
        transforms.Compose: Preprocessing pipeline for images.
    """
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

# --- Random Image Selection ---
def get_random_image_path(data_dir: str) -> tuple[str, str]:
    """
    Select a random image from a random subcategory in the data directory.
    
    Args:
        data_dir (str): Path to the data directory containing subcategories.
    
    Returns:
        tuple: (path to the random image, name of the subcategory).
    """
    subcategories = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
    if not subcategories:
        raise ValueError(f"No subcategories found in {data_dir}")
    
    selected_subcategory = random.choice(subcategories)
    subcategory_path = os.path.join(data_dir, selected_subcategory)
    
    image_files = [f for f in os.listdir(subcategory_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    if not image_files:
        raise ValueError(f"No images found in {subcategory_path}")
    
    selected_image = random.choice(image_files)
    image_path = os.path.join(subcategory_path, selected_image)
    
    return image_path, selected_subcategory

# --- Image Classification ---
def classify_image(image_path: str, model: models.ResNet) -> tuple[str, float, np.ndarray]:
    """
    Classify an image and return the predicted class, confidence, and probabilities.
    
    Args:
        image_path (str): Path to the image file.
        model (models.ResNet): Fine-tuned ResNet-50 model.
    
    Returns:
        tuple: (predicted class name, confidence percentage, probabilities for all classes).
    """
    image = Image.open(image_path).convert('RGB')
    preprocess = get_preprocessing_transforms()
    image_tensor = preprocess(image).unsqueeze(0)
    image_tensor = image_tensor.to(device)

    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        probabilities = probabilities.cpu().numpy()[0]
        predicted_class_idx = torch.argmax(outputs, dim=1).item()

    predicted_class = CLASS_NAMES[predicted_class_idx]
    predicted_confidence = probabilities[predicted_class_idx] * 100

    print(f"\nPredicted Class: {predicted_class}")
    print(f"Confidence: {predicted_confidence:.2f}%")
    print("\nConfidence Scores for All Classes:")
    for class_name, prob in zip(CLASS_NAMES, probabilities):
        print(f"{class_name}: {prob * 100:.2f}%")

    return predicted_class, predicted_confidence, probabilities

Using M1 GPU (MPS backend) for inference


In [15]:
# --- Main Execution ---
if __name__ == "__main__":
    # Load the fine-tuned model
    MODEL_PATH = 'finetuned_model_resnet50.pth'
    model = load_fine_tuned_model(MODEL_PATH)

    # Select a random image from the dataset
    try:
        random_image_path, subcategory = get_random_image_path(DATA_DIR)
        print(f"\nSelected Image Path: {random_image_path}")
        print(f"Subcategory (Ground Truth): {subcategory}")
        
        # Classify the random image
        predicted_class, confidence, probabilities = classify_image(random_image_path, model)
        
        # Compare prediction with ground truth
        print(f"\nPrediction Matches Ground Truth: {predicted_class.lower() == subcategory.lower()}")
    except Exception as e:
        print(f"Error: {e}")


Selected Image Path: ./Dataset_Augmented/Sandstone/313.jpg
Subcategory (Ground Truth): Sandstone

Predicted Class: Sandstone
Confidence: 99.82%

Confidence Scores for All Classes:
Basalt: 0.00%
Coal: 0.01%
Granite: 0.00%
Limestone: 0.05%
Marble: 0.12%
Quartzite: 0.00%
Sandstone: 99.82%

Prediction Matches Ground Truth: True


## Suggestions for Future Work

Based on the results and loss plot, here are potential next steps:
- **Address Overfitting:** The gap between training and validation loss suggests overfitting. Add regularization techniques like weight decay in the optimizer, or apply dropout to the fc layer
- **Further Fine-Tuning:** Some classes (e.g., Quartzite, Basalt) show limited improvement.Unfreeze additional layers (e.g., layer3) with an even smaller learning rate to adapt more mid-level features.
- **Data Augmentation:** Validation loss plateaus, indicating the model may need more diverse training data. Enhance training transforms with RandomRotation(), RandomAffine() to introduce more variability.
- **Class Imbalance or Data Quality:** Marble still lags despite improvement after model fine tuning and data augmentation. Collect more Marble samples or use class-weighted loss
- **Hyperparameter Tuning:** Perform hyperparameter tuning. Given the limited GPU constraints, this notebook does not include this part but hyperparameter tuning is a valuable next step to optimize learning rates, regularization, and scheduler settings, potentially boosting performance further.

## Conclusion

Fine-tuning layer4 improved the model’s performance, particularly for challenging classes like Marble, by adapting deeper features. However, overfitting and plateauing validation loss suggest room for improvement through regularization, further fine-tuning, and data augmentation. Future work should focus on enhancing generalization and addressing remaining class-specific challenges.