In [None]:
# ======================================================
# 1. Load Model & Classes (Safe Loader)
# ======================================================
import os, pickle, torch
import torch.nn as nn
from torchvision import models
from google.colab import files

model_path = "/content/best_model.pth"
classes_path = "/content/classes.pkl"

# --- Check model file ---
if not os.path.exists(model_path):
    print("📥 best_model.pth not found. Please upload it.")
    uploaded = files.upload()
    if "best_model.pth" in uploaded:
        print("✅ best_model.pth uploaded successfully")
    else:
        raise FileNotFoundError("❌ best_model.pth not provided")

# --- Check classes file ---
if not os.path.exists(classes_path):
    print("📥 classes.pkl not found. Please upload it.")
    uploaded = files.upload()
    if "classes.pkl" in uploaded:
        print("✅ classes.pkl uploaded successfully")
    else:
        raise FileNotFoundError("❌ classes.pkl not provided")

# --- Load classes ---
with open(classes_path, "rb") as f:
    classes = pickle.load(f)
print("📂 Classes loaded:", classes)

# --- Setup model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🚀 Using device: {device}")

model = models.resnet50(weights=None)  # same arch used in training
model.fc = nn.Linear(model.fc.in_features, len(classes))
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

print("✅ Model and classes loaded successfully!")

In [None]:
# ======================================================
# 2. Upload & Test Images
# ======================================================
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms

# Preprocessing (must match training time)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

print("📤 Please upload one or more images to test:")
uploaded = files.upload()

for filename in uploaded.keys():
    try:
        # Load image
        image = Image.open(filename).convert("RGB")
        input_tensor = transform(image).unsqueeze(0).to(device)

        # Predict
        with torch.no_grad():
            outputs = model(input_tensor)
            _, predicted = outputs.max(1)
        pred_class = classes[predicted.item()]

        # Show result
        plt.figure(figsize=(5,5))
        plt.imshow(image)
        plt.axis("off")
        plt.title(f"🔮 Predicted: {pred_class}", fontsize=14, color="red")
        plt.show()

        print(f"✅ Prediction for {filename}: {pred_class}")

    except Exception as e:
        print(f"❌ Error processing {filename}: {e}")