In [None]:
## test.py
import torch
import torch.nn as nnE
import numpy as np
import cv2
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data preprocessing (MUST match training preprocessing)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Load test dataset
test_dataset = datasets.ImageFolder(r"C:\Users\S NEEREJ\Desktop\Defect Dectetion\dataset_metal_surface\test", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Load trained model
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, 6)  # 6 classes for defect detection
model.load_state_dict(torch.load(r"C:\Users\S NEEREJ\Desktop\Defect Dectetion\best_model.pth", map_location=device))
model = model.to(device)
model.eval()  # Set model to evaluation mode

# Class labels
CLASS_LABELS = ["Crazing", "Inclusion", "Patches", "Pitted", "Rolled", "Scratches"]

# Grad-CAM implementation
def grad_cam(model, image):
    model.eval()
    image = image.unsqueeze(0).to(device)
    
    features, gradients = [], []
    def forward_hook(module, input, output):
        features.append(output)
    def backward_hook(module, grad_in, grad_out):
        gradients.append(grad_out[0])
    
    target_layer = model.layer4[-1]
    target_layer.register_forward_hook(forward_hook)
    target_layer.register_backward_hook(backward_hook)
    
    output = model(image)
    class_idx = torch.argmax(output, dim=1).item()
    model.zero_grad()
    output[0, class_idx].backward()
    
    grads = gradients[0].mean(dim=[2, 3], keepdim=True)
    cam = torch.relu((features[0] * grads).sum(dim=1)).squeeze()
    cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False)
    cam = cam.squeeze().cpu().detach().numpy()
    cam = (cam - cam.min()) / (cam.max() - cam.min())
    
    return cam, class_idx

# Masking function
def apply_mask(heatmap, threshold=0.6):
    return np.where(heatmap >= threshold, 255, 0).astype(np.uint8)

# Bounding box function
def get_bounding_box(mask):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if len(contours) == 0:
        return None
    return cv2.boundingRect(max(contours, key=cv2.contourArea))

# Heatmap overlay function
def overlay_heatmap(image, heatmap):
    image = np.array(image)
    heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    return cv2.addWeighted(image, 0.5, heatmap, 0.5, 0)

# Inference function
def infer_and_visualize(model, image_path):
    image_pil = Image.open(image_path).convert("RGB")
    image_tensor = transform(image_pil).to(device)
    heatmap, class_idx = grad_cam(model, image_tensor)
    defect_name = CLASS_LABELS[class_idx]
    
    mask = apply_mask(heatmap)
    bbox = get_bounding_box(mask)
    image_cv = np.array(image_pil)
    image_cv = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
    
    if bbox:
        x, y, w, h = bbox
        cv2.rectangle(image_cv, (x, y), (x + w, y + h), (0, 255, 0), 2)
        cv2.putText(image_cv, defect_name, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
    
    image_cv = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)
    overlay = overlay_heatmap(image_pil, heatmap)
    
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 3, 1)
    plt.imshow(image_pil)
    plt.title("Original Image")
    plt.subplot(1, 3, 2)
    plt.imshow(overlay)
    plt.title(f"Grad-CAM Heatmap: {defect_name}")
    plt.subplot(1, 3, 3)
    plt.imshow(image_cv)
    plt.title("Defect Detection with Bounding Box")
    plt.show()
    
    print(f"Prediction: {defect_name}")
    if bbox:
        print(f"Bounding Box: x={bbox[0]}, y={bbox[1]}, width={bbox[2]}, height={bbox[3]}")
    else:
        print("No significant defect detected.")

# Testing function
def test(model, test_loader):
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Test Accuracy: {correct / total:.4f}")

# Run testing
test(model, test_loader)

# Example inference
image_path =r"C:\Users\S NEEREJ\Desktop\Defect Dectetion\dataset_metal_surface\test\Pitted\PS_101.bmp"
infer_and_visualize(model, image_path)


FileNotFoundError: [WinError 3] The system cannot find the path specified: 'E:\\Defect Dectetion\\dataset_metal_surface\\test'