In [None]:
import cv2
import numpy as np
import os
import xml.etree.ElementTree as ET
from datetime import datetime
import logging

# --- Configuration ---
OUTPUT_TILE_SIZE = 512
STRIDE = 256
DOWNSCALE_THRESHOLD = 10000
DOWNSCALE_FACTOR = 2.0 # Scale factor (e.g., 2.0 means halve the size)

CLASS_MAPPING = {
    'upright': 0,
    'fallen': 1,
    'other': 2,
    'unlabeled': 3
}
# 'incomplete' is an error class, not mapped to a value.

# Priority for drawing when classes overlap at the same z-level (higher value = higher priority)
CLASS_PRIORITY = {
    'fallen': 3,
    'other': 2,
    'upright': 1, # Explicit upright polygons
    'unlabeled': 0
}

# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Helper Functions ---

def parse_points_string(points_str, scale_factor=1.0):
    """Parses a string of points 'x1,y1;x2,y2;...' into a list of [x,y] tuples and scales them."""
    points = []
    try:
        for p_pair in points_str.split(';'):
            if not p_pair: continue
            coords = p_pair.split(',')
            x = float(coords[0]) * scale_factor
            y = float(coords[1]) * scale_factor
            points.append([x, y])
    except Exception as e:
        logging.error(f"Error parsing points string '{points_str[:50]}...': {e}")
        return None
    return np.array(points, dtype=np.int32)

def read_timestamp(file_path):
    """Reads a timestamp from a file."""
    if os.path.exists(file_path):
        with open(file_path, 'r') as f:
            try:
                return float(f.read().strip())
            except ValueError:
                return 0.0
    return 0.0

def write_timestamp(file_path, timestamp):
    """Writes a timestamp to a file."""
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w') as f:
        f.write(str(timestamp))

def get_file_mod_time(file_path):
    """Gets the modification time of a file."""
    if os.path.exists(file_path):
        return os.path.getmtime(file_path)
    return 0.0

def generate_full_mask(image_dims, polygons_data, class_mapping, class_priority):
    """
    Generates a full-size mask for a large image based on polygon annotations.
    polygons_data: list of dicts {'points': np.array, 'label': str, 'z_order': int}
    """
    height, width = image_dims
    full_mask = np.full((height, width), class_mapping['upright'], dtype=np.uint8)

    if not polygons_data:
        return full_mask

    # Sort polygons first by z-order (ascending), then by class priority (descending for drawing order)
    # This means lower z-orders are processed first. For a given z-order, higher priority classes are drawn last (on top).
    
    # Get all unique z-orders
    z_orders = sorted(list(set(p['z_order'] for p in polygons_data)))

    for z in z_orders:
        polys_at_z = [p for p in polygons_data if p['z_order'] == z]
        
        # Sort polygons at the current z-level by class priority (ascending, so higher priority drawn last)
        polys_at_z_sorted = sorted(polys_at_z, key=lambda p: class_priority.get(p['label'], -1))

        for poly_info in polys_at_z_sorted:
            label = poly_info['label']
            points = poly_info['points']
            
            if label in class_mapping and points is not None and len(points) > 0:
                cv_points = points.reshape((-1, 1, 2)) # OpenCV format
                cv2.fillPoly(full_mask, [cv_points], class_mapping[label])
            elif label not in class_mapping and label != 'incomplete':
                 logging.warning(f"Unknown label '{label}' encountered in polygon data. Skipping this polygon.")


    return full_mask

