In [None]:
import os
import random
import cv2
import numpy as np
from read_roi import read_roi_file
import matplotlib.pyplot as plt
import logging

# Set up basic configuration for logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def get_roi_coordinates(roi_file_path):
    roi = read_roi_file(roi_file_path)
    for box_info in roi.values():
        if box_info['type'] == 'rectangle':
            center_x = box_info['left'] + box_info['width'] // 2
            center_y = box_info['top'] + box_info['height'] // 2
            return center_x, center_y
    return None

def visualize_and_save(image, output_path, title):
    plt.figure(figsize=(10, 5))
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.title(title)
    plt.axis('off')
    plt.savefig(output_path)
    plt.close()

def extract_patches(image, center_x, center_y, ideal_patch_size, output_folder, base_filename):
    visualization_image = image.copy()
    sp_start_x = max(center_x - ideal_patch_size[1] // 2, 0)
    sp_start_y = max(center_y - ideal_patch_size[0] // 2, 0)
    sp_end_x = min(sp_start_x + ideal_patch_size[1], image.shape[1])
    sp_end_y = min(sp_start_y + ideal_patch_size[0], image.shape[0])

    patch = image[sp_start_y:sp_end_y, sp_start_x:sp_end_x]
    patch_name = f"{base_filename}_patch.png"
    cv2.imwrite(os.path.join(output_folder, patch_name), patch)
    cv2.rectangle(visualization_image, (sp_start_x, sp_start_y), (sp_end_x, sp_end_y), (0, 255, 0), 3)

    return visualization_image

def extract_normal_patches(image, ideal_patch_size, output_folder, base_filename):
    h, w, _ = image.shape
    visualization_image = image.copy()
    attempts = 0
    centers = []

    while len(centers) < 2 and attempts < 100:
        center_x = random.randint(w // 4, 3 * w // 4)
        center_y = random.randint(h // 4, 3 * h // 4)

        sp_start_x = max(center_x - ideal_patch_size[1] // 2, 0)
        sp_start_y = max(center_y - ideal_patch_size[0] // 2, 0)
        sp_end_x = min(sp_start_x + ideal_patch_size[1], w)
        sp_end_y = min(sp_start_y + ideal_patch_size[0], h)

        if not any([sp_start_x < ex + ew and sp_end_x > ex and sp_start_y < ey + eh and sp_end_y > ey for ex, ey, ew, eh in centers]):
            centers.append((sp_start_x, sp_start_y, sp_end_x - sp_start_x, sp_end_y - sp_start_y))
            patch = image[sp_start_y:sp_end_y, sp_start_x:sp_end_x]
            patch_name = f"{base_filename}_normal_patch_{len(centers)}.png"
            cv2.imwrite(os.path.join(output_folder, patch_name), patch)
            cv2.rectangle(visualization_image, (sp_start_x, sp_start_y), (sp_end_x, sp_end_y), (0, 255, 0), 3)
        attempts += 1

    return visualization_image

def process_image(folder, output_folder, visualization_folder, ideal_patch_size=(275, 300), is_abnormal=True):
    os.makedirs(output_folder, exist_ok=True)
    os.makedirs(visualization_folder, exist_ok=True)

    for filename in os.listdir(folder):
        if filename.endswith(('.jpg', '.jpeg', '.png')):
            image_path = os.path.join(folder, filename)
            image = cv2.imread(image_path)
            if image is None:
                logging.warning(f"Failed to read image: {image_path}")
                continue

            base_filename = os.path.splitext(filename)[0]
            if is_abnormal:
                roi_path = os.path.splitext(image_path)[0] + ".roi"
                if os.path.exists(roi_path):
                    center_x, center_y = get_roi_coordinates(roi_path)
                    visualization_image = extract_patches(image, center_x, center_y, ideal_patch_size, output_folder, base_filename)
            else:
                visualization_image = extract_normal_patches(image, ideal_patch_size, output_folder, base_filename)

            visualization_path = os.path.join(visualization_folder, f"{base_filename}_visualization.png")
            visualize_and_save(visualization_image, visualization_path, "Extracted Patches")

def process_dirs(data_dir, output_dir, visualization_dir, ideal_patch_size=(275, 300)):
    for class_type in ['normal', 'abnormal']:
        input_dir = os.path.join(data_dir, class_type)
        output_sub_dir = os.path.join(output_dir, class_type)
        visualization_sub_dir = os.path.join(visualization_dir, class_type)
        process_image(input_dir, output_sub_dir, visualization_sub_dir, ideal_patch_size, is_abnormal=(class_type == 'abnormal'))

# Example usage
data_dir = 'Data'
output_dir = 'output_patches_new'
visualization_dir = 'visualizations_patches_new'

process_dirs(data_dir, output_dir, visualization_dir)