In [5]:
import torch
from torchvision import models, transforms
from torch import nn
from PIL import Image

# Paths
model_path = "melanoma_segmentation/results/saved_models_classification/mobilenet_melanoma.pth"
test_image_path = "data/ISIC-2017_Validation_Data/ISIC_0001769.jpg"

# Load MobileNetV2 with the correct classifier
def load_model(model_path):
    # Load base MobileNetV2 architecture
    mobilenet = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
    
    # Modify the classifier to match training
    num_features = mobilenet.last_channel
    mobilenet.classifier = nn.Sequential(
        nn.Linear(num_features, 1),
        nn.Sigmoid()
    )
    
    # Load the trained weights
    mobilenet.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    mobilenet.eval()  # Set to evaluation mode
    return mobilenet

# Preprocess the image
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to MobileNetV2 input size
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
    ])
    image = Image.open(image_path).convert("RGB")
    return transform(image).unsqueeze(0)  # Add batch dimension

# Perform inference
def classify_image(model, image_tensor):
    with torch.no_grad():
        output = model(image_tensor)
        prediction = output.item()  # Get the prediction as a float
        return prediction

# Main Function
if __name__ == "__main__":
    # Load the trained model
    model = load_model(model_path)

    # Preprocess the test image
    image_tensor = preprocess_image(test_image_path)

    # Classify the image
    prediction = classify_image(model, image_tensor)

    # Print the result
    if prediction > 0.5:
        print(f"Prediction: Melanoma (Confidence: {prediction:.4f})")
    else:
        print(f"Prediction: Non-Melanoma (Confidence: {1 - prediction:.4f})")


Prediction: Non-Melanoma (Confidence: 0.8979)