def process_large_image(image_info, task_name, cvat_xml_path, input_image_root_dir,
                        output_dataset_dir, last_proc_timestamp,
                        interpolation_method=cv2.INTER_AREA):
    """
    Processes a single large satellite image: loads, scales, generates mask, tiles, and saves.
    image_info: dict from parsed XML containing {'id', 'name', 'width', 'height', 'task_id', 'polygons'}
    polygons: list of {'label', 'points_str', 'z_order'}
    """
    image_name = image_info['name']
    image_path = os.path.join(input_image_root_dir, task_name, image_name)
    xml_mod_time = get_file_mod_time(cvat_xml_path)

    if not os.path.exists(image_path):
        logging.warning(f"Image file not found: {image_path}. Skipping.")
        return False # Indicate that no processing was done for timestamp update

    image_mod_time = get_file_mod_time(image_path)

    if image_mod_time <= last_proc_timestamp and xml_mod_time <= last_proc_timestamp :
        logging.info(f"Image {image_name} and XML are older than last processing time. Skipping.")
        return False


    logging.info(f"Processing large image: {image_path}")

    # --- 1. Error check based on labels for the entire image ---
    raw_labels = [p['label'] for p in image_info['polygons']]
    if not raw_labels and not image_info['polygons']: # No polygons at all
        logging.info(f"Image {image_name} has no polygon annotations. It will likely result in empty tiles only.")
        # This isn't an error per se, tiles will just be skipped if all upright.
    elif 'incomplete' in raw_labels:
        logging.error(f"Image {image_name} contains 'incomplete' labels. Skipping this image.")
        return True # Processed (checked), but skipped due to error
    
    # Check if only 'unlabeled' polygons exist (potentially with background)
    # This means every polygon MUST be 'unlabeled' if polygons exist.
    # If there are other types of labels, this condition is false.
    if image_info['polygons'] and all(label == 'unlabeled' for label in raw_labels):
        logging.error(f"Image {image_name} contains only 'unlabeled' polygons. Skipping this image.")
        return True # Processed (checked), but skipped due to error

    # --- 2. Load and potentially downscale image ---
    img = cv2.imread(image_path)
    if img is None:
        logging.error(f"Failed to load image: {image_path}. Skipping.")
        return True

    original_height, original_width = img.shape[:2]
    scale_factor = 1.0

    if original_height > DOWNSCALE_THRESHOLD or original_width > DOWNSCALE_THRESHOLD:
        logging.info(f"Image {image_name} ({original_width}x{original_height}) exceeds threshold, downscaling by {DOWNSCALE_FACTOR}x.")
        scale_factor = 1.0 / DOWNSCALE_FACTOR
        new_width = int(original_width * scale_factor)
        new_height = int(original_height * scale_factor)
        img = cv2.resize(img, (new_width, new_height), interpolation=interpolation_method)
        logging.info(f"Downscaled to {new_width}x{new_height}.")
    
    current_height, current_width = img.shape[:2]

    # --- 3. Prepare polygon data (parse points and scale) ---
    scaled_polygons_data = []
    for poly in image_info['polygons']:
        parsed_pts = parse_points_string(poly['points_str'], scale_factor)
        if parsed_pts is not None and len(parsed_pts) >=3 : # Polygons need at least 3 points
             scaled_polygons_data.append({
                'points': parsed_pts,
                'label': poly['label'],
                'z_order': poly['z_order']
            })
        elif parsed_pts is not None :
            logging.warning(f"Polygon with label '{poly['label']}' in {image_name} has < 3 points after scaling. Skipping polygon.")


    # --- 4. Generate full mask for the (potentially scaled) image ---
    logging.info(f"Generating full mask for {image_name}...")
    full_mask = generate_full_mask((current_height, current_width), scaled_polygons_data, CLASS_MAPPING, CLASS_PRIORITY)

    # --- 5. Stride and save tiles ---
    output_image_dir = os.path.join(output_dataset_dir, 'images', task_name)
    output_mask_dir = os.path.join(output_dataset_dir, 'masks', task_name)
    os.makedirs(output_image_dir, exist_ok=True)
    os.makedirs(output_mask_dir, exist_ok=True)

    tiles_saved_count = 0
    for r in range(0, current_height - OUTPUT_TILE_SIZE + 1, STRIDE):
        for c in range(0, current_width - OUTPUT_TILE_SIZE + 1, STRIDE):
            img_tile = img[r:r+OUTPUT_TILE_SIZE, c:c+OUTPUT_TILE_SIZE]
            mask_tile = full_mask[r:r+OUTPUT_TILE_SIZE, c:c+OUTPUT_TILE_SIZE]

            # Skip if mask tile contains only the 'upright' class (background)
            if np.all(mask_tile == CLASS_MAPPING['upright']):
                continue

            # Save tile
            base_filename = f"{os.path.splitext(image_name)[0]}_r{r}_c{c}"
            
            img_tile_path = os.path.join(output_image_dir, f"{base_filename}.png")
            cv2.imwrite(img_tile_path, img_tile)

            mask_tile_path = os.path.join(output_mask_dir, f"{base_filename}.png")
            cv2.imwrite(mask_tile_path, mask_tile)
            tiles_saved_count += 1
            
    logging.info(f"Saved {tiles_saved_count} tiles for {image_name}.")
    return True # Processed successfully

