In [None]:
import os
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

print("‚úÖ Libraries imported successfully!")

### Configuration

In [None]:
# Paths
INPUT_DIR = "./input"
MODEL_DIR = "./model"

# Find the latest model file
model_files = [f for f in os.listdir(MODEL_DIR) if f.endswith('.pth')]
if model_files:
    MODEL_PATH = os.path.join(MODEL_DIR, model_files[0])
    print(f"‚úÖ Found model: {MODEL_PATH}")
else:
    MODEL_PATH = None
    print("‚ùå No model found in model/ directory")
    print("   Please copy a trained model (.pth file) to the model/ folder")

# Supported image extensions
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']

### Load Model

In [None]:
if MODEL_PATH and os.path.exists(MODEL_PATH):
    # Load checkpoint
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"üñ•Ô∏è  Using device: {device}")
    
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    
    # Get model info
    CLASS_NAMES = checkpoint.get('class_names', [f'Class_{i}' for i in range(checkpoint.get('num_classes', 10))])
    NUM_CLASSES = checkpoint.get('num_classes', len(CLASS_NAMES))
    MODEL_ARCH = checkpoint.get('model_arch', 'resnet50')
    IMG_SIZE = checkpoint.get('img_size', (224, 224))
    NORMALIZE_MEAN = checkpoint.get('normalize_mean', [0.485, 0.456, 0.406])
    NORMALIZE_STD = checkpoint.get('normalize_std', [0.229, 0.224, 0.225])
    
    print(f"\nüìä Model Info:")
    print(f"   Architecture: {MODEL_ARCH}")
    print(f"   Classes: {NUM_CLASSES}")
    print(f"   Class Names: {CLASS_NAMES}")
    print(f"   Image Size: {IMG_SIZE}")
    
    # Load model architecture
    if MODEL_ARCH == 'resnet50':
        model = models.resnet50(weights=None)
        in_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(in_features, NUM_CLASSES)
        )
    elif MODEL_ARCH == 'efficientnet_b3':
        model = models.efficientnet_b3(weights=None)
        in_features = model.classifier[1].in_features
        model.classifier = nn.Sequential(
            nn.Dropout(0.4, inplace=True),
            nn.Linear(in_features, NUM_CLASSES)
        )
    else:
        raise ValueError(f"Unknown model architecture: {MODEL_ARCH}")
    
    # Load weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    print("\n‚úÖ Model loaded successfully!")
    
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD)
    ])
else:
    print("‚ùå Cannot load model. Please check the model path.")

### Classify Images

In [None]:
def classify_image(image_path, model, transform, device, class_names):
    """Classify a single image"""
    # Load and transform image
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        confidence, predicted = probabilities.max(1)
    
    pred_class = class_names[predicted.item()]
    pred_confidence = confidence.item() * 100
    
    # Get top-3 predictions
    top3_probs, top3_indices = probabilities.topk(min(3, len(class_names)))
    top3 = [(class_names[idx.item()], prob.item() * 100) 
            for idx, prob in zip(top3_indices[0], top3_probs[0])]
    
    return pred_class, pred_confidence, top3, image


if MODEL_PATH and os.path.exists(MODEL_PATH):
    # Get all images from input folder
    input_images = []
    for f in os.listdir(INPUT_DIR):
        ext = os.path.splitext(f)[1].lower()
        if ext in IMAGE_EXTENSIONS:
            input_images.append(os.path.join(INPUT_DIR, f))
    
    if not input_images:
        print(f"‚ö†Ô∏è No images found in {INPUT_DIR}")
        print(f"   Supported formats: {IMAGE_EXTENSIONS}")
    else:
        print(f"üìÇ Found {len(input_images)} image(s) to classify\n")
        
        # Calculate grid layout
        n_images = len(input_images)
        cols = min(3, n_images)
        rows = (n_images + cols - 1) // cols
        
        fig, axes = plt.subplots(rows, cols, figsize=(5*cols, 5*rows))
        if n_images == 1:
            axes = [axes]
        else:
            axes = axes.ravel() if n_images > 1 else [axes]
        
        for idx, img_path in enumerate(input_images):
            pred_class, confidence, top3, image = classify_image(
                img_path, model, transform, device, CLASS_NAMES
            )
            
            # Display image
            axes[idx].imshow(image)
            axes[idx].axis('off')
            
            # Title with prediction
            title = f"Predicted: {pred_class}\nConfidence: {confidence:.1f}%"
            color = 'green' if confidence > 80 else 'orange' if confidence > 50 else 'red'
            axes[idx].set_title(title, fontsize=12, fontweight='bold', color=color)
            
            # Print details
            print(f"üì∑ {os.path.basename(img_path)}")
            print(f"   Prediction: {pred_class} ({confidence:.1f}%)")
            print(f"   Top-3: {top3}")
            print()
        
        # Hide empty subplots
        for idx in range(n_images, len(axes)):
            axes[idx].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        print("‚úÖ Classification complete!")

### Single Image Classification (Interactive)

In [None]:
# Classify a specific image
# Change this path to classify a specific image
SPECIFIC_IMAGE = "input/your_image.jpg"  # Change this

if MODEL_PATH and os.path.exists(SPECIFIC_IMAGE):
    pred_class, confidence, top3, image = classify_image(
        SPECIFIC_IMAGE, model, transform, device, CLASS_NAMES
    )
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Show image
    axes[0].imshow(image)
    axes[0].axis('off')
    axes[0].set_title(f"Predicted: {pred_class} ({confidence:.1f}%)", fontsize=14)
    
    # Show confidence bars
    classes = [t[0] for t in top3]
    probs = [t[1] for t in top3]
    colors = ['green' if p > 80 else 'orange' if p > 50 else 'lightcoral' for p in probs]
    
    axes[1].barh(classes, probs, color=colors)
    axes[1].set_xlabel('Confidence (%)')
    axes[1].set_title('Top-3 Predictions')
    axes[1].set_xlim(0, 100)
    
    plt.tight_layout()
    plt.show()
elif MODEL_PATH:
    print(f"‚ö†Ô∏è Image not found: {SPECIFIC_IMAGE}")