In [3]:
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import torch.nn.functional as F
import os

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the saved model
model_path = 'streak_detector.pth'
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
num_features = model.fc.in_features
model.fc = torch.nn.Linear(num_features, 2)  # Adjust for 2 classes
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()

# Define the transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet expects 224x224 images
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet normalization
])

# Prediction function
def predict_image(image_path, model, transform):
    """
    Predict the class of a single image using the trained model.

    Args:
        image_path (str): Path to the PNG image file.
        model: Loaded PyTorch model.
        transform: Image transformation function.

    Returns:
        Tuple: Prediction label (str), probability of streak, probability of no streak.
    """
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)  # Transform and add batch dimension
    output = model(image)  # Get logits
    probabilities = F.softmax(output, dim=1)  # Convert logits to probabilities
    prob_streak, prob_no_streak = probabilities[0].tolist()  # Extract probabilities
    prediction = "streak" if prob_streak > prob_no_streak else "no_streak"
    return prediction, prob_streak, prob_no_streak

# Example usage
image_path = './path/to/your/image.png'  # Replace with your PNG file path
image_path = 'C:/Users/Administrator/Documents/GitHub/satellite-detecting/Data/Nowy folder/tic13.png'
if os.path.exists(image_path):
    prediction, prob_streak, prob_no_streak = predict_image(image_path, model, transform)
    print(f"Prediction: {prediction}")
    print(f"Probability of streak: {prob_streak:.2f}")
    print(f"Probability of no streak: {prob_no_streak:.2f}")
else:
    print(f"Error: File '{image_path}' does not exist.")


Prediction: streak
Probability of streak: 1.00
Probability of no streak: 0.00