def main(cvat_xml_path, input_image_root_dir, output_dir_base, interpolation_str="INTER_AREA"):
    """Main processing function."""

    interpolation_methods = {
        "INTER_NEAREST": cv2.INTER_NEAREST,
        "INTER_LINEAR": cv2.INTER_LINEAR,
        "INTER_AREA": cv2.INTER_AREA,
        "INTER_CUBIC": cv2.INTER_CUBIC,
        "INTER_LANCZOS4": cv2.INTER_LANCZOS4,
    }
    interpolation_method = interpolation_methods.get(interpolation_str, cv2.INTER_AREA)
    logging.info(f"Using interpolation method: {interpolation_str}")

    output_dataset_dir = os.path.join(output_dir_base, "dataset")
    os.makedirs(output_dataset_dir, exist_ok=True)

    timestamp_file = os.path.join(output_dataset_dir, "lastmodified.txt")
    last_processed_timestamp = read_timestamp(timestamp_file)
    logging.info(f"Last processed timestamp: {datetime.fromtimestamp(last_processed_timestamp) if last_processed_timestamp > 0 else 'N/A (full run)'}")

    # --- Parse CVAT XML ---
    logging.info(f"Parsing CVAT XML: {cvat_xml_path}")
    try:
        tree = ET.parse(cvat_xml_path)
        root = tree.getroot()
    except ET.ParseError as e:
        logging.error(f"Error parsing XML file: {e}")
        return
    except FileNotFoundError:
        logging.error(f"CVAT XML file not found: {cvat_xml_path}")
        return

    # Create a mapping from task_id to task_name
    task_id_to_name = {}
    for task_elem in root.findall("./meta/project/tasks/task"):
        task_id = task_elem.find("id").text
        task_name = task_elem.find("name").text
        task_id_to_name[task_id] = task_name
        # Ensure output directories for this task exist early
        os.makedirs(os.path.join(output_dataset_dir, 'images', task_name), exist_ok=True)
        os.makedirs(os.path.join(output_dataset_dir, 'masks', task_name), exist_ok=True)


    images_data = []
    for image_elem in root.findall("./image"):
        img_id = image_elem.get("id")
        img_name = image_elem.get("name")
        img_width = int(image_elem.get("width"))
        img_height = int(image_elem.get("height"))
        task_id = image_elem.get("task_id") # CVAT for images format
        
        # Fallback if task_id is not directly on image (older CVAT versions might have it on job)
        if task_id is None: # Try to find task_id via job_id if structure implies
            job_id = image_elem.get("job_id") # Common in CVAT project exports
            if job_id:
                # This part is speculative without exact XML structure for job_id -> task_id mapping
                # Assuming a simple structure or that task_id is preferred
                logging.warning(f"Image {img_name} missing direct task_id, has job_id {job_id}. Task mapping might be indirect.")
                # If you have a way to map job_id to task_id, implement here.
                # For now, we rely on task_id on image element or image name convention.


        # Determine task_name
        current_task_name = task_id_to_name.get(task_id)
        if not current_task_name:
            # Try to infer task_name from image_name if convention like "Location_..." exists
            # This is a fallback, relying on task_id is better.
            parts = img_name.split('_')
            if len(parts) > 1 and parts[0].isalpha(): # e.g. "Crozier_..."
                 potential_task_name_from_img = parts[1] # Assuming "ID_Location_..."
                 if potential_task_name_from_img in task_id_to_name.values():
                      current_task_name = potential_task_name_from_img
                      logging.warning(f"Inferred task name '{current_task_name}' from image name for {img_name} as task_id was not directly mapped.")
                 else: # Try first part if that's the location
                    potential_task_name_from_img = parts[0]
                    if potential_task_name_from_img in task_id_to_name.values():
                        current_task_name = potential_task_name_from_img
                        logging.warning(f"Inferred task name '{current_task_name}' (first part) from image name for {img_name} as task_id was not directly mapped.")


        if not current_task_name:
            logging.error(f"Could not determine task name for image {img_name} (task_id: {task_id}). Skipping this image.")
            continue
            
        polygons = []
        for poly_elem in image_elem.findall("./polygon"):
            label = poly_elem.get("label")
            points_str = poly_elem.get("points")
            z_order = int(poly_elem.get("z_order", "0"))
            occluded = poly_elem.get("occluded") == "1" # Example, if you need it

            if label and points_str:
                 polygons.append({'label': label, 'points_str': points_str, 'z_order': z_order})
            else:
                logging.warning(f"Skipping polygon with missing label or points in image {img_name}")

        images_data.append({
            'id': img_id, 'name': img_name, 'width': img_width, 'height': img_height,
            'task_id': task_id, 'task_name': current_task_name, 'polygons': polygons
        })
    
    logging.info(f"Parsed {len(images_data)} image entries from XML.")

    something_processed = False
    for image_info_dict in images_data:
        task_name_for_image = image_info_dict['task_name']
        if process_large_image(image_info_dict, task_name_for_image, cvat_xml_path,
                               input_image_root_dir, output_dataset_dir,
                               last_processed_timestamp, interpolation_method):
            something_processed = True
            
    # Update timestamp if any image was processed or checked after the last timestamp
    # Or if the XML itself is newer.
    xml_mod_time = get_file_mod_time(cvat_xml_path)
    if something_processed or xml_mod_time > last_processed_timestamp :
        current_run_timestamp = datetime.now().timestamp()
        write_timestamp(timestamp_file, current_run_timestamp)
        logging.info(f"Processing complete. Updated timestamp to: {datetime.fromtimestamp(current_run_timestamp)}")
    else:
        logging.info("No new images or XML modifications found. Processing complete.")


