In [None]:
import os
import csv
import shutil
import re
import argparse
from pathlib import Path
from datetime import datetime, timezone
import xml.etree.ElementTree as ET
from xml.dom import minidom

# --- Configuration ---
STANDARDIZED_LABEL_INFO = {
    "upright": {"name": "upright", "color": "#00FF00", "attributes": []},
    "fallen": {
        "name": "fallen",
        "color": "#FF0000",
        "attributes": [
            {
                "name": "is_crowd",
                "mutable": "False",
                "input_type": "checkbox",
                "default_value": "false",
                "values": "",
            }
        ],
    },
    "other": {"name": "other", "color": "#0000FF", "attributes": []},
    "unlabeled": {"name": "unlabeled", "color": "#808080", "attributes": []},
    "incomplete": {"name": "incomplete", "color": "#FFFF00", "attributes": []},
}

# Mapping for input label variations to standardized names
LABEL_NAME_MAP = {
    "upright": "upright",
    "fallen": "fallen",
    "other": "other",
    "unlabelled": "unlabeled",
    "unlabeled": "unlabeled",
    "incomplete": "incomplete",
}

PROJECT_ID = "1"
PROJECT_NAME = "Consolidated_Tree_Project"
OWNER_USERNAME = "admin"
OWNER_EMAIL = "admin@example.com"

# --- Helper Functions ---

def get_cvat_timestamp():
    """Generates a CVAT-compatible timestamp string."""
    return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f") + "+00:00"

def get_location_name_from_xml_filename(xml_filename):
    """Extracts the base location name from an XML filename (e.g., 'dugwal_1.xml' -> 'dugwal')."""
    base = Path(xml_filename).stem
    # Remove common suffixes like _1, _2, etc.
    base = re.sub(r"_\d+$", "", base)
    return base

def read_server_paths_csv(csv_filepath):
    """Reads the server paths CSV into a dictionary."""
    server_paths = {}
    try:
        with open(csv_filepath, mode="r", encoding="utf-8") as f:
            reader = csv.reader(f)
            for row in reader:
                if row: # Ensure row is not empty
                    location = row[0].strip()
                    path = row[1].strip()
                    server_paths[location] = path
        if not server_paths:
            print(f"Warning: CSV file {csv_filepath} is empty or contains no valid data.")
        return server_paths
    except FileNotFoundError:
        print(f"Error: CSV file not found at {csv_filepath}")
        return None
    except Exception as e:
        print(f"Error reading CSV file {csv_filepath}: {e}")
        return None


def prettify_xml(element):
    """Return a pretty-printed XML string for the Element."""
    rough_string = ET.tostring(element, "utf-8")
    reparsed = minidom.parseString(rough_string)
    return reparsed.toprettyxml(indent="  ", encoding="utf-8").decode("utf-8")


# --- Main Processing Logic ---

