In [None]:
# Script 1 (Revised v4): Large Image Background Extraction & Layout Analysis
# Added per-file statistics printing. Parameters at the end.
# 
# 

import os
import json
from PIL import Image, ImageStat, UnidentifiedImageError # Pillow for large images
import numpy as np
import cv2 # OpenCV for greyscale conversion if needed
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm # Use tqdm.notebook for Jupyter!
import math
import warnings
from collections import defaultdict

# Suppress DecompressionBombWarning if images are huge, but be aware of risks
# Image.MAX_IMAGE_PIXELS = None # Uncomment if needed, use with caution

# Suppress specific OpenCV warnings if they occur
warnings.filterwarnings("ignore", category=UserWarning, module='cv2')

# --- Helper Functions ---
# (Helper functions: get_bbox_from_points, check_object_validity, 
#  parse_labelme_json_and_analyze, check_slice_overlap, 
#  is_slice_predominantly_black, save_analysis_plots remain the same as v3)
# --- Re-including them here for completeness of the script block ---

def get_bbox_from_points(points):
    """Calculates the bounding box (x1, y1, x2, y2) from a list of points."""
    if not points or len(points) < 1:
        return 0, 0, 0, 0
    try:
        x_coords = [p[0] for p in points]
        y_coords = [p[1] for p in points]
        x1 = min(x_coords)
        y1 = min(y_coords)
        x2 = max(x_coords)
        y2 = max(y_coords)
        # Ensure coordinates are integers and valid range
        return int(x1), int(y1), int(x2), int(y2)
    except (IndexError, TypeError): # Handle malformed points
        return 0, 0, 0, 0

def check_object_validity(image_pil, bbox, black_threshold):
    """Checks if the center of a bbox is likely in a black area using pixel sampling."""
    try:
        x1, y1, x2, y2 = bbox
        if x1 >= x2 or y1 >= y2: return False # Invalid bbox
        
        center_x = int((x1 + x2) / 2)
        center_y = int((y1 + y2) / 2)
        
        # Ensure center is within image bounds before sampling
        img_w, img_h = image_pil.size
        if not (0 <= center_x < img_w and 0 <= center_y < img_h):
             return False # Center outside image, treat as invalid

        # Sample pixel - Pillow's getpixel is efficient for single pixels
        pixel_value = image_pil.getpixel((center_x, center_y))
        
        # Handle greyscale vs RGB tuple from getpixel
        mean_val = 0
        if isinstance(pixel_value, (int, float)): # Greyscale
            mean_val = pixel_value
        elif isinstance(pixel_value, (tuple, list)) and len(pixel_value) >= 1: # RGB/RGBA etc.
             mean_val = pixel_value[0] # Assuming R=G=B or checking first channel is sufficient
        else: 
             return True # Unknown format, assume not black
             
        # If pixel value is strictly greater than the threshold, it's valid (not black)
        return mean_val > black_threshold 
        
    except Exception as e:
        # print(f"  Warning: Error sampling pixel for validity check at ({center_x},{center_y}): {e}")
        return True # Default to valid if sampling fails

def parse_labelme_json_and_analyze(json_path, image_pil, black_threshold_object_check):
    """Parses LabelMe JSON, checks object validity, and collects stats."""
    object_list = []
    image_stats = defaultdict(list)
    img_w, img_h = image_pil.size

    try:
        with open(json_path, 'r', encoding='utf-8') as f: 
            data = json.load(f)
    except Exception as e:
        print(f"  Error reading JSON {json_path}: {e}")
        return [], image_stats # Return empty

    for shape in data.get('shapes', []):
        points = shape.get('points', [])
        label = shape.get('label', 'unknown')

        if not points: continue

        x1, y1, x2, y2 = get_bbox_from_points(points)
        
        # Ensure bbox is within image bounds
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(img_w, x2), min(img_h, y2)
        
        if x1 >= x2 or y1 >= y2: continue # Skip invalid boxes after clipping

        bbox = [x1, y1, x2, y2]
        # Check validity based on center pixel value
        is_valid = check_object_validity(image_pil, bbox, black_threshold_object_check)

        obj_data = {'label': label, 'bbox': bbox, 'is_valid': is_valid}
        object_list.append(obj_data)

        # Only collect stats for objects deemed valid
        if is_valid:
            width = x2 - x1
            height = y2 - y1
            if width > 0 and height > 0:
                image_stats["object_widths"].append(width)
                image_stats["object_heights"].append(height)
                aspect_ratio = max(width / height, height / width)
                image_stats["object_aspect_ratios"].append(aspect_ratio)
                image_stats["object_centers_x"].append((x1 + x2) / 2.0)
                image_stats["object_centers_y"].append((y1 + y2) / 2.0)
                image_stats["object_labels"].append(label) 

    image_stats["num_valid_objects"] = len(image_stats["object_widths"]) 
    image_stats["num_total_objects"] = len(object_list) 
    return object_list, image_stats

