In [None]:
import os
import random
import cv2
import numpy as np
from read_roi import read_roi_file
import matplotlib.pyplot as plt
import logging

# Set up basic configuration for logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def get_roi_coordinates(roi_file_path):
    roi = read_roi_file(roi_file_path)
    for box_info in roi.values():
        if box_info['type'] == 'rectangle':
            center_x = box_info['left'] + box_info['width'] // 2
            center_y = box_info['top'] + box_info['height'] // 2
            return center_x, center_y
    return None

def visualize_and_save(image, output_path, title):
    plt.figure(figsize=(10, 5))
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.title(title)
    plt.axis('off')
    plt.savefig(output_path)
    plt.close()

def extract_patches(image, center_x, center_y, ideal_patch_size, output_folder, base_filename):
    visualization_image = image.copy()
    sp_start_x = max(center_x - ideal_patch_size[1] // 2, 0)
    sp_start_y = max(center_y - ideal_patch_size[0] // 2, 0)
    sp_end_x = min(sp_start_x + ideal_patch_size[1], image.shape[1])
    sp_end_y = min(sp_start_y + ideal_patch_size[0], image.shape[0])

    patch = image[sp_start_y:sp_end_y, sp_start_x:sp_end_x]
    patch_name = f"{base_filename}_patch.png"
    cv2.imwrite(os.path.join(output_folder, patch_name), patch)
    cv2.rectangle(visualization_image, (sp_start_x, sp_start_y), (sp_end_x, sp_end_y), (0, 255, 0), 3)

    return visualization_image

def extract_normal_patches(image, ideal_patch_size, output_folder, base_filename):
    h, w, _ = image.shape
    visualization_image = image.copy()
    attempts = 0
    centers = []

    while len(centers) < 2 and attempts < 100:
        center_x = random.randint(w // 4, 3 * w // 4)
        center_y = random.randint(h // 4, 3 * h // 4)

        sp_start_x = max(center_x - ideal_patch_size[1] // 2, 0)
        sp_start_y = max(center_y - ideal_patch_size[0] // 2, 0)
        sp_end_x = min(sp_start_x + ideal_patch_size[1], w)
        sp_end_y = min(sp_start_y + ideal_patch_size[0], h)

        if not any([sp_start_x < ex + ew and sp_end_x > ex and sp_start_y < ey + eh and sp_end_y > ey for ex, ey, ew, eh in centers]):
            centers.append((sp_start_x, sp_start_y, sp_end_x - sp_start_x, sp_end_y - sp_start_y))
            patch = image[sp_start_y:sp_end_y, sp_start_x:sp_end_x]
            patch_name = f"{base_filename}_normal_patch_{len(centers)}.png"
            cv2.imwrite(os.path.join(output_folder, patch_name), patch)
            cv2.rectangle(visualization_image, (sp_start_x, sp_start_y), (sp_end_x, sp_end_y), (0, 255, 0), 3)
        attempts += 1

    return visualization_image

def process_image(folder, output_folder, visualization_folder, ideal_patch_size=(275, 300), is_abnormal=True):
    os.makedirs(output_folder, exist_ok=True)
    os.makedirs(visualization_folder, exist_ok=True)

    for filename in os.listdir(folder):
        if filename.endswith(('.jpg', '.jpeg', '.png')):
            image_path = os.path.join(folder, filename)
            image = cv2.imread(image_path)
            if image is None:
                logging.warning(f"Failed to read image: {image_path}")
                continue

            base_filename = os.path.splitext(filename)[0]
            if is_abnormal:
                roi_path = os.path.splitext(image_path)[0] + ".roi"
                if os.path.exists(roi_path):
                    center_x, center_y = get_roi_coordinates(roi_path)
                    visualization_image = extract_patches(image, center_x, center_y, ideal_patch_size, output_folder, base_filename)
            else:
                visualization_image = extract_normal_patches(image, ideal_patch_size, output_folder, base_filename)

            visualization_path = os.path.join(visualization_folder, f"{base_filename}_visualization.png")
            visualize_and_save(visualization_image, visualization_path, "Extracted Patches")

def process_dirs(data_dir, output_dir, visualization_dir, ideal_patch_size=(275, 300)):
    for class_type in ['normal', 'abnormal']:
        input_dir = os.path.join(data_dir, class_type)
        output_sub_dir = os.path.join(output_dir, class_type)
        visualization_sub_dir = os.path.join(visualization_dir, class_type)
        process_image(input_dir, output_sub_dir, visualization_sub_dir, ideal_patch_size, is_abnormal=(class_type == 'abnormal'))

# Example usage
data_dir = 'Data'
output_dir = 'output_patches_new'
visualization_dir = 'visualizations_patches_new'

process_dirs(data_dir, output_dir, visualization_dir)

In [2]:
import os
import random
import cv2
import numpy as np
from read_roi import read_roi_file
import matplotlib.pyplot as plt
import logging

# Set up basic configuration for logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def get_roi_coordinates(roi_file_path):
    roi = read_roi_file(roi_file_path)
    for box_info in roi.values():
        if box_info['type'] == 'rectangle':
            center_x = box_info['left'] + box_info['width'] // 2
            center_y = box_info['top'] + box_info['height'] // 2
            return center_x, center_y
    return None

def visualize_and_save(image, output_path, title):
    plt.figure(figsize=(10, 5))
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.title(title)
    plt.axis('off')
    plt.savefig(output_path)
    plt.close()

def extract_patches(image, center_x, center_y, ideal_patch_size, output_folder, base_filename, num_shifts=4, shift_amount=30):
    visualization_image = image.copy()
    shifts = [(0, 0)]  # Initial center
    for _ in range(num_shifts):
        shifts.extend([
            (shift_amount, 0),  # right
            (-shift_amount, 0), # left
            (0, shift_amount),  # down
            (0, -shift_amount)  # up
        ])

    for i, (shift_x, shift_y) in enumerate(shifts):
        sp_start_x = max(center_x + shift_x - ideal_patch_size[1] // 2, 0)
        sp_start_y = max(center_y + shift_y - ideal_patch_size[0] // 2, 0)
        sp_end_x = min(sp_start_x + ideal_patch_size[1], image.shape[1])
        sp_end_y = min(sp_start_y + ideal_patch_size[0], image.shape[0])

        patch = image[sp_start_y:sp_end_y, sp_start_x:sp_end_x]
        patch_name = f"{base_filename}_patch_{i}.png"
        cv2.imwrite(os.path.join(output_folder, patch_name), patch)
        cv2.rectangle(visualization_image, (sp_start_x, sp_start_y), (sp_end_x, sp_end_y), (0, 255, 0), 3)

    return visualization_image

def extract_normal_patches(image, ideal_patch_size, output_folder, base_filename):
    h, w, _ = image.shape
    visualization_image = image.copy()
    attempts = 0
    centers = []

    while len(centers) < 2 and attempts < 100:
        center_x = random.randint(w // 4, 3 * w // 4)
        center_y = random.randint(h // 4, 3 * h // 4)

        sp_start_x = max(center_x - ideal_patch_size[1] // 2, 0)
        sp_start_y = max(center_y - ideal_patch_size[0] // 2, 0)
        sp_end_x = min(sp_start_x + ideal_patch_size[1], w)
        sp_end_y = min(sp_start_y + ideal_patch_size[0], h)

        if not any([sp_start_x < ex + ew and sp_end_x > ex and sp_start_y < ey + eh and sp_end_y > ey for ex, ey, ew, eh in centers]):
            centers.append((sp_start_x, sp_start_y, sp_end_x - sp_start_x, sp_end_y - sp_start_y))
            patch = image[sp_start_y:sp_end_y, sp_start_x:sp_end_x]
            patch_name = f"{base_filename}_normal_patch_{len(centers)}.png"
            cv2.imwrite(os.path.join(output_folder, patch_name), patch)
            cv2.rectangle(visualization_image, (sp_start_x, sp_start_y), (sp_end_x, sp_end_y), (0, 255, 0), 3)
        attempts += 1

    return visualization_image

def process_image(folder, output_folder, visualization_folder, ideal_patch_size=(275, 300), is_abnormal=True):
    os.makedirs(output_folder, exist_ok=True)
    os.makedirs(visualization_folder, exist_ok=True)

    for filename in os.listdir(folder):
        if filename.endswith(('.jpg', '.jpeg', '.png')):
            image_path = os.path.join(folder, filename)
            image = cv2.imread(image_path)
            if image is None:
                logging.warning(f"Failed to read image: {image_path}")
                continue

            base_filename = os.path.splitext(filename)[0]
            if is_abnormal:
                roi_path = os.path.splitext(image_path)[0] + ".roi"
                if os.path.exists(roi_path):
                    center_x, center_y = get_roi_coordinates(roi_path)
                    visualization_image = extract_patches(image, center_x, center_y, ideal_patch_size, output_folder, base_filename)
            else:
                visualization_image = extract_normal_patches(image, ideal_patch_size, output_folder, base_filename)

            visualization_path = os.path.join(visualization_folder, f"{base_filename}_visualization.png")
            visualize_and_save(visualization_image, visualization_path, "Extracted Patches")

def process_dirs(data_dir, output_dir, visualization_dir, ideal_patch_size=(275, 300)):
    for class_type in ['normal', 'abnormal']:
        input_dir = os.path.join(data_dir, class_type)
        output_sub_dir = os.path.join(output_dir, class_type)
        visualization_sub_dir = os.path.join(visualization_dir, class_type)
        process_image(input_dir, output_sub_dir, visualization_sub_dir, ideal_patch_size, is_abnormal=(class_type == 'abnormal'))

# Example usage
data_dir = 'Data'
output_dir = 'output_patches_new_1'
visualization_dir = 'visualizations_patches_new_1'

process_dirs(data_dir, output_dir, visualization_dir)


In [5]:
import os
import random
import cv2
import numpy as np
from read_roi import read_roi_file
import matplotlib.pyplot as plt
import logging

# Set up basic configuration for logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def get_roi_coordinates(roi_file_path):
    roi = read_roi_file(roi_file_path)
    for box_info in roi.values():
        if box_info['type'] == 'rectangle':
            center_x = box_info['left'] + box_info['width'] // 2
            center_y = box_info['top'] + box_info['height'] // 2
            width = box_info['width']
            height = box_info['height']
            return center_x, center_y, width, height
    return None

def visualize_and_save(image, output_path, title):
    plt.figure(figsize=(10, 5))
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.title(title)
    plt.axis('off')
    plt.savefig(output_path)
    plt.close()

def extract_patches(image, center_x, center_y, ideal_patch_size, output_folder, base_filename, num_shifts=4, shift_amount=30):
    visualization_image = image.copy()
    shifts = [(0, 0)]  # Initial center
    for _ in range(num_shifts):
        shifts.extend([
            (shift_amount, 0),  # right
            (-shift_amount, 0), # left
            (0, shift_amount),  # down
            (0, -shift_amount)  # up
        ])

    for i, (shift_x, shift_y) in enumerate(shifts):
        sp_start_x = max(center_x + shift_x - ideal_patch_size[1] // 2, 0)
        sp_start_y = max(center_y + shift_y - ideal_patch_size[0] // 2, 0)
        sp_end_x = sp_start_x + ideal_patch_size[1]
        sp_end_y = sp_start_y + ideal_patch_size[0]

        if sp_end_x > image.shape[1] or sp_end_y > image.shape[0]:
            continue

        patch = image[sp_start_y:sp_end_y, sp_start_x:sp_end_x]
        patch_name = f"{base_filename}_patch_{i}.png"
        cv2.imwrite(os.path.join(output_folder, patch_name), patch)
        cv2.rectangle(visualization_image, (sp_start_x, sp_start_y), (sp_end_x, sp_end_y), (0, 255, 0), 3)

    return visualization_image

def extract_normal_patches(image, ideal_patch_size, output_folder, base_filename, num_normal_patches=10):
    h, w, _ = image.shape
    visualization_image = image.copy()
    attempts = 0
    centers = []

    while len(centers) < num_normal_patches and attempts < 100:
        center_x = random.randint(ideal_patch_size[1] // 2, w - ideal_patch_size[1] // 2)
        center_y = random.randint(ideal_patch_size[0] // 2, h - ideal_patch_size[0] // 2)

        sp_start_x = center_x - ideal_patch_size[1] // 2
        sp_start_y = center_y - ideal_patch_size[0] // 2
        sp_end_x = sp_start_x + ideal_patch_size[1]
        sp_end_y = sp_start_y + ideal_patch_size[0]

        if sp_end_x > w or sp_end_y > h:
            continue

        if not any([sp_start_x < ex + ew and sp_end_x > ex and sp_start_y < ey + eh and sp_end_y > ey for ex, ey, ew, eh in centers]):
            centers.append((sp_start_x, sp_start_y, sp_end_x - sp_start_x, sp_end_y - sp_start_y))
            patch = image[sp_start_y:sp_end_y, sp_start_x:sp_end_x]
            patch_name = f"{base_filename}_normal_patch_{len(centers)}.png"
            cv2.imwrite(os.path.join(output_folder, patch_name), patch)
            cv2.rectangle(visualization_image, (sp_start_x, sp_start_y), (sp_end_x, sp_end_y), (0, 255, 0), 3)
        attempts += 1

    return visualization_image

def extract_normal_patches_from_abnormal(image, roi_center_x, roi_center_y, roi_width, roi_height, ideal_patch_size, output_folder, base_filename, num_normal_patches=10):
    h, w, _ = image.shape
    visualization_image = image.copy()
    attempts = 0
    centers = []

    while len(centers) < num_normal_patches and attempts < 100:
        center_x = random.randint(ideal_patch_size[1] // 2, w - ideal_patch_size[1] // 2)
        center_y = random.randint(ideal_patch_size[0] // 2, h - ideal_patch_size[0] // 2)

        sp_start_x = center_x - ideal_patch_size[1] // 2
        sp_start_y = center_y - ideal_patch_size[0] // 2
        sp_end_x = sp_start_x + ideal_patch_size[1]
        sp_end_y = sp_start_y + ideal_patch_size[0]

        if sp_end_x > w or sp_end_y > h:
            continue

        # Ensure the patch does not overlap with the ROI area
        roi_start_x = roi_center_x - roi_width // 2
        roi_start_y = roi_center_y - roi_height // 2
        roi_end_x = roi_center_x + roi_width // 2
        roi_end_y = roi_center_y + roi_height // 2

        if not (sp_end_x > roi_start_x and sp_start_x < roi_end_x and sp_end_y > roi_start_y and sp_start_y < roi_end_y):
            if not any([sp_start_x < ex + ew and sp_end_x > ex and sp_start_y < ey + eh and sp_end_y > ey for ex, ey, ew, eh in centers]):
                centers.append((sp_start_x, sp_start_y, sp_end_x - sp_start_x, sp_end_y - sp_start_y))
                patch = image[sp_start_y:sp_end_y, sp_start_x:sp_end_x]
                patch_name = f"{base_filename}_normal_patch_{len(centers)}.png"
                cv2.imwrite(os.path.join(output_folder, patch_name), patch)
                cv2.rectangle(visualization_image, (sp_start_x, sp_start_y), (sp_end_x, sp_end_y), (0, 255, 0), 3)
        attempts += 1

    return visualization_image

def process_image(folder, normal_output_folder, abnormal_output_folder, visualization_folder, ideal_patch_size=(275, 300), is_abnormal=True):
    os.makedirs(normal_output_folder, exist_ok=True)
    os.makedirs(abnormal_output_folder, exist_ok=True)
    os.makedirs(visualization_folder, exist_ok=True)

    for filename in os.listdir(folder):
        if filename.endswith(('.jpg', '.jpeg', '.png')):
            image_path = os.path.join(folder, filename)
            image = cv2.imread(image_path)
            if image is None:
                logging.warning(f"Failed to read image: {image_path}")
                continue

            base_filename = os.path.splitext(filename)[0]
            if is_abnormal:
                roi_path = os.path.splitext(image_path)[0] + ".roi"
                if os.path.exists(roi_path):
                    center_x, center_y, width, height = get_roi_coordinates(roi_path)
                    visualization_image = extract_patches(image, center_x, center_y, ideal_patch_size, abnormal_output_folder, base_filename)
                    # Extract normal patches from abnormal images
                    normal_visualization_image = extract_normal_patches_from_abnormal(image, center_x, center_y, width, height, ideal_patch_size, normal_output_folder, base_filename, num_normal_patches=10)
                    visualization_image = cv2.addWeighted(visualization_image, 0.5, normal_visualization_image, 0.5, 0)
            else:
                visualization_image = extract_normal_patches(image, ideal_patch_size, normal_output_folder, base_filename, num_normal_patches=10)

            visualization_path = os.path.join(visualization_folder, f"{base_filename}_visualization.png")
            visualize_and_save(visualization_image, visualization_path, "Extracted Patches")

def process_dirs(data_dir, output_dir, visualization_dir, ideal_patch_size=(275, 300)):
    for class_type in ['normal', 'abnormal']:
        input_dir = os.path.join(data_dir, class_type)
        normal_output_sub_dir = os.path.join(output_dir, 'normal')
        abnormal_output_sub_dir = os.path.join(output_dir, 'abnormal')
        visualization_sub_dir = os.path.join(visualization_dir, class_type)
        process_image(input_dir, normal_output_sub_dir, abnormal_output_sub_dir, visualization_sub_dir, ideal_patch_size, is_abnormal=(class_type == 'abnormal'))

# Example usage
data_dir = 'Data'
output_dir = 'output_patches_new_2'
visualization_dir = 'visualizations_patches_new_2'

process_dirs(data_dir, output_dir, visualization_dir)


In [2]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score
import torchvision.models as models

# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transformations for testing
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load your pre-trained PyTorch models for patches and full images
model_patch = torch.load('Models/denseNet_redo_patches', map_location=torch.device('cpu'))
model_patch.eval()
model_full = torch.load('Models/denseNet_redo_full_images', map_location=torch.device('cpu'))
model_full.eval()

# Transformation pipeline for image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def get_heatmap(input_tensor, model, target_layers):
    cam = GradCAMPlusPlus(model=model, target_layers=target_layers)
    pred = model(input_tensor)
    _, predicted_class = pred.max(1)
    targets = [ClassifierOutputTarget(predicted_class.item())]
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
    return grayscale_cam, predicted_class.item(), torch.softmax(pred, dim=1)[0, predicted_class].item()

def adaptive_thresholding(heatmap, predicted_label):
    """Determine the threshold based on the predicted label.
    - If the label is abnormal (1), apply a stricter threshold (top 25% of values).
    - If the label is normal (0), apply a broader threshold (top 50% of values)."""
    threshold_percentile = 75 if predicted_label == 1 else 95
    threshold_value = np.percentile(heatmap, threshold_percentile)
    return (heatmap >= threshold_value).astype('uint8')

def detect_features_in_channel(channel, mask=None):
    features = cv2.goodFeaturesToTrack(channel, mask=mask, maxCorners=100, qualityLevel=0.01, minDistance=150)
    return np.int0(features).reshape(-1, 2) if features is not None else np.array([])

def weighted_voting(patch_predictions, patch_confidences, image_pred, image_conf):
    """Perform weighted voting based on the confidence scores."""
    vote_count = {0: 0, 1: 0}  # Initialize voting counters for each class

    # Weight adjustments based on the full image prediction
    image_weight = 5
    if image_pred == 1:
        vote_count[1] += image_conf * image_weight
    else:
        vote_count[0] += (1 - image_conf) * image_weight

    # Weight adjustments based on patch predictions
    for pred, conf in zip(patch_predictions, patch_confidences):
        patch_weight = 2 if conf > 0.8 else 1  # Give higher weight to high confidence abnormal patches
        if pred == 1:
            vote_count[1] += conf * patch_weight
        else:
            vote_count[0] += (1 - conf) * patch_weight

    return 1 if vote_count[1] > vote_count[0] else 0

def extract_and_visualize(image_path, label, model_patch, model_full, transform, output_dir):
    image = Image.open(image_path).convert('RGB')
    image_np = np.array(image)
    input_tensor = transform(image).unsqueeze(0)
    
    if torch.cuda.is_available():
        input_tensor = input_tensor.cuda()

    # Full image attention map
    heatmap_full, pred_full, conf_full = get_heatmap(input_tensor, model_full, [model_full.features.norm5])
    heatmap_resized_full = cv2.resize(heatmap_full, (image_np.shape[1], image_np.shape[0]))
    heatmap_normalized_full = heatmap_resized_full / np.max(heatmap_resized_full)

    # Apply adaptive thresholding based on predicted label
    binary_mask_full = adaptive_thresholding(heatmap_normalized_full, pred_full)

    # Overlay Image
    overlay_img_full = cv2.addWeighted(image_np, 0.6, cv2.applyColorMap(np.uint8(255 * heatmap_normalized_full), cv2.COLORMAP_JET), 0.4, 0)

    # Detect salient points
    salient_points_all = detect_features_in_channel(cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY))  # Without mask
    salient_points_filtered = detect_features_in_channel(cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY), binary_mask_full)  # With mask

    # Image with all salient points
    image_with_all_salient_points = image_np.copy()
    for pt in salient_points_all:
        cv2.circle(image_with_all_salient_points, (pt[0], pt[1]), 5, (0, 0, 255), 5)

    # Image with filtered salient points
    image_with_salient_points = image_np.copy()
    for pt in salient_points_filtered:
        cv2.circle(image_with_salient_points, (pt[0], pt[1]), 5, (0, 0, 255), 5)

    # Image with patch classifications
    image_with_classified_patches = image_np.copy()
    patch_predictions = []
    patch_confidences = []
    for pt in salient_points_filtered:
        top_left_x = max(pt[0] - 137, 0)
        top_left_y = max(pt[1] - 150, 0)
        bottom_right_x = min(top_left_x + 275, image_with_classified_patches.shape[1])
        bottom_right_y = min(top_left_y + 300, image_with_classified_patches.shape[0])

        patch = image_with_classified_patches[top_left_y:bottom_right_y, top_left_x:bottom_right_x]
        patch_tensor = torch.from_numpy(np.transpose(cv2.resize(patch, (224, 224)), (2, 0, 1)).astype('float32') / 255.0).unsqueeze(0)
        if torch.cuda.is_available():
            patch_tensor = patch_tensor.cuda()
        with torch.no_grad():
            output = model_patch(patch_tensor)
            prob = torch.softmax(output, dim=1)
            pred = torch.argmax(prob, dim=1)
            conf = prob[0][pred].item()

        patch_predictions.append(pred.item())
        patch_confidences.append(conf)

        color = (0, 255, 0) if pred.item() == 1 else (0, 0, 255)
        cv2.rectangle(image_with_classified_patches, (top_left_x, top_left_y), (bottom_right_x, bottom_right_y), color, 2)
        text = f"{conf:.2f}"  # Confidence score
        cv2.putText(image_with_classified_patches, text, (top_left_x, top_left_y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)

    final_prediction = weighted_voting(patch_predictions, patch_confidences, pred_full, conf_full)

    # 3x2 Grid Visualization
    fig, axs = plt.subplots(3, 2, figsize=(16, 24))
    axs[0, 0].imshow(heatmap_normalized_full, cmap='jet')
    axs[0, 0].axis('off')
    axs[0, 0].set_title(f'Heatmap Full (DenseNet) - Pred: {pred_full}, GT: {label}')

    axs[0, 1].imshow(overlay_img_full)
    axs[0, 1].axis('off')
    axs[0, 1].set_title('Overlay Image Full')

    axs[1, 0].imshow(image_with_all_salient_points)
    axs[1, 0].axis('off')
    axs[1, 0].set_title('All Salient Points')

    axs[1, 1].imshow(image_with_salient_points)
    axs[1, 1].axis('off')
    axs[1, 1].set_title('Filtered Salient Points')

    axs[2, 0].imshow(image_with_classified_patches)
    axs[2, 0].axis('off')
    axs[2, 0].set_title('Patch Classifications')

    # Leave the last grid empty if not needed
    axs[2, 1].axis('off')

    output_filename = f"{os.path.splitext(os.path.basename(image_path))[0]}_composite.png"
    plt.savefig(os.path.join(output_dir, output_filename))
    plt.close()

    return final_prediction, label

# Directory and processing setup
root_dir = 'Data/Testing/Images'
output_dir = 'Processed_Images_pred_700'
os.makedirs(output_dir, exist_ok=True)

all_predictions = []
all_labels = []

for subfolder in ['normal_testing_images', 'abnormal_testing_images']:
    folder_path = os.path.join(root_dir, subfolder)
    label = 0 if subfolder == 'normal_testing_images' else 1
    for image_file in os.listdir(folder_path):
        if image_file.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(folder_path, image_file)
            final_prediction, true_label = extract_and_visualize(image_path, label, model_patch, model_full, transform, output_dir)
            all_predictions.append(final_prediction)
            all_labels.append(true_label)

# Calculate and print accuracy, precision, and recall
accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions)
recall = recall_score(all_labels, all_predictions)

print(f"Completed processing. Check the output directory: {output_dir}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")


Completed processing. Check the output directory: Processed_Images_pred_700
Accuracy: 0.8293
Precision: 0.7733
Recall: 0.9779


In [3]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score
import torchvision.models as models

# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transformations for testing
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load your pre-trained PyTorch models for patches and full images
model_patch = torch.load('Models/denseNet_redo_patches', map_location=torch.device('cpu'))
model_patch.eval()
model_full = torch.load('Models/denseNet_redo_full_images', map_location=torch.device('cpu'))
model_full.eval()

# Transformation pipeline for image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def get_heatmap(input_tensor, model, target_layers):
    cam = GradCAMPlusPlus(model=model, target_layers=target_layers)
    pred = model(input_tensor)
    _, predicted_class = pred.max(1)
    targets = [ClassifierOutputTarget(predicted_class.item())]
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
    return grayscale_cam, predicted_class.item(), torch.softmax(pred, dim=1)[0, predicted_class].item()

def adaptive_thresholding(heatmap, predicted_label):
    """Determine the threshold based on the predicted label.
    - If the label is abnormal (1), apply a stricter threshold (top 25% of values).
    - If the label is normal (0), apply a broader threshold (top 50% of values)."""
    threshold_percentile = 75 if predicted_label == 1 else 95
    threshold_value = np.percentile(heatmap, threshold_percentile)
    return (heatmap >= threshold_value).astype('uint8')

def detect_features_in_channel(channel, mask=None):
    features = cv2.goodFeaturesToTrack(channel, mask=mask, maxCorners=100, qualityLevel=0.01, minDistance=150)
    return np.int0(features).reshape(-1, 2) if features is not None else np.array([])

def weighted_voting(patch_predictions, patch_confidences, image_pred, image_conf):
    """Perform weighted voting based on the confidence scores."""
    vote_count = {0: 0, 1: 0}  # Initialize voting counters for each class

    # Weight adjustments based on the full image prediction
    image_weight = 5
    if image_pred == 1:
        vote_count[1] += image_conf * image_weight
    else:
        vote_count[0] += (1 - image_conf) * image_weight

    # Weight adjustments based on patch predictions
    for pred, conf in zip(patch_predictions, patch_confidences):
        patch_weight = 2 if conf > 0.8 else 1  # Give higher weight to high confidence abnormal patches
        if pred == 1:
            vote_count[1] += conf * patch_weight
        else:
            vote_count[0] += (1 - conf) * patch_weight

    return 1 if vote_count[1] > vote_count[0] else 0

def extract_and_visualize(image_path, label, model_patch, model_full, transform, correct_dir, incorrect_dir):
    image = Image.open(image_path).convert('RGB')
    image_np = np.array(image)
    input_tensor = transform(image).unsqueeze(0)
    
    if torch.cuda.is_available():
        input_tensor = input_tensor.cuda()

    # Full image attention map
    heatmap_full, pred_full, conf_full = get_heatmap(input_tensor, model_full, [model_full.features.norm5])
    heatmap_resized_full = cv2.resize(heatmap_full, (image_np.shape[1], image_np.shape[0]))
    heatmap_normalized_full = heatmap_resized_full / np.max(heatmap_resized_full)

    # Apply adaptive thresholding based on predicted label
    binary_mask_full = adaptive_thresholding(heatmap_normalized_full, pred_full)

    # Overlay Image
    overlay_img_full = cv2.addWeighted(image_np, 0.6, cv2.applyColorMap(np.uint8(255 * heatmap_normalized_full), cv2.COLORMAP_JET), 0.4, 0)

    # Detect salient points
    salient_points_all = detect_features_in_channel(cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY))  # Without mask
    salient_points_filtered = detect_features_in_channel(cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY), binary_mask_full)  # With mask

    # Image with all salient points
    image_with_all_salient_points = image_np.copy()
    for pt in salient_points_all:
        cv2.circle(image_with_all_salient_points, (pt[0], pt[1]), 5, (0, 0, 255), 5)

    # Image with filtered salient points
    image_with_salient_points = image_np.copy()
    for pt in salient_points_filtered:
        cv2.circle(image_with_salient_points, (pt[0], pt[1]), 5, (0, 0, 255), 5)

    # Image with patch classifications
    image_with_classified_patches = image_np.copy()
    patch_predictions = []
    patch_confidences = []
    for pt in salient_points_filtered:
        top_left_x = max(pt[0] - 137, 0)
        top_left_y = max(pt[1] - 150, 0)
        bottom_right_x = min(top_left_x + 275, image_with_classified_patches.shape[1])
        bottom_right_y = min(top_left_y + 300, image_with_classified_patches.shape[0])

        patch = image_with_classified_patches[top_left_y:bottom_right_y, top_left_x:bottom_right_x]
        patch_tensor = torch.from_numpy(np.transpose(cv2.resize(patch, (224, 224)), (2, 0, 1)).astype('float32') / 255.0).unsqueeze(0)
        if torch.cuda.is_available():
            patch_tensor = patch_tensor.cuda()
        with torch.no_grad():
            output = model_patch(patch_tensor)
            prob = torch.softmax(output, dim=1)
            pred = torch.argmax(prob, dim=1)
            conf = prob[0][pred].item()

        patch_predictions.append(pred.item())
        patch_confidences.append(conf)

        color = (0, 255, 0) if pred.item() == 1 else (0, 0, 255)
        cv2.rectangle(image_with_classified_patches, (top_left_x, top_left_y), (bottom_right_x, bottom_right_y), color, 2)
        text = f"{conf:.2f}"  # Confidence score
        cv2.putText(image_with_classified_patches, text, (top_left_x, top_left_y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)

    final_prediction = weighted_voting(patch_predictions, patch_confidences, pred_full, conf_full)

    # 3x2 Grid Visualization
    fig, axs = plt.subplots(3, 2, figsize=(16, 24))
    axs[0, 0].imshow(heatmap_normalized_full, cmap='jet')
    axs[0, 0].axis('off')
    axs[0, 0].set_title(f'Heatmap Full (DenseNet) - Pred: {pred_full}, GT: {label}')

    axs[0, 1].imshow(overlay_img_full)
    axs[0, 1].axis('off')
    axs[0, 1].set_title('Overlay Image Full')

    axs[1, 0].imshow(image_with_all_salient_points)
    axs[1, 0].axis('off')
    axs[1, 0].set_title('All Salient Points')

    axs[1, 1].imshow(image_with_salient_points)
    axs[1, 1].axis('off')
    axs[1, 1].set_title('Filtered Salient Points')

    axs[2, 0].imshow(image_with_classified_patches)
    axs[2, 0].axis('off')
    axs[2, 0].set_title('Patch Classifications')

    # Leave the last grid empty if not needed
    axs[2, 1].axis('off')

    if final_prediction == label:
        output_filename = os.path.join(correct_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_composite.png")
    else:
        output_filename = os.path.join(incorrect_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_composite.png")

    plt.savefig(output_filename)
    plt.close()

    return final_prediction, label

# Directory and processing setup
root_dir = 'Data/Testing/Images'
output_dir_correct = 'Processed_Images_pred_correct'
output_dir_incorrect = 'Processed_Images_pred_incorrect'
os.makedirs(output_dir_correct, exist_ok=True)
os.makedirs(output_dir_incorrect, exist_ok=True)

all_predictions = []
all_labels = []

for subfolder in ['normal_testing_images', 'abnormal_testing_images']:
    folder_path = os.path.join(root_dir, subfolder)
    label = 0 if subfolder == 'normal_testing_images' else 1
    for image_file in os.listdir(folder_path):
        if image_file.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(folder_path, image_file)
            final_prediction, true_label = extract_and_visualize(image_path, label, model_patch, model_full, transform, output_dir_correct, output_dir_incorrect)
            all_predictions.append(final_prediction)
            all_labels.append(true_label)

# Calculate and print accuracy, precision, and recall
accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions)
recall = recall_score(all_labels, all_predictions)

print(f"Completed processing. Check the output directory: {output_dir_correct} and {output_dir_incorrect}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")


Completed processing. Check the output directory: Processed_Images_pred_correct and Processed_Images_pred_incorrect
Accuracy: 0.8293
Precision: 0.7733
Recall: 0.9779


In [5]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score
import torchvision.models as models

# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transformations for testing
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load your pre-trained PyTorch models for patches and full images
model_patch = torch.load('Models/denseNet_redo_patches', map_location=torch.device('cpu'))
model_patch.eval()
model_full = torch.load('Models/denseNet_redo_full_images', map_location=torch.device('cpu'))
model_full.eval()

# Transformation pipeline for image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def get_heatmap(input_tensor, model, target_layers):
    cam = GradCAMPlusPlus(model=model, target_layers=target_layers)
    pred = model(input_tensor)
    _, predicted_class = pred.max(1)
    targets = [ClassifierOutputTarget(predicted_class.item())]
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
    return grayscale_cam, predicted_class.item(), torch.softmax(pred, dim=1)[0, predicted_class].item()

def adaptive_thresholding(heatmap, predicted_label):
    """Determine the threshold based on the predicted label.
    - If the label is abnormal (1), apply a stricter threshold (top 25% of values).
    - If the label is normal (0), apply a broader threshold (top 50% of values)."""
    threshold_percentile = 75 if predicted_label == 1 else 95
    threshold_value = np.percentile(heatmap, threshold_percentile)
    return (heatmap >= threshold_value).astype('uint8')

def detect_features_in_channel(channel, mask=None):
    features = cv2.goodFeaturesToTrack(channel, mask=mask, maxCorners=100, qualityLevel=0.01, minDistance=150)
    return np.int0(features).reshape(-1, 2) if features is not None else np.array([])

def weighted_voting(patch_predictions, patch_confidences, image_pred, image_conf):
    """Perform weighted voting based on the confidence scores."""
    vote_count = {0: 0, 1: 0}  # Initialize voting counters for each class

    # Weight adjustments based on the full image prediction
    image_weight = 5
    if image_pred == 1:
        vote_count[1] += image_conf * image_weight
    else:
        vote_count[0] += (1 - image_conf) * image_weight

    # Weight adjustments based on patch predictions
    for pred, conf in zip(patch_predictions, patch_confidences):
        patch_weight = 2 if conf > 0.8 else 1  # Give higher weight to high confidence abnormal patches
        if pred == 1:
            vote_count[1] += conf * patch_weight
        else:
            vote_count[0] += (1 - conf) * patch_weight

    return 1 if vote_count[1] > vote_count[0] else 0

def extract_and_save_patches(image_path, extracted_patches, patch_predictions, patch_confidences, output_dir_patches):
    base_filename = os.path.splitext(os.path.basename(image_path))[0]
    patch_dir = os.path.join(output_dir_patches, base_filename)
    os.makedirs(patch_dir, exist_ok=True)

    for i, (patch, pred, conf) in enumerate(zip(extracted_patches, patch_predictions, patch_confidences)):
        patch_filename = os.path.join(patch_dir, f"patch_{i}_pred_{pred}_conf_{conf:.2f}.png")
        patch.save(patch_filename)

def extract_and_visualize(image_path, label, model_patch, model_full, transform, correct_dir, incorrect_dir, output_dir_patches):
    image = Image.open(image_path).convert('RGB')
    image_np = np.array(image)
    input_tensor = transform(image).unsqueeze(0)
    
    if torch.cuda.is_available():
        input_tensor = input_tensor.cuda()

    # Full image attention map
    heatmap_full, pred_full, conf_full = get_heatmap(input_tensor, model_full, [model_full.features.norm5])
    heatmap_resized_full = cv2.resize(heatmap_full, (image_np.shape[1], image_np.shape[0]))
    heatmap_normalized_full = heatmap_resized_full / np.max(heatmap_resized_full)

    # Apply adaptive thresholding based on predicted label
    binary_mask_full = adaptive_thresholding(heatmap_normalized_full, pred_full)

    # Overlay Image
    overlay_img_full = cv2.addWeighted(image_np, 0.6, cv2.applyColorMap(np.uint8(255 * heatmap_normalized_full), cv2.COLORMAP_JET), 0.4, 0)

    # Detect salient points
    salient_points_all = detect_features_in_channel(cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY))  # Without mask
    salient_points_filtered = detect_features_in_channel(cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY), binary_mask_full)  # With mask

    # Image with all salient points
    image_with_all_salient_points = image_np.copy()
    for pt in salient_points_all:
        cv2.circle(image_with_all_salient_points, (pt[0], pt[1]), 5, (0, 0, 255), 5)

    # Image with filtered salient points
    image_with_salient_points = image_np.copy()
    for pt in salient_points_filtered:
        cv2.circle(image_with_salient_points, (pt[0], pt[1]), 5, (0, 0, 255), 5)

    # Image with patch classifications
    image_with_classified_patches = image_np.copy()
    patch_predictions = []
    patch_confidences = []
    extracted_patches = []  # Store patches for saving
    for pt in salient_points_filtered:
        top_left_x = max(pt[0] - 137, 0)
        top_left_y = max(pt[1] - 150, 0)
        bottom_right_x = min(top_left_x + 275, image_with_classified_patches.shape[1])
        bottom_right_y = min(top_left_y + 300, image_with_classified_patches.shape[0])

        patch = image_with_classified_patches[top_left_y:bottom_right_y, top_left_x:bottom_right_x]
        extracted_patches.append(Image.fromarray(patch))  # Add patch to list
        patch_tensor = torch.from_numpy(np.transpose(cv2.resize(patch, (224, 224)), (2, 0, 1)).astype('float32') / 255.0).unsqueeze(0)
        if torch.cuda.is_available():
            patch_tensor = patch_tensor.cuda()
        with torch.no_grad():
            output = model_patch(patch_tensor)
            prob = torch.softmax(output, dim=1)
            pred = torch.argmax(prob, dim=1)
            conf = prob[0][pred].item()

        patch_predictions.append(pred.item())
        patch_confidences.append(conf)

        color = (0, 255, 0) if pred.item() == 1 else (0, 0, 255)
        cv2.rectangle(image_with_classified_patches, (top_left_x, top_left_y), (bottom_right_x, bottom_right_y), color, 2)
        text = f"{conf:.2f}"  # Confidence score
        cv2.putText(image_with_classified_patches, text, (top_left_x, top_left_y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)

    final_prediction = weighted_voting(patch_predictions, patch_confidences, pred_full, conf_full)

    # Save extracted patches with predictions and confidences
    extract_and_save_patches(image_path, extracted_patches, patch_predictions, patch_confidences, output_dir_patches)

    # 3x2 Grid Visualization
    fig, axs = plt.subplots(3, 2, figsize=(16, 24))
    axs[0, 0].imshow(heatmap_normalized_full, cmap='jet')
    axs[0, 0].axis('off')
    axs[0, 0].set_title(f'Heatmap Full (DenseNet) - Pred: {pred_full}, GT: {label}')

    axs[0, 1].imshow(overlay_img_full)
    axs[0, 1].axis('off')
    axs[0, 1].set_title('Overlay Image Full')

    axs[1, 0].imshow(image_with_all_salient_points)
    axs[1, 0].axis('off')
    axs[1, 0].set_title('All Salient Points')

    axs[1, 1].imshow(image_with_salient_points)
    axs[1, 1].axis('off')
    axs[1, 1].set_title('Filtered Salient Points')

    axs[2, 0].imshow(image_with_classified_patches)
    axs[2, 0].axis('off')
    axs[2, 0].set_title('Patch Classifications')

    # Leave the last grid empty if not needed
    axs[2, 1].axis('off')

    if final_prediction == label:
        output_filename = os.path.join(correct_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_composite.png")
    else:
        output_filename = os.path.join(incorrect_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_composite.png")

    plt.savefig(output_filename)
    plt.close()

    return final_prediction, label

# Directory and processing setup
root_dir = 'Data/Testing/Images'
output_dir_correct = 'Processed_Images_pred_correct'
output_dir_incorrect = 'Processed_Images_pred_incorrect'
output_dir_patches = 'Processed_Patches'
os.makedirs(output_dir_correct, exist_ok=True)
os.makedirs(output_dir_incorrect, exist_ok=True)
os.makedirs(output_dir_patches, exist_ok=True)

all_predictions = []
all_labels = []

for subfolder in ['normal_testing_images', 'abnormal_testing_images']:
    folder_path = os.path.join(root_dir, subfolder)
    label = 0 if subfolder == 'normal_testing_images' else 1
    for image_file in os.listdir(folder_path):
        if image_file.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(folder_path, image_file)
            final_prediction, true_label = extract_and_visualize(image_path, label, model_patch, model_full, transform, output_dir_correct, output_dir_incorrect, output_dir_patches)
            all_predictions.append(final_prediction)
            all_labels.append(true_label)

# Calculate and print accuracy, precision, and recall
accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions)
recall = recall_score(all_labels, all_predictions)

print(f"Completed processing. Check the output directories: {output_dir_correct}, {output_dir_incorrect}, and {output_dir_patches}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")


Completed processing. Check the output directories: Processed_Images_pred_correct, Processed_Images_pred_incorrect, and Processed_Patches
Accuracy: 0.8293
Precision: 0.7733
Recall: 0.9779


In [19]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score
import torchvision.models as models

# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transformations for testing
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load your pre-trained PyTorch models for patches and full images
model_patch = torch.load('Models/Patch_Model_All_Data_DenseNet', map_location=torch.device('cpu'))
model_patch.eval()
model_full = torch.load('Models/denseNet_redo_full_images', map_location=torch.device('cpu'))
model_full.eval()

# Transformation pipeline for image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def get_heatmap(input_tensor, model, target_layers):
    cam = GradCAMPlusPlus(model=model, target_layers=target_layers)
    pred = model(input_tensor)
    _, predicted_class = pred.max(1)
    targets = [ClassifierOutputTarget(predicted_class.item())]
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
    return grayscale_cam, predicted_class.item(), torch.softmax(pred, dim=1)[0, predicted_class].item()

def adaptive_thresholding(heatmap, predicted_label):
    """Determine the threshold based on the predicted label.
    - If the label is abnormal (1), apply a stricter threshold (top 25% of values).
    - If the label is normal (0), apply a broader threshold (top 50% of values)."""
    threshold_percentile = 75 if predicted_label == 1 else 95
    threshold_value = np.percentile(heatmap, threshold_percentile)
    return (heatmap >= threshold_value).astype('uint8')

def detect_features_in_channel(channel, mask=None):
    features = cv2.goodFeaturesToTrack(channel, mask=mask, maxCorners=100, qualityLevel=0.01, minDistance=150)
    return np.int0(features).reshape(-1, 2) if features is not None else np.array([])

# def weighted_voting(patch_predictions, patch_confidences, image_pred, image_conf):
#     """Perform weighted voting based on the confidence scores."""
#     vote_count = {0: 0, 1: 0}  # Initialize voting counters for each class

#     # Calculate the percentage of patches predicted as abnormal
#     abnormal_patches = sum(1 for pred in patch_predictions if pred == 1)
#     normal_patches = len(patch_predictions) - abnormal_patches

#     # Adjust the weight of the full image prediction based on patch consistency
#     if abnormal_patches > normal_patches:
#         # Increase weight of full image prediction if majority of patches are abnormal
#         image_weight = 5 + (abnormal_patches - normal_patches)
#     else:
#         # Decrease weight of full image prediction if majority of patches are normal
#         image_weight = max(1, 5 - (normal_patches - abnormal_patches))

#     if image_pred == 1:
#         vote_count[1] += image_conf * image_weight
#     else:
#         vote_count[0] += (1 - image_conf) * image_weight

#     # Weight adjustments based on patch predictions
#     for pred, conf in zip(patch_predictions, patch_confidences):
#         patch_weight = 2 if conf > 0.8 else 1  # Give higher weight to high confidence abnormal patches
#         if pred == 1:
#             vote_count[1] += conf * patch_weight
#         else:
#             vote_count[0] += (1 - conf) * patch_weight

#     return 1 if vote_count[1] > vote_count[0] else 0

# def weighted_voting(patch_predictions, patch_confidences, image_pred, image_conf):
#     """Perform weighted voting based on the confidence scores."""
#     vote_count = {0: 0, 1: 0}  # Initialize voting counters for each class

#     # Weight adjustments based on the full image prediction
#     image_weight = 5
#     if image_pred == 1:
#         vote_count[1] += image_conf * image_weight
#     else:
#         vote_count[0] += (1 - image_conf) * image_weight 

#     # Weight adjustments based on patch predictions
#     for pred, conf in zip(patch_predictions, patch_confidences):
#         patch_weight = 2 if conf > 0.8 else 1  # Give higher weight to high confidence abnormal patches
#         if pred == 1:
#             vote_count[1] += conf * patch_weight
#         else:
#             vote_count[0] += (1 - conf) * patch_weight

#     return 1 if vote_count[1] > vote_count[0] else 0

def weighted_voting(patch_predictions, patch_confidences, image_pred, image_conf):
    """Perform weighted voting based on the confidence scores."""
    vote_count = {0: 0, 1: 0}  # Initialize voting counters for each class

    # Weight adjustments based on the full image prediction
    image_weight = 5
    if image_pred == 1:
        vote_count[1] += image_conf * image_weight

        for pred, conf in zip(patch_predictions, patch_confidences):
            patch_weight = 2 if conf > 0.8 else 1
            if pred == 1:
                vote_count[1] += conf * patch_weight
            else:
                vote_count[0] += (1 - conf) * patch_weight
    else:
        vote_count[0] += (1 - image_conf) * image_weight * 5

        for pred, conf in zip(patch_predictions, patch_confidences):
            patch_weight = 3 if conf > 0.8 else 2
            if pred == 1:
                vote_count[1] += conf * 1.5
            else:
                vote_count[0] += (1 - conf) * patch_weight

    return 1 if vote_count[1] > vote_count[0] else 0



def save_composite_patches(image_path, extracted_patches, patch_predictions, patch_confidences, output_dir_patches):
    base_filename = os.path.splitext(os.path.basename(image_path))[0]
    num_patches = len(extracted_patches)
    num_cols = 5
    num_rows = (num_patches + num_cols - 1) // num_cols  # Ensure we have enough rows

    fig, axs = plt.subplots(num_rows, num_cols, figsize=(20, 4 * num_rows))
    axs = axs.ravel()  # Flatten the array for easy indexing

    for i, (patch, pred, conf) in enumerate(zip(extracted_patches, patch_predictions, patch_confidences)):
        axs[i].imshow(patch)
        axs[i].axis('off')
        axs[i].set_title(f'Pred: {pred}, Conf: {conf:.2f}')

    for j in range(i + 1, len(axs)):  # Hide any unused subplots
        axs[j].axis('off')

    plt.tight_layout()
    patch_output_filename = os.path.join(output_dir_patches, f"{base_filename}_patches.png")
    plt.savefig(patch_output_filename)
    plt.close()

def extract_and_visualize(image_path, label, model_patch, model_full, transform, correct_dir, incorrect_dir, output_dir_patches):
    image = Image.open(image_path).convert('RGB')
    image_np = np.array(image)
    input_tensor = transform(image).unsqueeze(0)
    
    if torch.cuda.is_available():
        input_tensor = input_tensor.cuda()

    # Full image attention map
    heatmap_full, pred_full, conf_full = get_heatmap(input_tensor, model_full, [model_full.features.norm5])
    heatmap_resized_full = cv2.resize(heatmap_full, (image_np.shape[1], image_np.shape[0]))
    heatmap_normalized_full = heatmap_resized_full / np.max(heatmap_resized_full)

    # Apply adaptive thresholding based on predicted label
    binary_mask_full = adaptive_thresholding(heatmap_normalized_full, pred_full)

    # Overlay Image
    overlay_img_full = cv2.addWeighted(image_np, 0.6, cv2.applyColorMap(np.uint8(255 * heatmap_normalized_full), cv2.COLORMAP_JET), 0.4, 0)

    # Detect salient points
    salient_points_all = detect_features_in_channel(cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY))  # Without mask
    salient_points_filtered = detect_features_in_channel(cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY), binary_mask_full)  # With mask

    # Image with all salient points
    image_with_all_salient_points = image_np.copy()
    for pt in salient_points_all:
        cv2.circle(image_with_all_salient_points, (pt[0], pt[1]), 5, (0, 0, 255), 5)

    # Image with filtered salient points
    image_with_salient_points = image_np.copy()
    for pt in salient_points_filtered:
        cv2.circle(image_with_salient_points, (pt[0], pt[1]), 5, (0, 0, 255), 5)

    # Image with patch classifications
    image_with_classified_patches = image_np.copy()
    patch_image = image_np.copy()
    patch_predictions = []
    patch_confidences = []
    extracted_patches = []  # Store patches for saving
    composite_patches = []
    for pt in salient_points_filtered:
        top_left_x = max(pt[0] - 137, 0)
        top_left_y = max(pt[1] - 150, 0)
        bottom_right_x = min(top_left_x + 275, image_with_classified_patches.shape[1])
        bottom_right_y = min(top_left_y + 300, image_with_classified_patches.shape[0])

        patch = image_with_classified_patches[top_left_y:bottom_right_y, top_left_x:bottom_right_x]
        patch_ = patch_image[top_left_y:bottom_right_y, top_left_x:bottom_right_x]
        extracted_patches.append(patch.copy())  # Add patch to list
        composite_patches.append(patch_)
        patch_tensor = torch.from_numpy(np.transpose(cv2.resize(patch, (224, 224)), (2, 0, 1)).astype('float32') / 255.0).unsqueeze(0)
        if torch.cuda.is_available():
            patch_tensor = patch_tensor.cuda()
        with torch.no_grad():
            output = model_patch(patch_tensor)
            prob = torch.softmax(output, dim=1)
            pred = torch.argmax(prob, dim=1)
            conf = prob[0][pred].item()

        patch_predictions.append(pred.item())
        patch_confidences.append(conf)

        color = (0, 255, 0) if pred.item() == 1 else (0, 0, 255)
        cv2.rectangle(image_with_classified_patches, (top_left_x, top_left_y), (bottom_right_x, bottom_right_y), color, 2)
        text = f"{conf:.2f}"  # Confidence score
        cv2.putText(image_with_classified_patches, text, (top_left_x, top_left_y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)

    final_prediction = weighted_voting(patch_predictions, patch_confidences, pred_full, conf_full)

    # Save composite patches with predictions and confidences
    save_composite_patches(image_path, composite_patches, patch_predictions, patch_confidences, output_dir_patches)

    # 3x2 Grid Visualization
    fig, axs = plt.subplots(3, 2, figsize=(16, 24))
    axs[0, 0].imshow(heatmap_normalized_full, cmap='jet')
    axs[0, 0].axis('off')
    axs[0, 0].set_title(f'Heatmap Full (DenseNet) - Pred: {pred_full}, GT: {label}')

    axs[0, 1].imshow(overlay_img_full)
    axs[0, 1].axis('off')
    axs[0, 1].set_title('Overlay Image Full')

    axs[1, 0].imshow(image_with_all_salient_points)
    axs[1, 0].axis('off')
    axs[1, 0].set_title('All Salient Points')

    axs[1, 1].imshow(image_with_salient_points)
    axs[1, 1].axis('off')
    axs[1, 1].set_title('Filtered Salient Points')

    axs[2, 0].imshow(image_with_classified_patches)
    axs[2, 0].axis('off')
    axs[2, 0].set_title('Patch Classifications')

    # Leave the last grid empty if not needed
    axs[2, 1].axis('off')

    if final_prediction == label:
        output_filename = os.path.join(correct_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_composite.png")
    else:
        output_filename = os.path.join(incorrect_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_composite.png")

    plt.savefig(output_filename)
    plt.close()

    return final_prediction, label

# Directory and processing setup
root_dir = 'Data/Testing/Images'
output_dir_correct = 'Processed_Images_pred_correct_3'
output_dir_incorrect = 'Processed_Images_pred_incorrect_3'
output_dir_patches = 'Processed_Patches_3'
os.makedirs(output_dir_correct, exist_ok=True)
os.makedirs(output_dir_incorrect, exist_ok=True)
os.makedirs(output_dir_patches, exist_ok=True)

all_predictions = []
all_labels = []

for subfolder in ['abnormal_testing_images', 'normal_testing_images']:
    folder_path = os.path.join(root_dir, subfolder)
    label = 0 if subfolder == 'normal_testing_images' else 1
    for image_file in os.listdir(folder_path):
        if image_file.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(folder_path, image_file)
            final_prediction, true_label = extract_and_visualize(image_path, label, model_patch, model_full, transform, output_dir_correct, output_dir_incorrect, output_dir_patches)
            all_predictions.append(final_prediction)
            all_labels.append(true_label)

# Calculate and print accuracy, precision, and reca
accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions)
recall = recall_score(all_labels, all_predictions)

print(f"Completed processing. Check the output directories: {output_dir_correct}, {output_dir_incorrect}, and {output_dir_patches}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")


# This one seems to classify patches better. Just work on the weighted voting now.

Completed processing. Check the output directories: Processed_Images_pred_correct_3, Processed_Images_pred_incorrect_3, and Processed_Patches_3
Accuracy: 0.8130
Precision: 0.7528
Recall: 0.9853


In [22]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score
import torchvision.models as models

# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transformations for testing
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load your pre-trained PyTorch models for patches and full images
model_patch = torch.load('Models/Patch_Model_All_Data_DenseNet', map_location=torch.device('cpu'))
model_patch.eval()
model_full = torch.load('Models/denseNet_redo_full_images', map_location=torch.device('cpu'))
model_full.eval()

# Transformation pipeline for image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def get_heatmap(input_tensor, model, target_layers):
    cam = GradCAMPlusPlus(model=model, target_layers=target_layers)
    pred = model(input_tensor)
    _, predicted_class = pred.max(1)
    targets = [ClassifierOutputTarget(predicted_class.item())]
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
    return grayscale_cam, predicted_class.item(), torch.softmax(pred, dim=1)[0, predicted_class].item()

def adaptive_thresholding(heatmap, predicted_label):
    """Determine the threshold based on the predicted label.
    - If the label is abnormal (1), apply a stricter threshold (top 25% of values).
    - If the label is normal (0), apply a broader threshold (top 50% of values)."""
    threshold_percentile = 75 if predicted_label == 1 else 95
    threshold_value = np.percentile(heatmap, threshold_percentile)
    return (heatmap >= threshold_value).astype('uint8')

def detect_features_in_channel(channel, mask=None):
    features = cv2.goodFeaturesToTrack(channel, mask=mask, maxCorners=100, qualityLevel=0.01, minDistance=150)
    return np.int0(features).reshape(-1, 2) if features is not None else np.array([])

# def weighted_voting(patch_predictions, patch_confidences, image_pred, image_conf):
#     """Perform weighted voting based on the confidence scores."""
#     vote_count = {0: 0, 1: 0}  # Initialize voting counters for each class

#     # Calculate the percentage of patches predicted as abnormal
#     abnormal_patches = sum(1 for pred in patch_predictions if pred == 1)
#     normal_patches = len(patch_predictions) - abnormal_patches

#     # Adjust the weight of the full image prediction based on patch consistency
#     if abnormal_patches > normal_patches:
#         # Increase weight of full image prediction if majority of patches are abnormal
#         image_weight = 5 + (abnormal_patches - normal_patches)
#     else:
#         # Decrease weight of full image prediction if majority of patches are normal
#         image_weight = max(1, 5 - (normal_patches - abnormal_patches))

#     if image_pred == 1:
#         vote_count[1] += image_conf * image_weight
#     else:
#         vote_count[0] += (1 - image_conf) * image_weight

#     # Weight adjustments based on patch predictions
#     for pred, conf in zip(patch_predictions, patch_confidences):
#         patch_weight = 2 if conf > 0.8 else 1  # Give higher weight to high confidence abnormal patches
#         if pred == 1:
#             vote_count[1] += conf * patch_weight
#         else:
#             vote_count[0] += (1 - conf) * patch_weight

#     return 1 if vote_count[1] > vote_count[0] else 0

# def weighted_voting(patch_predictions, patch_confidences, image_pred, image_conf):
#     """Perform weighted voting based on the confidence scores."""
#     vote_count = {0: 0, 1: 0}  # Initialize voting counters for each class

#     # Weight adjustments based on the full image prediction
#     image_weight = 5
#     if image_pred == 1:
#         vote_count[1] += image_conf * image_weight
#     else:
#         vote_count[0] += (1 - image_conf) * image_weight 

#     # Weight adjustments based on patch predictions
#     for pred, conf in zip(patch_predictions, patch_confidences):
#         patch_weight = 2 if conf > 0.8 else 1  # Give higher weight to high confidence abnormal patches
#         if pred == 1:
#             vote_count[1] += conf * patch_weight
#         else:
#             vote_count[0] += (1 - conf) * patch_weight

#     return 1 if vote_count[1] > vote_count[0] else 0

def weighted_voting(patch_predictions, patch_confidences, image_pred, image_conf):
    """Perform weighted voting based on the confidence scores."""
    vote_count = {0: 0, 1: 0}  # Initialize voting counters for each class

    # Weight adjustments based on the full image prediction
    image_weight = 2
    if image_pred == 1:
        vote_count[1] += image_conf * image_weight

        for pred, conf in zip(patch_predictions, patch_confidences):
            patch_weight = 2 if conf > 0.8 else 1
            if pred == 1:
                vote_count[1] += conf * patch_weight
            else:
                vote_count[0] += (1 - conf) * patch_weight
    else:
        vote_count[0] += (1 - image_conf) * image_weight

        for pred, conf in zip(patch_predictions, patch_confidences):
            patch_weight = 3 if conf > 0.8 else 2
            if pred == 1:
                vote_count[1] += conf * 1.5
            else:
                vote_count[0] += (1 - conf) * patch_weight

    return 1 if vote_count[1] > vote_count[0] else 0



def save_composite_patches(image_path, extracted_patches, patch_predictions, patch_confidences, output_dir_patches):
    base_filename = os.path.splitext(os.path.basename(image_path))[0]
    num_patches = len(extracted_patches)
    num_cols = 5
    num_rows = (num_patches + num_cols - 1) // num_cols  # Ensure we have enough rows

    fig, axs = plt.subplots(num_rows, num_cols, figsize=(20, 4 * num_rows))
    axs = axs.ravel()  # Flatten the array for easy indexing

    for i, (patch, pred, conf) in enumerate(zip(extracted_patches, patch_predictions, patch_confidences)):
        axs[i].imshow(patch)
        axs[i].axis('off')
        axs[i].set_title(f'Pred: {pred}, Conf: {conf:.2f}')

    for j in range(i + 1, len(axs)):  # Hide any unused subplots
        axs[j].axis('off')

    plt.tight_layout()
    patch_output_filename = os.path.join(output_dir_patches, f"{base_filename}_patches.png")
    plt.savefig(patch_output_filename)
    plt.close()

def extract_and_visualize(image_path, label, model_patch, model_full, transform, correct_dir, incorrect_dir, output_dir_patches):
    image = Image.open(image_path).convert('RGB')
    image_np = np.array(image)
    input_tensor = transform(image).unsqueeze(0)
    
    if torch.cuda.is_available():
        input_tensor = input_tensor.cuda()

    # Full image attention map
    heatmap_full, pred_full, conf_full = get_heatmap(input_tensor, model_full, [model_full.features.norm5])
    heatmap_resized_full = cv2.resize(heatmap_full, (image_np.shape[1], image_np.shape[0]))
    heatmap_normalized_full = heatmap_resized_full / np.max(heatmap_resized_full)

    # Apply adaptive thresholding based on predicted label
    binary_mask_full = adaptive_thresholding(heatmap_normalized_full, pred_full)

    # Overlay Image
    overlay_img_full = cv2.addWeighted(image_np, 0.6, cv2.applyColorMap(np.uint8(255 * heatmap_normalized_full), cv2.COLORMAP_JET), 0.4, 0)

    # Detect salient points
    salient_points_all = detect_features_in_channel(cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY))  # Without mask
    salient_points_filtered = detect_features_in_channel(cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY), binary_mask_full)  # With mask

    # Image with all salient points
    image_with_all_salient_points = image_np.copy()
    for pt in salient_points_all:
        cv2.circle(image_with_all_salient_points, (pt[0], pt[1]), 5, (0, 0, 255), 5)

    # Image with filtered salient points
    image_with_salient_points = image_np.copy()
    for pt in salient_points_filtered:
        cv2.circle(image_with_salient_points, (pt[0], pt[1]), 5, (0, 0, 255), 5)

    # Image with patch classifications
    image_with_classified_patches = image_np.copy()
    patch_image = image_np.copy()
    patch_predictions = []
    patch_confidences = []
    extracted_patches = []  # Store patches for saving
    composite_patches = []
    for pt in salient_points_filtered:
        top_left_x = max(pt[0] - 137, 0)
        top_left_y = max(pt[1] - 150, 0)
        bottom_right_x = min(top_left_x + 275, image_with_classified_patches.shape[1])
        bottom_right_y = min(top_left_y + 300, image_with_classified_patches.shape[0])

        patch = image_with_classified_patches[top_left_y:bottom_right_y, top_left_x:bottom_right_x]
        patch_ = patch_image[top_left_y:bottom_right_y, top_left_x:bottom_right_x]
        extracted_patches.append(patch.copy())  # Add patch to list
        composite_patches.append(patch_)
        patch_tensor = torch.from_numpy(np.transpose(cv2.resize(patch, (224, 224)), (2, 0, 1)).astype('float32') / 255.0).unsqueeze(0)
        if torch.cuda.is_available():
            patch_tensor = patch_tensor.cuda()
        with torch.no_grad():
            output = model_patch(patch_tensor)
            prob = torch.softmax(output, dim=1)
            pred = torch.argmax(prob, dim=1)
            conf = prob[0][pred].item()

        patch_predictions.append(pred.item())
        patch_confidences.append(conf)

        color = (0, 255, 0) if pred.item() == 1 else (0, 0, 255)
        cv2.rectangle(image_with_classified_patches, (top_left_x, top_left_y), (bottom_right_x, bottom_right_y), color, 2)
        text = f"{conf:.2f}"  # Confidence score
        cv2.putText(image_with_classified_patches, text, (top_left_x, top_left_y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)

    final_prediction = weighted_voting(patch_predictions, patch_confidences, pred_full, conf_full)

    # Save composite patches with predictions and confidences
    save_composite_patches(image_path, composite_patches, patch_predictions, patch_confidences, output_dir_patches)

    # 3x2 Grid Visualization
    fig, axs = plt.subplots(3, 2, figsize=(16, 24))
    axs[0, 0].imshow(heatmap_normalized_full, cmap='jet')
    axs[0, 0].axis('off')
    axs[0, 0].set_title(f'Heatmap Full (DenseNet) - Pred: {pred_full}, GT: {label}, Final_Pred: {final_prediction}')

    axs[0, 1].imshow(overlay_img_full)
    axs[0, 1].axis('off')
    axs[0, 1].set_title('Overlay Image Full')

    axs[1, 0].imshow(image_with_all_salient_points)
    axs[1, 0].axis('off')
    axs[1, 0].set_title('All Salient Points')

    axs[1, 1].imshow(image_with_salient_points)
    axs[1, 1].axis('off')
    axs[1, 1].set_title('Filtered Salient Points')

    axs[2, 0].imshow(image_with_classified_patches)
    axs[2, 0].axis('off')
    axs[2, 0].set_title('Patch Classifications')

    # Leave the last grid empty if not needed
    axs[2, 1].axis('off')

    if final_prediction == label:
        output_filename = os.path.join(correct_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_composite.png")
    else:
        output_filename = os.path.join(incorrect_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_composite.png")

    plt.savefig(output_filename)
    plt.close()

    return final_prediction, label

# Directory and processing setup
root_dir = 'Data/Testing/Images'
output_dir_correct = 'Processed_Images_pred_correct_4'
output_dir_incorrect = 'Processed_Images_pred_incorrect_4'
output_dir_patches = 'Processed_Patches_4'
os.makedirs(output_dir_correct, exist_ok=True)
os.makedirs(output_dir_incorrect, exist_ok=True)
os.makedirs(output_dir_patches, exist_ok=True)

all_predictions = []
all_labels = []

for subfolder in ['abnormal_testing_images', 'normal_testing_images']:
    folder_path = os.path.join(root_dir, subfolder)
    label = 0 if subfolder == 'normal_testing_images' else 1
    for image_file in os.listdir(folder_path):
        if image_file.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(folder_path, image_file)
            final_prediction, true_label = extract_and_visualize(image_path, label, model_patch, model_full, transform, output_dir_correct, output_dir_incorrect, output_dir_patches)
            all_predictions.append(final_prediction)
            all_labels.append(true_label)

# Calculate and print accuracy, precision, and reca
accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions)
recall = recall_score(all_labels, all_predictions)

print(f"Completed processing. Check the output directories: {output_dir_correct}, {output_dir_incorrect}, and {output_dir_patches}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")

# This one seems to classify patches better. Just work on the weighted voting now.

Completed processing. Check the output directories: Processed_Images_pred_correct_4, Processed_Images_pred_incorrect_4, and Processed_Patches_4
Accuracy: 0.8130
Precision: 0.7528
Recall: 0.9853


In [23]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score
import torchvision.models as models

# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transformations for testing
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load your pre-trained PyTorch models for patches and full images
model_patch = torch.load('Models/Patch_Model_All_Data_DenseNet_2', map_location=torch.device('cpu'))
model_patch.eval()
model_full = torch.load('Models/denseNet_redo_full_images', map_location=torch.device('cpu'))
model_full.eval()

# Transformation pipeline for image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def get_heatmap(input_tensor, model, target_layers):
    cam = GradCAMPlusPlus(model=model, target_layers=target_layers)
    pred = model(input_tensor)
    _, predicted_class = pred.max(1)
    targets = [ClassifierOutputTarget(predicted_class.item())]
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
    return grayscale_cam, predicted_class.item(), torch.softmax(pred, dim=1)[0, predicted_class].item()

def adaptive_thresholding(heatmap, predicted_label):
    """Determine the threshold based on the predicted label.
    - If the label is abnormal (1), apply a stricter threshold (top 25% of values).
    - If the label is normal (0), apply a broader threshold (top 50% of values)."""
    threshold_percentile = 75 if predicted_label == 1 else 95
    threshold_value = np.percentile(heatmap, threshold_percentile)
    return (heatmap >= threshold_value).astype('uint8')

def detect_features_in_channel(channel, mask=None):
    features = cv2.goodFeaturesToTrack(channel, mask=mask, maxCorners=100, qualityLevel=0.01, minDistance=150)
    return np.int0(features).reshape(-1, 2) if features is not None else np.array([])

# def weighted_voting(patch_predictions, patch_confidences, image_pred, image_conf):
#     """Perform weighted voting based on the confidence scores."""
#     vote_count = {0: 0, 1: 0}  # Initialize voting counters for each class

#     # Calculate the percentage of patches predicted as abnormal
#     abnormal_patches = sum(1 for pred in patch_predictions if pred == 1)
#     normal_patches = len(patch_predictions) - abnormal_patches

#     # Adjust the weight of the full image prediction based on patch consistency
#     if abnormal_patches > normal_patches:
#         # Increase weight of full image prediction if majority of patches are abnormal
#         image_weight = 5 + (abnormal_patches - normal_patches)
#     else:
#         # Decrease weight of full image prediction if majority of patches are normal
#         image_weight = max(1, 5 - (normal_patches - abnormal_patches))

#     if image_pred == 1:
#         vote_count[1] += image_conf * image_weight
#     else:
#         vote_count[0] += (1 - image_conf) * image_weight

#     # Weight adjustments based on patch predictions
#     for pred, conf in zip(patch_predictions, patch_confidences):
#         patch_weight = 2 if conf > 0.8 else 1  # Give higher weight to high confidence abnormal patches
#         if pred == 1:
#             vote_count[1] += conf * patch_weight
#         else:
#             vote_count[0] += (1 - conf) * patch_weight

#     return 1 if vote_count[1] > vote_count[0] else 0

# def weighted_voting(patch_predictions, patch_confidences, image_pred, image_conf):
#     """Perform weighted voting based on the confidence scores."""
#     vote_count = {0: 0, 1: 0}  # Initialize voting counters for each class

#     # Weight adjustments based on the full image prediction
#     image_weight = 5
#     if image_pred == 1:
#         vote_count[1] += image_conf * image_weight
#     else:
#         vote_count[0] += (1 - image_conf) * image_weight 

#     # Weight adjustments based on patch predictions
#     for pred, conf in zip(patch_predictions, patch_confidences):
#         patch_weight = 2 if conf > 0.8 else 1  # Give higher weight to high confidence abnormal patches
#         if pred == 1:
#             vote_count[1] += conf * patch_weight
#         else:
#             vote_count[0] += (1 - conf) * patch_weight

#     return 1 if vote_count[1] > vote_count[0] else 0

def weighted_voting(patch_predictions, patch_confidences, image_pred, image_conf):
    """Perform weighted voting based on the confidence scores."""
    vote_count = {0: 0, 1: 0}  # Initialize voting counters for each class

    # Weight adjustments based on the full image prediction
    image_weight = 2
    if image_pred == 1:
        vote_count[1] += image_conf * image_weight

        for pred, conf in zip(patch_predictions, patch_confidences):
            patch_weight = 2 if conf > 0.8 else 1
            if pred == 1:
                vote_count[1] += conf * patch_weight
            else:
                vote_count[0] += (1 - conf) * patch_weight
    else:
        vote_count[0] += (1 - image_conf) * image_weight

        for pred, conf in zip(patch_predictions, patch_confidences):
            patch_weight = 3 if conf > 0.8 else 2
            if pred == 1:
                vote_count[1] += conf * 1.5
            else:
                vote_count[0] += (1 - conf) * patch_weight

    return 1 if vote_count[1] > vote_count[0] else 0



def save_composite_patches(image_path, extracted_patches, patch_predictions, patch_confidences, output_dir_patches):
    base_filename = os.path.splitext(os.path.basename(image_path))[0]
    num_patches = len(extracted_patches)
    num_cols = 5
    num_rows = (num_patches + num_cols - 1) // num_cols  # Ensure we have enough rows

    fig, axs = plt.subplots(num_rows, num_cols, figsize=(20, 4 * num_rows))
    axs = axs.ravel()  # Flatten the array for easy indexing

    for i, (patch, pred, conf) in enumerate(zip(extracted_patches, patch_predictions, patch_confidences)):
        axs[i].imshow(patch)
        axs[i].axis('off')
        axs[i].set_title(f'Pred: {pred}, Conf: {conf:.2f}')

    for j in range(i + 1, len(axs)):  # Hide any unused subplots
        axs[j].axis('off')

    plt.tight_layout()
    patch_output_filename = os.path.join(output_dir_patches, f"{base_filename}_patches.png")
    plt.savefig(patch_output_filename)
    plt.close()

def extract_and_visualize(image_path, label, model_patch, model_full, transform, correct_dir, incorrect_dir, output_dir_patches):
    image = Image.open(image_path).convert('RGB')
    image_np = np.array(image)
    input_tensor = transform(image).unsqueeze(0)
    
    if torch.cuda.is_available():
        input_tensor = input_tensor.cuda()

    # Full image attention map
    heatmap_full, pred_full, conf_full = get_heatmap(input_tensor, model_full, [model_full.features.norm5])
    heatmap_resized_full = cv2.resize(heatmap_full, (image_np.shape[1], image_np.shape[0]))
    heatmap_normalized_full = heatmap_resized_full / np.max(heatmap_resized_full)

    # Apply adaptive thresholding based on predicted label
    binary_mask_full = adaptive_thresholding(heatmap_normalized_full, pred_full)

    # Overlay Image
    overlay_img_full = cv2.addWeighted(image_np, 0.6, cv2.applyColorMap(np.uint8(255 * heatmap_normalized_full), cv2.COLORMAP_JET), 0.4, 0)

    # Detect salient points
    salient_points_all = detect_features_in_channel(cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY))  # Without mask
    salient_points_filtered = detect_features_in_channel(cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY), binary_mask_full)  # With mask

    # Image with all salient points
    image_with_all_salient_points = image_np.copy()
    for pt in salient_points_all:
        cv2.circle(image_with_all_salient_points, (pt[0], pt[1]), 5, (0, 0, 255), 5)

    # Image with filtered salient points
    image_with_salient_points = image_np.copy()
    for pt in salient_points_filtered:
        cv2.circle(image_with_salient_points, (pt[0], pt[1]), 5, (0, 0, 255), 5)

    # Image with patch classifications
    image_with_classified_patches = image_np.copy()
    patch_image = image_np.copy()
    patch_predictions = []
    patch_confidences = []
    extracted_patches = []  # Store patches for saving
    composite_patches = []
    for pt in salient_points_filtered:
        top_left_x = max(pt[0] - 137, 0)
        top_left_y = max(pt[1] - 150, 0)
        bottom_right_x = min(top_left_x + 275, image_with_classified_patches.shape[1])
        bottom_right_y = min(top_left_y + 300, image_with_classified_patches.shape[0])

        patch = image_with_classified_patches[top_left_y:bottom_right_y, top_left_x:bottom_right_x]
        patch_ = patch_image[top_left_y:bottom_right_y, top_left_x:bottom_right_x]
        extracted_patches.append(patch.copy())  # Add patch to list
        composite_patches.append(patch_)
        patch_tensor = torch.from_numpy(np.transpose(cv2.resize(patch, (224, 224)), (2, 0, 1)).astype('float32') / 255.0).unsqueeze(0)
        if torch.cuda.is_available():
            patch_tensor = patch_tensor.cuda()
        with torch.no_grad():
            output = model_patch(patch_tensor)
            prob = torch.softmax(output, dim=1)
            pred = torch.argmax(prob, dim=1)
            conf = prob[0][pred].item()

        patch_predictions.append(pred.item())
        patch_confidences.append(conf)

        color = (0, 255, 0) if pred.item() == 1 else (0, 0, 255)
        cv2.rectangle(image_with_classified_patches, (top_left_x, top_left_y), (bottom_right_x, bottom_right_y), color, 2)
        text = f"{conf:.2f}"  # Confidence score
        cv2.putText(image_with_classified_patches, text, (top_left_x, top_left_y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)

    final_prediction = weighted_voting(patch_predictions, patch_confidences, pred_full, conf_full)

    # Save composite patches with predictions and confidences
    save_composite_patches(image_path, composite_patches, patch_predictions, patch_confidences, output_dir_patches)

    # 3x2 Grid Visualization
    fig, axs = plt.subplots(3, 2, figsize=(16, 24))
    axs[0, 0].imshow(heatmap_normalized_full, cmap='jet')
    axs[0, 0].axis('off')
    axs[0, 0].set_title(f'Heatmap Full (DenseNet) - Pred: {pred_full}, GT: {label}, Final_Pred: {final_prediction}')

    axs[0, 1].imshow(overlay_img_full)
    axs[0, 1].axis('off')
    axs[0, 1].set_title('Overlay Image Full')

    axs[1, 0].imshow(image_with_all_salient_points)
    axs[1, 0].axis('off')
    axs[1, 0].set_title('All Salient Points')

    axs[1, 1].imshow(image_with_salient_points)
    axs[1, 1].axis('off')
    axs[1, 1].set_title('Filtered Salient Points')

    axs[2, 0].imshow(image_with_classified_patches)
    axs[2, 0].axis('off')
    axs[2, 0].set_title('Patch Classifications')

    # Leave the last grid empty if not needed
    axs[2, 1].axis('off')

    if final_prediction == label:
        output_filename = os.path.join(correct_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_composite.png")
    else:
        output_filename = os.path.join(incorrect_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}_composite.png")

    plt.savefig(output_filename)
    plt.close()

    return final_prediction, label

# Directory and processing setup
root_dir = 'Data/Testing/Images'
output_dir_correct = 'Processed_Images_pred_correct_5'
output_dir_incorrect = 'Processed_Images_pred_incorrect_5'
output_dir_patches = 'Processed_Patches_5'
os.makedirs(output_dir_correct, exist_ok=True)
os.makedirs(output_dir_incorrect, exist_ok=True)
os.makedirs(output_dir_patches, exist_ok=True)

all_predictions = []
all_labels = []

for subfolder in ['abnormal_testing_images', 'normal_testing_images']:
    folder_path = os.path.join(root_dir, subfolder)
    label = 0 if subfolder == 'normal_testing_images' else 1
    for image_file in os.listdir(folder_path):
        if image_file.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(folder_path, image_file)
            final_prediction, true_label = extract_and_visualize(image_path, label, model_patch, model_full, transform, output_dir_correct, output_dir_incorrect, output_dir_patches)
            all_predictions.append(final_prediction)
            all_labels.append(true_label)

# Calculate and print accuracy, precision, and reca
accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions)
recall = recall_score(all_labels, all_predictions)

print(f"Completed processing. Check the output directories: {output_dir_correct}, {output_dir_incorrect}, and {output_dir_patches}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")

# This one seems to classify patches better. Just work on the weighted voting now.

Completed processing. Check the output directories: Processed_Images_pred_correct_5, Processed_Images_pred_incorrect_5, and Processed_Patches_5
Accuracy: 0.8089
Precision: 0.7514
Recall: 0.9779