def process_xml_files(xmls_dir_path, server_paths_csv_path, output_tifs_dir_path, output_project_xml_path):
    """
    Processes XML files, creates a project XML, and downloads relevant TIFs.
    """
    if not os.path.isdir(xmls_dir_path):
        print(f"Error: XMLs directory not found: {xmls_dir_path}")
        return

    server_paths_map = read_server_paths_csv(server_paths_csv_path)
    if server_paths_map is None:
        return

    Path(output_tifs_dir_path).mkdir(parents=True, exist_ok=True)

    project_images_data = []  # To store (image_element, location_name, original_task_info)
    project_tasks_meta = []   # To store metadata for each task in the project

    global_image_id = 0
    project_task_id_counter = 1 # For new task IDs in the project XML
    project_segment_id_counter = 1 # For new segment IDs

    xml_files = [f for f in os.listdir(xmls_dir_path) if f.endswith(".xml")]
    if not xml_files:
        print(f"No XML files found in {xmls_dir_path}")
        return

    for xml_filename in xml_files:
        xml_filepath = os.path.join(xmls_dir_path, xml_filename)
        location_name = get_location_name_from_xml_filename(xml_filename)
        print(f"\nProcessing {xml_filename} for location: {location_name}...")

        server_tif_folder_path = server_paths_map.get(location_name)
        if not server_tif_folder_path:
            print(f"Warning: No server path found for location '{location_name}' in CSV. Skipping TIF downloads for this file.")

        try:
            tree = ET.parse(xml_filepath)
            root = tree.getroot()
        except ET.ParseError as e:
            print(f"Error parsing {xml_filepath}: {e}. Skipping this file.")
            continue

        original_task_element = root.find("meta/task")
        if original_task_element is None:
            print(f"Warning: No <task> element found in meta for {xml_filename}. Skipping.")
            continue
        
        current_task_images = []
        images_to_download_for_task = [] # (tif_filename, server_base_path)

        for image_elem in root.findall("image"):
            image_filename = image_elem.get("name")
            has_polygons = False
            has_incomplete_label = False
            
            # Store original annotations to modify them before adding to project
            processed_annotations = []

            for annot_elem in list(image_elem): # Iterate over a copy for safe modification
                if annot_elem.tag not in ["polygon", "box"]:
                    processed_annotations.append(annot_elem) # Keep other elements if any
                    continue

                if annot_elem.tag == "polygon":
                    has_polygons = True

                original_label = annot_elem.get("label")
                if original_label is None: # Should not happen in valid CVAT
                    print(f"Warning: Annotation in {image_filename} missing label. Skipping annotation.")
                    continue

                standardized_label_name = LABEL_NAME_MAP.get(original_label.lower())
                if not standardized_label_name:
                    print(f"Warning: Unknown label '{original_label}' in {image_filename}. Treating as 'other'.")
                    standardized_label_name = "other" # Default or skip

                annot_elem.set("label", standardized_label_name) # Update label to standardized name

                if standardized_label_name == "incomplete":
                    has_incomplete_label = True
                
                processed_annotations.append(annot_elem)


            if server_tif_folder_path and image_filename and has_polygons and not has_incomplete_label:
                # This image is kept
                # Clear existing annotations and add processed ones
                for child in list(image_elem):
                    image_elem.remove(child)
                for pa in processed_annotations:
                    image_elem.append(pa)
                
                current_task_images.append(image_elem)
                images_to_download_for_task.append(image_filename)
            elif not has_polygons:
                print(f"  Skipping image {image_filename}: No polygon annotations.")
            elif has_incomplete_label:
                print(f"  Skipping image {image_filename}: Contains an 'incomplete' label.")
        
        if current_task_images:
            task_info = {
                "id": str(project_task_id_counter), # New ID for the project
                "name": original_task_element.findtext("name", default=location_name), # Use original task name or location name
                "size": str(len(current_task_images)),
                "mode": original_task_element.findtext("mode", default="annotation"),
                "overlap": original_task_element.findtext("overlap", default="0"),
                "bugtracker": original_task_element.findtext("bugtracker", default=""),
                "created": original_task_element.findtext("created", default=get_cvat_timestamp()),
                "updated": original_task_element.findtext("updated", default=get_cvat_timestamp()),
                "subset": "default", # Project uses 'default' subset
                "start_frame": "0",
                "stop_frame": str(len(current_task_images) - 1),
                "frame_filter": original_task_element.findtext("frame_filter", default=""),
                "segment_id": str(project_segment_id_counter), # New segment ID
                "segment_url": f"http://localhost:8080/api/jobs/{project_segment_id_counter}" # Use new segment ID
            }
            project_tasks_meta.append(task_info)
            
            for img_elem in current_task_images:
                project_images_data.append({
                    "element": img_elem,
                    "location_name": location_name,
                    "task_id": str(project_task_id_counter), # Link to the new task ID
                    "server_tif_folder": server_tif_folder_path,
                    "tif_filename": img_elem.get("name")
                })

            project_task_id_counter += 1
            project_segment_id_counter +=1
        else:
            print(f"No valid images included from {xml_filename}.")


    # --- Build Project XML ---
    if not project_images_data:
        print("\nNo images to include in the project XML. Output will be empty or minimal.")
        # Create a minimal valid project XML if desired, or just skip
        # For now, let's write a minimal one.

    # Create root <annotations>
    project_root_out = ET.Element("annotations")
    ET.SubElement(project_root_out, "version").text = "1.1"

    # <meta>
    meta_out = ET.SubElement(project_root_out, "meta")
    project_out = ET.SubElement(meta_out, "project")
    ET.SubElement(project_out, "id").text = PROJECT_ID
    ET.SubElement(project_out, "name").text = PROJECT_NAME
    ET.SubElement(project_out, "bugtracker")
    ET.SubElement(project_out, "created").text = get_cvat_timestamp()
    ET.SubElement(project_out, "updated").text = get_cvat_timestamp()

    # <project><tasks>
    tasks_out = ET.SubElement(project_out, "tasks")
    for task_meta in project_tasks_meta:
        task_out = ET.SubElement(tasks_out, "task")
        ET.SubElement(task_out, "id").text = task_meta["id"]
        ET.SubElement(task_out, "name").text = task_meta["name"]
        ET.SubElement(task_out, "size").text = task_meta["size"]
        ET.SubElement(task_out, "mode").text = task_meta["mode"]
        ET.SubElement(task_out, "overlap").text = task_meta["overlap"]
        ET.SubElement(task_out, "bugtracker").text = task_meta["bugtracker"]
        ET.SubElement(task_out, "created").text = task_meta["created"]
        ET.SubElement(task_out, "updated").text = task_meta["updated"]
        ET.SubElement(task_out, "subset").text = task_meta["subset"]
        ET.SubElement(task_out, "start_frame").text = task_meta["start_frame"]
        ET.SubElement(task_out, "stop_frame").text = task_meta["stop_frame"]
        ET.SubElement(task_out, "frame_filter").text = task_meta["frame_filter"]
        
        segments_out = ET.SubElement(task_out, "segments")
        segment_out = ET.SubElement(segments_out, "segment")
        ET.SubElement(segment_out, "id").text = task_meta["segment_id"]
        ET.SubElement(segment_out, "start").text = task_meta["start_frame"] # segment covers all frames of the task
        ET.SubElement(segment_out, "stop").text = task_meta["stop_frame"]
        ET.SubElement(segment_out, "url").text = task_meta["segment_url"]

        owner_task_out = ET.SubElement(task_out, "owner")
        ET.SubElement(owner_task_out, "username").text = OWNER_USERNAME
        ET.SubElement(owner_task_out, "email").text = OWNER_EMAIL
        ET.SubElement(task_out, "assignee")


    ET.SubElement(project_out, "subsets").text = "default"
    owner_proj_out = ET.SubElement(project_out, "owner")
    ET.SubElement(owner_proj_out, "username").text = OWNER_USERNAME
    ET.SubElement(owner_proj_out, "email").text = OWNER_EMAIL
    ET.SubElement(project_out, "assignee")

    # <project><labels>
    labels_out = ET.SubElement(project_out, "labels")
    for std_label_name, info in STANDARDIZED_LABEL_INFO.items():
        label_out = ET.SubElement(labels_out, "label")
        ET.SubElement(label_out, "name").text = info["name"]
        ET.SubElement(label_out, "color").text = info["color"]
        ET.SubElement(label_out, "type").text = "any" # As per example
        attributes_out = ET.SubElement(label_out, "attributes")
        for attr_def in info["attributes"]:
            attribute_out = ET.SubElement(attributes_out, "attribute")
            ET.SubElement(attribute_out, "name").text = attr_def["name"]
            ET.SubElement(attribute_out, "mutable").text = attr_def["mutable"]
            ET.SubElement(attribute_out, "input_type").text = attr_def["input_type"]
            ET.SubElement(attribute_out, "default_value").text = attr_def["default_value"]
            ET.SubElement(attribute_out, "values").text = attr_def["values"]


    ET.SubElement(meta_out, "dumped").text = get_cvat_timestamp()

    # Add <image> elements to root <annotations>
    download_tasks = [] # (server_full_path, local_full_path)

    for img_data in project_images_data:
        image_elem = img_data["element"]
        image_elem.set("id", str(global_image_id))
        image_elem.set("subset", "default") # Standardize subset for project
        image_elem.set("task_id", img_data["task_id"])
        project_root_out.append(image_elem)
        global_image_id += 1

        # Prepare for download
        location_name = img_data["location_name"]
        tif_filename = img_data["tif_filename"]
        server_base_path = img_data["server_tif_folder"]

        if server_base_path and tif_filename:
            local_location_dir = Path(output_tifs_dir_path) / location_name
            local_location_dir.mkdir(parents=True, exist_ok=True)
            
            # Correctly join server path parts
            # Assuming server_base_path is like \\server\share\folder
            # and tif_filename is just the file name.
            # os.path.join might not work correctly with UNC paths on non-Windows
            # if not careful, but shutil.copy handles it.
            # Construct full server path carefully:
            if server_base_path.endswith('\\') or server_base_path.endswith('/'):
                 server_full_tif_path = server_base_path + tif_filename
            else:
                 server_full_tif_path = server_base_path + os.sep + tif_filename # os.sep might be an issue for UNC

            # More robust way for network paths:
            # For \\1.2.3.4\share\folder and image.tif -> \\1.2.3.4\share\folder\image.tif
            # Ensure tif_filename doesn't have leading slashes
            tif_filename_cleaned = tif_filename.lstrip('\\/')
            server_full_tif_path = os.path.join(server_base_path, tif_filename_cleaned)


            local_tif_path = local_location_dir / tif_filename
            download_tasks.append((server_full_tif_path, local_tif_path, tif_filename))


    # Write Project XML file
    xml_string_pretty = prettify_xml(project_root_out)
    try:
        with open(output_project_xml_path, "w", encoding="utf-8") as f:
            f.write(xml_string_pretty)
        print(f"\nSuccessfully created project XML: {output_project_xml_path}")
    except IOError as e:
        print(f"Error writing project XML to {output_project_xml_path}: {e}")


    # --- Download TIFs ---
    if download_tasks:
        print(f"\nStarting TIF downloads to {output_tifs_dir_path}...")
        successful_downloads = 0
        failed_downloads = 0
        for server_path, local_path, tif_name in download_tasks:
            print(f"  Attempting to download {tif_name} from {server_path} to {local_path}")
            try:
                # Ensure local directory exists (should be created already)
                Path(local_path).parent.mkdir(parents=True, exist_ok=True)
                shutil.copy2(server_path, local_path) # copy2 preserves metadata
                print(f"    Successfully downloaded {tif_name}")
                successful_downloads +=1
            except FileNotFoundError:
                print(f"    Error: Server file not found - {server_path}")
                failed_downloads +=1
            except PermissionError:
                print(f"    Error: Permission denied for {server_path} or {local_path}")
                failed_downloads +=1
            except Exception as e:
                print(f"    Error downloading {server_path}: {e}")
                failed_downloads +=1
        print(f"TIF Download complete. Successful: {successful_downloads}, Failed: {failed_downloads}")
    else:
        print("\nNo TIF files to download.")