if __name__ == '__main__':
    # --- User Configuration ---
    # ! Adjust these paths according to your system !
    CVAT_XML_FILE = r"..\dataset\project_combined.xml" # E.g., "C:/cvat_exports/project_trees_annotations.xml"
    INPUT_IMAGE_ROOT = r"..\dataset" # Base folder containing 'centreglassville', 'crozier', etc.
    OUTPUT_DIRECTORY_BASE = r"..\dataset_processed" # The 'dataset' folder will be created inside this directory
    
    # Optional: Change interpolation method string if needed.
    # Options: "INTER_NEAREST", "INTER_LINEAR", "INTER_AREA", "INTER_CUBIC", "INTER_LANCZOS4"
    SCALING_INTERPOLATION = "INTER_AREA" 

    # --- Run ---
    # Example usage:
    if CVAT_XML_FILE == "path/to/your/annotations.xml" or \
       INPUT_IMAGE_ROOT == "C:/USERS/KEVIN/DEV/TORNADO-TREE-DESTRUCTION-EF/TORNADO-TREE-DESTRUCTION-EF/DATASET_NEEDS_UPDATE" : # Quick check if paths were updated
        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        print("!!! PLEASE UPDATE 'CVAT_XML_FILE' AND 'INPUT_IMAGE_ROOT' IN THE SCRIPT !!!")
        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    else:
        main(CVAT_XML_FILE, INPUT_IMAGE_ROOT, OUTPUT_DIRECTORY_BASE, SCALING_INTERPOLATION)

2025-05-29 15:21:40,042 - INFO - Using interpolation method: INTER_AREA
2025-05-29 15:21:40,044 - INFO - Last processed timestamp: N/A (full run)
2025-05-29 15:21:40,045 - INFO - Parsing CVAT XML: ..\dataset\project_combined.xml
2025-05-29 15:21:40,131 - INFO - Parsed 36 image entries from XML.
2025-05-29 15:21:40,132 - INFO - Processing large image: ..\dataset\crozier\22_Crozier_461000_5385000.tif
2025-05-29 15:21:41,880 - INFO - Generating full mask for 22_Crozier_461000_5385000.tif...
2025-05-29 15:22:02,829 - INFO - Saved 1311 tiles for 22_Crozier_461000_5385000.tif.
2025-05-29 15:22:02,874 - INFO - Processing large image: ..\dataset\dugwal\22_Dugwal_499000_5383000.tif
2025-05-29 15:22:04,666 - INFO - Generating full mask for 22_Dugwal_499000_5383000.tif...
2025-05-29 15:22:07,098 - INFO - Saved 137 tiles for 22_Dugwal_499000_5383000.tif.
2025-05-29 15:22:07,131 - INFO - Processing large image: ..\dataset\dugwal\22_Dugwal_499000_5384000.tif
2025-05-29 15:22:08,656 - INFO - Generati