In [None]:
import torch
from PIL import Image

def load_model(model_path, device):
    model = SimpleCNN()
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device)
    model.eval()
    return model


In [None]:
def predict_image(image_path, model, transform, device, threshold=0.5):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image)
        prob = torch.sigmoid(output).item()

    label = "OK" if prob >= threshold else "NOT_OK"

    return {
        "prediction": label,
        "confidence": prob
    }


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = load_model("defect_cnn.pkl", device)

result = predict_image(
    image_path=r"C:\test_images\sample.jpg",
    model=model,
    transform=transform,
    device=device
)

print(result)