if __name__ == "__main__":
    process_xml_files(r"C:\Users\kevin\Documents\xmls",
                             r"C:\Users\kevin\dev\tornado-tree-destruction-ef\tornado-tree-destruction-ef\scripts\serverpaths.csv",
                             r"C:\Users\kevin\Documents\d", 
                             r"C:\Users\kevin\Documents\project_combined_final.xml")


Processing centreglassville.xml for location: centreglassville...
No valid images included from centreglassville.xml.

Processing crozier.xml for location: crozier...
  Skipping image 22_Crozier_461000_5386000.tif: No polygon annotations.

Processing dugwal.xml for location: dugwal...

Processing gooseberry.xml for location: gooseberry...

Processing gowanmarsh.xml for location: gowanmarsh...

Processing jasper.xml for location: jasper...
  Skipping image SofWabasso2.TIF: No polygon annotations.
  Skipping image SofWabasso3.TIF: No polygon annotations.
  Skipping image SofWabasso6.TIF: No polygon annotations.
  Skipping image SofWabasso8.TIF: No polygon annotations.
  Skipping image SofWabasso9.TIF: No polygon annotations.

Processing kaoskauta.xml for location: kaoskauta...
  Skipping image 22_Kaoskauta_492000_5470000.tif: No polygon annotations.
  Skipping image 22_Kaoskauta_493000_5469000.tif: No polygon annotations.
  Skipping image 22_Kaoskauta_493000_5470000.tif: No polygon anno