def check_slice_overlap(slice_bbox, object_list, margin_pixels):
    """Checks if slice overlaps with any valid object bbox + margin."""
    sx1, sy1, sx2, sy2 = slice_bbox
    for obj in object_list:
        if obj['is_valid']: # Only check against valid objects
            ox1, oy1, ox2, oy2 = obj['bbox']
            # Add margin to object bbox
            ox1m = max(0, ox1 - margin_pixels)
            oy1m = max(0, oy1 - margin_pixels)
            ox2m = ox2 + margin_pixels 
            oy2m = oy2 + margin_pixels

            # Check for overlap 
            if sx1 < ox2m and sx2 > ox1m and sy1 < oy2m and sy2 > oy1m:
                return True # Overlap found
    return False # No overlap with any valid object

def is_slice_predominantly_black(slice_pixels_np_gray, black_threshold, black_percentage_threshold):
    """Checks if a NumPy slice is mostly black based on pixel value percentage."""
    if slice_pixels_np_gray is None or slice_pixels_np_gray.size == 0:
        return True # Treat empty/invalid slice as black

    try:
        # Pixels with value <= threshold are considered black
        black_pixel_count = np.sum(slice_pixels_np_gray <= black_threshold)
        total_pixels = slice_pixels_np_gray.size
        if total_pixels == 0: return True # Avoid division by zero for empty slice

        percentage_black = (black_pixel_count / total_pixels) * 100.0
        
        # Return True if the percentage of black pixels EXCEEDS the threshold
        return percentage_black > black_percentage_threshold
        
    except Exception as e:
         print(f"  Warning: Error calculating black percentage for slice: {e}")
         return True # Treat as black on error

