In [5]:
import torch
import segmentation_models_pytorch as smp
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import os




In [8]:

# Set device (use GPU if available, otherwise fallback to CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define U-Net model
ENCODER = 'efficientnet-b7'
#ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['solar_panel']
ACTIVATION = 'sigmoid'

model = smp.Unet(
    in_channels = 4, #4 for all bands
    encoder_name=ENCODER, 
    encoder_weights=None,  # No pretraining, since we are loading trained weights
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

# Move model to device
model = model.to(device)

# Load trained weights
weights_path = os.path.join(os.path.expanduser("~"), "satellite-ml-solarp-detection","models", "weights", "u-net_efficientnet-b7_v1", "unet-seed23_weights.pth")

if os.path.exists(weights_path):
    model.load_state_dict(torch.load(weights_path, map_location=device))
    print("Model weights loaded successfully.")
else:
    print("Error: Weights file not found!")

# Set model to evaluation mode
model.eval()

# Define preprocessing function
def preprocess_image(image_path):
    """Loads an image, converts it to a tensor, and normalizes it."""
    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Adjust based on model input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard ImageNet normalization
    ])
    return transform(image).unsqueeze(0)  # Add batch dimension

# Define inference function
def predict(image_path):
    """Runs inference on a given image and displays the result."""
    input_tensor = preprocess_image(image_path).to(device)
    with torch.no_grad():
        output = model(input_tensor)
        output = torch.sigmoid(output).squeeze().cpu().numpy()
    
    # Show original and predicted mask
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(Image.open(image_path))
    ax[0].set_title("Original Image")
    ax[1].imshow(output, cmap="gray")
    ax[1].set_title("Predicted Mask")
    plt.show()

# Example usage
sample_image_path = "path/to/test/image.tif"  # Change this path
if os.path.exists(sample_image_path):
    predict(sample_image_path)
else:
    print("Sample image not found. Please provide a valid path.")


Using device: cpu
Model weights loaded successfully.
Sample image not found. Please provide a valid path.
