In [5]:
# predict_image.py

import argparse
from pathlib import Path
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from prepare_data import create_dataloaders

# ----------------------
# CONFIGURATION
# ----------------------
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_PATH = 'resnet50_marburglens.pth'

# ----------------------
# TRANSFORM (must match validation transforms)
# ----------------------
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# ----------------------
# FUNCTION: Load model
# ----------------------
def load_model(num_classes=29):
    model = models.resnet50(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.eval()
    model.to(DEVICE)
    return model

# ----------------------
# FUNCTION: Predict image
# ----------------------
def predict_image(image_path):
    # Load model
    model = load_model()

    # Load classes
    train_loader, _, _ = create_dataloaders(batch_size=32)
    class_names = train_loader.dataset.classes

    # Check image exists
    image_path = Path(image_path)
    if not image_path.exists():
        print(f"❌ Image not found: {image_path}")
        return

    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0).to(DEVICE)

    # Prediction
    with torch.no_grad():
        outputs = model(input_tensor)
        _, preds = torch.max(outputs, 1)

    predicted_class = class_names[preds.item()]
    print(f"\n✅ Predicted Building: {predicted_class}\n")

# ----------------------
# MAIN
# ----------------------
if __name__ == '__main__':
    import sys
    if 'ipykernel' in sys.modules:
        # Running inside Jupyter
        predict_image('test_images/myphoto.jpg')  # put your image path here manually
    else:
        parser = argparse.ArgumentParser(description='Predict building from an image')
        parser.add_argument('--img', type=str, required=True, help='Path to the input image')
        args = parser.parse_args()
        predict_image(args.img)


  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))


Datasets ready: 1212 train, 335 val, 205 test samples.

✅ Predicted Building: Universitätsstraße 25

