In [12]:
import os
import random
import cv2
import numpy as np
from read_roi import read_roi_file
import matplotlib.pyplot as plt
import logging

# Set up basic configuration for logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def get_roi_coordinates(roi_file_path):
    roi = read_roi_file(roi_file_path)
    for box_info in roi.values():
        if box_info['type'] == 'rectangle':
            center_x = box_info['left'] + box_info['width'] // 2
            center_y = box_info['top'] + box_info['height'] // 2
            width = box_info['width']
            height = box_info['height']
            return center_x, center_y, width, height
    return None

def visualize_and_save(image, output_path, title):
    plt.figure(figsize=(10, 5))
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.title(title)
    plt.axis('off')
    plt.savefig(output_path)
    plt.close()

def extract_patches(image, center_x, center_y, ideal_patch_size, output_folder, base_filename):
    visualization_image = image.copy()
    positions = [
        (center_x, center_y),  # Center
        (center_x - ideal_patch_size[1] // 2, center_y - ideal_patch_size[0] // 2),  # Top left
        (center_x + ideal_patch_size[1] // 2, center_y - ideal_patch_size[0] // 2),  # Top right
        (center_x - ideal_patch_size[1] // 2, center_y + ideal_patch_size[0] // 2),  # Bottom left
        (center_x + ideal_patch_size[1] // 2, center_y + ideal_patch_size[0] // 2)   # Bottom right
    ]

    for i, (pos_x, pos_y) in enumerate(positions):
        sp_start_x = max(pos_x - ideal_patch_size[1] // 2, 0)
        sp_start_y = max(pos_y - ideal_patch_size[0] // 2, 0)
        sp_end_x = sp_start_x + ideal_patch_size[1]
        sp_end_y = sp_start_y + ideal_patch_size[0]

        if sp_end_x > image.shape[1] or sp_end_y > image.shape[0]:
            continue

        patch = image[sp_start_y:sp_end_y, sp_start_x:sp_end_x]
        patch_name = f"{base_filename}_patch_{i}.png"
        cv2.imwrite(os.path.join(output_folder, patch_name), patch)
        cv2.rectangle(visualization_image, (sp_start_x, sp_start_y), (sp_end_x, sp_end_y), (0, 255, 0), 3)

    return visualization_image

def extract_normal_patches(image, ideal_patch_size, output_folder, base_filename, num_normal_patches=10):
    h, w, _ = image.shape
    visualization_image = image.copy()
    attempts = 0
    centers = []

    while len(centers) < num_normal_patches and attempts < 100:
        center_x = random.randint(ideal_patch_size[1] // 2, w - ideal_patch_size[1] // 2)
        center_y = random.randint(ideal_patch_size[0] // 2, h - ideal_patch_size[0] // 2)

        sp_start_x = center_x - ideal_patch_size[1] // 2
        sp_start_y = center_y - ideal_patch_size[0] // 2
        sp_end_x = sp_start_x + ideal_patch_size[1]
        sp_end_y = sp_start_y + ideal_patch_size[0]

        if sp_end_x > w or sp_end_y > h:
            continue

        if not any([sp_start_x < ex + ew and sp_end_x > ex and sp_start_y < ey + eh and sp_end_y > ey for ex, ey, ew, eh in centers]):
            centers.append((sp_start_x, sp_start_y, sp_end_x - sp_start_x, sp_end_y - sp_start_y))
            patch = image[sp_start_y:sp_end_y, sp_start_x:sp_end_x]
            patch_name = f"{base_filename}_normal_patch_{len(centers)}.png"
            cv2.imwrite(os.path.join(output_folder, patch_name), patch)
            cv2.rectangle(visualization_image, (sp_start_x, sp_start_y), (sp_end_x, sp_end_y), (0, 255, 0), 3)
        attempts += 1

    return visualization_image

def extract_normal_patches_from_abnormal(image, roi_center_x, roi_center_y, roi_width, roi_height, ideal_patch_size, output_folder, base_filename, num_normal_patches=10):
    h, w, _ = image.shape
    visualization_image = image.copy()
    attempts = 0
    centers = []

    while len(centers) < num_normal_patches and attempts < 100:
        center_x = random.randint(ideal_patch_size[1] // 2, w - ideal_patch_size[1] // 2)
        center_y = random.randint(ideal_patch_size[0] // 2, h - ideal_patch_size[0] // 2)

        sp_start_x = center_x - ideal_patch_size[1] // 2
        sp_start_y = center_y - ideal_patch_size[0] // 2
        sp_end_x = sp_start_x + ideal_patch_size[1]
        sp_end_y = sp_start_y + ideal_patch_size[0]

        if sp_end_x > w or sp_end_y > h:
            continue

        # Ensure the patch does not overlap with the ROI area
        roi_start_x = roi_center_x - roi_width // 2
        roi_start_y = roi_center_y - roi_height // 2
        roi_end_x = roi_center_x + roi_width // 2
        roi_end_y = roi_center_y + roi_height // 2

        if not (sp_end_x > roi_start_x and sp_start_x < roi_end_x and sp_end_y > roi_start_y and sp_start_y < roi_end_y):
            if not any([sp_start_x < ex + ew and sp_end_x > ex and sp_start_y < ey + eh and sp_end_y > ey for ex, ey, ew, eh in centers]):
                centers.append((sp_start_x, sp_start_y, sp_end_x - sp_start_x, sp_end_y - sp_start_y))
                patch = image[sp_start_y:sp_end_y, sp_start_x:sp_end_x]
                patch_name = f"{base_filename}_normal_patch_{len(centers)}.png"
                cv2.imwrite(os.path.join(output_folder, patch_name), patch)
                cv2.rectangle(visualization_image, (sp_start_x, sp_start_y), (sp_end_x, sp_end_y), (0, 255, 0), 3)
        attempts += 1

    return visualization_image

def process_image(folder, normal_output_folder, abnormal_output_folder, visualization_folder, ideal_patch_size=(275, 300), is_abnormal=True):
    os.makedirs(normal_output_folder, exist_ok=True)
    os.makedirs(abnormal_output_folder, exist_ok=True)
    os.makedirs(visualization_folder, exist_ok=True)

    for filename in os.listdir(folder):
        if filename.endswith(('.jpg', '.jpeg', '.png')):
            image_path = os.path.join(folder, filename)
            image = cv2.imread(image_path)
            if image is None:
                logging.warning(f"Failed to read image: {image_path}")
                continue

            base_filename = os.path.splitext(filename)[0]
            if is_abnormal:
                roi_path = os.path.splitext(image_path)[0] + ".roi"
                if os.path.exists(roi_path):
                    center_x, center_y, width, height = get_roi_coordinates(roi_path)
                    visualization_image = extract_patches(image, center_x, center_y, ideal_patch_size, abnormal_output_folder, base_filename)
                    # Extract normal patches from abnormal images
                    normal_visualization_image = extract_normal_patches_from_abnormal(image, center_x, center_y, width, height, ideal_patch_size, normal_output_folder, base_filename, num_normal_patches=10)
                    visualization_image = cv2.addWeighted(visualization_image, 0.5, normal_visualization_image, 0.5, 0)
            else:
                visualization_image = extract_normal_patches(image, ideal_patch_size, normal_output_folder, base_filename, num_normal_patches=10)

            visualization_path = os.path.join(visualization_folder, f"{base_filename}_visualization.png")
            visualize_and_save(visualization_image, visualization_path, "Extracted Patches")

def process_dirs(data_dir, output_dir, visualization_dir, ideal_patch_size=(275, 300)):
    for class_type in ['abnormal', 'normal']:
        input_dir = os.path.join(data_dir, class_type)
        normal_output_sub_dir = os.path.join(output_dir, 'normal')
        abnormal_output_sub_dir = os.path.join(output_dir, 'abnormal')
        visualization_sub_dir = os.path.join(visualization_dir, class_type)
        process_image(input_dir, normal_output_sub_dir, abnormal_output_sub_dir, visualization_sub_dir, ideal_patch_size, is_abnormal=(class_type == 'abnormal'))

# Example usage
data_dir = 'src/Data'
output_dir = 'output_patches_new_3'
visualization_dir = 'visualizations_patches_new_3'

process_dirs(data_dir, output_dir, visualization_dir)