def save_analysis_plots(stats_data, output_dir, img_w, img_h):
    """Generates and saves plots for layout statistics. Requires aggregated data."""
    stats_plot_dir = os.path.join(output_dir, "statistics_plots")
    os.makedirs(stats_plot_dir, exist_ok=True)
    print(f"\nGenerating analysis plots in: {stats_plot_dir}")

    # Aggregate lists from all images
    all_num_objects = stats_data.get("num_valid_objects_per_image", []) 
    all_widths = stats_data.get("object_widths", [])
    all_heights = stats_data.get("object_heights", [])
    all_aspect_ratios = stats_data.get("object_aspect_ratios", [])
    all_centers_x = stats_data.get("object_centers_x", [])
    all_centers_y = stats_data.get("object_centers_y", [])

    plt.style.use('seaborn-v0_8-darkgrid') 

    try:
        # --- Plotting Section (only if data exists) ---
        if all_num_objects:
            plt.figure(figsize=(10, 6))
            max_objs = max(all_num_objects) if all_num_objects else 0
            bins = max(1, max_objs // 2 if max_objs > 1 else 1) 
            sns.histplot(all_num_objects, kde=False, bins=bins)
            plt.title('Distribution of Valid Objects per Image')
            plt.xlabel('Number of Valid Objects')
            plt.ylabel('Frequency (Images)')
            plt.tight_layout()
            plt.savefig(os.path.join(stats_plot_dir, "valid_objects_per_image_hist.png"))
            plt.close()
        
        if all_widths and all_heights:
            plt.figure(figsize=(12, 5))
            plt.subplot(1, 2, 1)
            sns.histplot(all_widths, kde=True, bins=50)
            plt.title('Distribution of Valid Object Widths')
            plt.xlabel('Width (pixels)')
            plt.subplot(1, 2, 2)
            sns.histplot(all_heights, kde=True, bins=50)
            plt.title('Distribution of Valid Object Heights')
            plt.xlabel('Height (pixels)')
            plt.tight_layout()
            plt.savefig(os.path.join(stats_plot_dir, "object_dimensions_hist.png"))
            plt.close()

        if all_aspect_ratios:
            plt.figure(figsize=(10, 6))
            sns.histplot(all_aspect_ratios, kde=True, bins=50)
            plt.title('Distribution of Valid Object Aspect Ratios (max(W/H, H/W))')
            plt.xlabel('Aspect Ratio')
            plt.ylabel('Frequency')
            plt.tight_layout()
            plt.savefig(os.path.join(stats_plot_dir, "object_aspect_ratios_hist.png"))
            plt.close()

        if all_centers_x and all_centers_y and img_w > 0 and img_h > 0:
            plt.figure(figsize=(10, 10 * img_h / img_w if img_w > 0 else 10)) # Adjust aspect ratio
            sns.histplot(x=all_centers_x, y=all_centers_y, bins=100, cbar=True) 
            plt.title('Heatmap of Valid Object Center Positions')
            plt.xlabel('Center X (pixels)')
            plt.ylabel('Center Y (pixels)')
            plt.xlim(0, img_w) 
            plt.ylim(img_h, 0) # Invert Y axis
            plt.gca().set_aspect('equal', adjustable='box')
            plt.tight_layout()
            plt.savefig(os.path.join(stats_plot_dir, "object_centers_heatmap.png"))
            plt.close()
            
        print("Analysis plots saved.")

    except Exception as e:
        print(f"Error generating plots: {e}. Skipping plot generation.")


# --- Main Processing Function ---

def process_large_images_for_backgrounds(input_path, output_path, 
                                         slice_size, 
                                         step_size, 
                                         margin_pixels,
                                         black_threshold, # Combined threshold
                                         black_slice_percentage_threshold, # % threshold for slice
                                         generate_plots):
    """
    Main function: Processes large images and LabelMe JSONs to extract clean 
    background slices and analyze layout statistics, avoiding black areas/objects.
    Accepts configuration parameters as arguments and prints per-file stats.
    """
    print("--- Starting Large Image Layout Analysis and Background Extraction (v4) ---")
    # ... (print configuration parameters as before) ...
    print(f"Input Path: {input_path}")
    print(f"Output Path: {output_path}")
    print(f"Slice Size: {slice_size}x{slice_size}")
    print(f"Step Size: {step_size} (Note: Small step size significantly increases runtime!)")
    print(f"Margin Pixels: {margin_pixels}")
    print(f"Black Threshold: <= {black_threshold}")
    print(f"Black Slice Discard %: > {black_slice_percentage_threshold}%")
    print(f"Generate Plots: {generate_plots}")
    print(f"---------------------\n")


    output_bg_dir = os.path.join(output_path, "backgrounds")
    output_stats_dir = os.path.join(output_path, "statistics")
    output_stats_file = os.path.join(output_stats_dir, "layout_statistics.json")

    # --- Create output directories ---
    try:
        os.makedirs(output_bg_dir, exist_ok=True)
        os.makedirs(output_stats_dir, exist_ok=True)
        print(f"Created/Ensured output directories exist.")
    except OSError as e:
        print(f"Error creating output directories: {e}")
        return

    # Initialize global statistics storage
    global_stats = defaultdict(list)
    
    # --- Find image/JSON pairs ---
    try:
        all_files = os.listdir(input_path)
    except FileNotFoundError:
        print(f"Error: Input path not found: {input_path}")
        return
    except Exception as e:
        print(f"Error listing files in input path: {e}")
        return
        
    image_extensions = ('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.tif')
    image_filenames = sorted([f for f in all_files if f.lower().endswith(image_extensions)])

    if not image_filenames:
        print("Error: No image files found in the input directory.")
        return

    total_bg_extracted_all_images = 0
    processed_image_count = 0
    max_img_w, max_img_h = 0, 0 

    print(f"\nFound {len(image_filenames)} potential images. Starting processing...")
    # --- Loop through image/JSON pairs ---
    for img_filename in tqdm(image_filenames, desc="Processing Images"): 
        base_name = os.path.splitext(img_filename)[0]
        img_path = os.path.join(input_path, img_filename)
        json_filename = base_name + ".json"
        json_path = os.path.join(input_path, json_filename) 

        if not os.path.exists(json_path):
            continue # Silently skip images without JSON labels

        # --- Initialize Per-File Stats ---
        num_total_objects_this_image = 0
        num_valid_objects_this_image = 0
        extracted_for_this_image = 0

        try:
            # --- Open image and get basic info ---
            with Image.open(img_path) as img: 
                img.load() 
                img_w, img_h = img.size
                max_img_w, max_img_h = max(max_img_w, img_w), max(max_img_h, img_h)
                
                # --- Parse Labels & Analyze ---
                object_list, image_stats = parse_labelme_json_and_analyze(json_path, img, black_threshold) 
                
                # Update per-file counts
                num_total_objects_this_image = image_stats.get("num_total_objects", 0)
                num_valid_objects_this_image = image_stats.get("num_valid_objects", 0)

                # Aggregate global stats
                processed_image_count += 1
                global_stats["image_filenames"].append(img_filename) 
                global_stats["num_valid_objects_per_image"].append(num_valid_objects_this_image) 
                for key in ["object_widths", "object_heights", "object_aspect_ratios", "object_centers_x", "object_centers_y", "object_labels"]:
                    global_stats[key].extend(image_stats.get(key, []))

                # --- Iterate through slices ---
                slice_coords_to_check = []
                for sy in range(0, img_h - slice_size + 1, step_size): 
                     for sx in range(0, img_w - slice_size + 1, step_size):
                          slice_coords_to_check.append((sx, sy))

                # Inner loop for slices (no tqdm here by default to avoid excessive output)
                for sx, sy in slice_coords_to_check:
                    slice_bbox_pil = (sx, sy, sx + slice_size, sy + slice_size) 
                    slice_bbox_coords = [sx, sy, sx + slice_size, sy + slice_size] 

                    overlaps = check_slice_overlap(slice_bbox_coords, object_list, margin_pixels)

                    if not overlaps:
                        try:
                            slice_pil = img.crop(slice_bbox_pil)
                            slice_np = np.array(slice_pil)
                            
                            # Ensure greyscale 
                            if slice_np.ndim == 3: 
                                if slice_np.shape[2] >= 3: 
                                     slice_gray = cv2.cvtColor(slice_np, cv2.COLOR_RGB2GRAY)
                                else: continue 
                            elif slice_np.ndim == 2: 
                                 slice_gray = slice_np
                            else: continue 

                            # Check if the slice itself is predominantly black
                            is_black = is_slice_predominantly_black(
                                slice_gray, 
                                black_threshold, 
                                black_slice_percentage_threshold
                            )

                            if not is_black:
                                # Save the clean background patch
                                output_filename = f"{base_name}_bg_{sx}_{sy}.png"
                                output_filepath = os.path.join(output_bg_dir, output_filename)
                                cv2.imwrite(output_filepath, slice_gray)
                                extracted_for_this_image += 1
                                
                        except Exception as crop_err:
                            pass # Ignore errors during cropping/processing individual slices
                            
            # --- Print Per-File Summary ---
            print(f"Finished: {img_filename} ({img_w}x{img_h}) | "
                  f"Total Objects: {num_total_objects_this_image} | "
                  f"Valid Objects: {num_valid_objects_this_image} | "
                  f"Backgrounds Extracted: {extracted_for_this_image}")
            total_bg_extracted_all_images += extracted_for_this_image

        except FileNotFoundError:
            print(f"  Error: Image file not found at {img_path}, skipping.")
        except UnidentifiedImageError:
             print(f"  Error: Cannot identify image file (corrupt or unsupported format): {img_filename}. Skipping.")
        except Image.DecompressionBombError:
             print(f"  Error: Image {img_filename} is too large or corrupt (DecompressionBombError). Skipping.")
             print("  Consider uncommenting 'Image.MAX_IMAGE_PIXELS = None' if needed, but be cautious.")
        except Exception as img_err:
            print(f"  Error processing image {img_filename} or its JSON: {img_err}")

    # --- Final Summary & Saving Stats ---
    print(f"\n--- Finished Processing All Images ---")
    print(f"Processed {processed_image_count} images with labels found.")
    print(f"Total background patches extracted across all images: {total_bg_extracted_all_images}")
    
    summary_stats = {}
    if global_stats.get("object_widths"): # Check if any valid objects were found
        try:
            print("\nCalculating final summary statistics for valid objects...")
            # (Calculation logic remains the same as v3)
            stats_keys = ["num_valid_objects_per_image", "object_widths", "object_heights", "object_aspect_ratios", "object_centers_x", "object_centers_y"]
            for key in stats_keys:
                data_list = global_stats.get(key, []) 
                if data_list: 
                    data_array = np.array(data_list)
                    summary_stats[key] = {
                        'mean': float(np.mean(data_array)),
                        'std': float(np.std(data_array)),
                        'min': float(np.min(data_array)),
                        'max': float(np.max(data_array)),
                        'median': float(np.median(data_array)),
                        'count': len(data_array)
                    }
                else:
                    summary_stats[key] = {'count': 0}
            
            # Truncate raw data for saving
            raw_data_truncated = {}
            for k, v in global_stats.items():
                 if k not in ["image_filenames", "object_labels"]:
                      raw_data_truncated[k] = v[:min(len(v), 10000)] 

            stats_to_save = {
                 "parameters_used": { 
                     "slice_size": slice_size, "step_size": step_size, "margin_pixels": margin_pixels,
                     "black_threshold": black_threshold, "black_slice_percentage_threshold": black_slice_percentage_threshold
                 },
                "summary": summary_stats,
                "raw_data_truncated": raw_data_truncated,
                 "image_count_processed": processed_image_count,
                 "total_valid_objects_found": summary_stats.get("object_widths", {}).get("count", 0),
                 "total_background_patches_extracted": total_bg_extracted_all_images
            }

            # Save layout statistics to JSON file
            with open(output_stats_file, 'w', encoding='utf-8') as f:
                json.dump(stats_to_save, f, indent=4, default=float) 
            print(f"Layout statistics saved to {output_stats_file}")

            # Generate Plots (optional)
            if generate_plots:
                 save_analysis_plots(global_stats, output_path, max_img_w, max_img_h) 

        except Exception as e:
            print(f"Error calculating or saving final statistics: {e}")
    else:
        print("No valid object data found across all images to calculate statistics.")

    print("--- Script Finished ---")


# ==============================================================================
# --- User Configuration Section ---
# ==============================================================================

# 1. Define input and output paths 
#    (Use raw strings r'...' or double backslashes '\\' for Windows paths)
#    Input path should contain both large image files and corresponding .json LabelMe files.
input_path = r"C:\Users\praam\Desktop\havetai+vetcyto\task-04_dataset\50_pc_expanded_automated_labels_with_aligned_HM_images_T"
output_path = r"C:\Users\praam\Desktop\havetai+vetcyto\task-05_dataset\backgrounds_layouts_automated_labels_pc_150"

# 2. Define Slicing Parameters
slice_size = 640       # Pixel dimension of the square slices/background patches to extract

# --- MODIFIED: Step size for moving the slicing window ---
# Set to 10 for maximum overlap search (WILL BE VERY SLOW!). 
# Consider increasing (e.g., 64, 128, 320) for faster processing if needed.
step_size = 320         

margin_pixels = 15     # Safety margin added around valid object bounding boxes. 
                       # Slices overlapping this margin area won't be saved as background.

# 3. Define Black Area Detection Parameters
# Pixel value <= this is considered 'black'. User found '1' optimal.
black_threshold = 1    

# --- MODIFIED: Discard slice if % of black pixels EXCEEDS this threshold ---
# Set to 1.0 for maximum cleanliness (discard if > 1% black pixels).
black_slice_percentage_threshold = 1.0 

# 4. Optional Plotting
# Set to True to generate and save plots visualizing the layout statistics.
generate_plots = True  

# ==============================================================================
# --- Script Execution ---
# ==============================================================================

print(f"--- Configuration ---")
print(f"Input Path: {input_path}")
print(f"Output Path: {output_path}")
print(f"Slice Size: {slice_size}")
print(f"Step Size: {step_size} (Note: Small step size significantly increases runtime!)")
print(f"Margin Pixels: {margin_pixels}")
print(f"Black Threshold: <= {black_threshold}")
print(f"Black Slice Discard %: > {black_slice_percentage_threshold}%")
print(f"Generate Plots: {generate_plots}")
print(f"---------------------\n")

# Run the main processing function with the configured parameters
# Ensure all helper functions are defined in previous cells or above this call
process_large_images_for_backgrounds(
    input_path=input_path, 
    output_path=output_path,
    slice_size=slice_size,
    step_size=step_size,
    margin_pixels=margin_pixels,
    black_threshold=black_threshold, # Used for both object check & slice check
    black_slice_percentage_threshold=black_slice_percentage_threshold,
    generate_plots=generate_plots
)

# Check the 'output_path' for 'backgrounds' and 'statistics' folders after execution.
