In [1]:
import csv
import os
from pathlib import Path
import subprocess
import requests
import init_arkit
import cv2 
import json 
print("INIT FILE:", init_arkit.__file__)
import importlib
importlib.reload(init_arkit)

def download_file(file_url: str, dest_path: str | os.PathLike) -> bool:
    """
    Download a single file from `file_url` and save it to `dest_path`.

    Returns:
        True if the download succeeds, False otherwise.
    """
    dest_path = Path(dest_path)
    try:
        print(f"Downloading {file_url} â†’ {dest_path}")
        response = requests.get(file_url, timeout=60)
        response.raise_for_status()
        dest_path.parent.mkdir(parents=True, exist_ok=True)
        with dest_path.open('wb') as f:
            f.write(response.content)
        print(f"âœ“ Downloaded: {dest_path}")
        return True
    except Exception as e:
        print(f"âœ— Failed to download {file_url}: {e}")
        return False
    
# -------------------------------------------------------------
# Download a single ARKit scene's labelmaker
# -------------------------------------------------------------
def download_arkit_labelmaker(video_id: str, split: str, scene_dir: Path) -> None:
    """
    Downloads a single ARKit scene's label and mesh information.
    Prints the scene name before and after download.
    """
    print("\n===============================================")
    print(f"ðŸ“Œ Downloading ARKit labelmaker: {video_id}  (split: {split})")
    print("===============================================")

    # files we always need for processing
    label_files = ['labels.txt', 'point_lifted_mesh.ply']

    # base URLs for fetching labels and meshes
    labels_base_url = "https://huggingface.co/datasets/labelmaker/arkit_labelmaker/raw/main"
    ply_base_url = "https://huggingface.co/datasets/labelmaker/arkit_labelmaker/resolve/main"

    # Download the two required annotation assets into the scene folder
    for file_name in label_files:
        file_url = (
            f"{ply_base_url}/{split}/{video_id}/{file_name}"
            if file_name.endswith(".ply")
            else f"{labels_base_url}/{split}/{video_id}/{file_name}"
        )
        dest_path = scene_dir / file_name
        download_file(file_url, dest_path)

    print(f"ðŸŽ‰ Finished downloading scene's labelmaker: {video_id}")

# -------------------------------------------------------------
# Download a single ARKit scene (simplified)
# -------------------------------------------------------------
def download_arkit_scene(video_id: str, split: str, download_dir: str = "arkitscenes") -> None:
    """
    Downloads a single ARKit scene using download_data.py.
    Prints the scene name before and after download.
    """
    print("\n===============================================")
    print(f"ðŸ“Œ Downloading ARKit Scene: {video_id}  (split: {split})")
    print("===============================================")

    cmd = [
        "python3", "download_data.py", "raw",
        "--video_id", video_id,
        "--split", split,
        "--download_dir", download_dir,
        "--raw_dataset_assets", "lowres_wide.traj",
        "vga_wide", "vga_wide_intrinsics", 
    ]

    subprocess.run(cmd, check=False)  # set check=True if you want it to crash on error

    print(f"ðŸŽ‰ Finished downloading scene: {video_id}")


# -------------------------------------------------------------
# Read the ARKit split CSV and download scenes sequentially
# -------------------------------------------------------------
def download_arkit_dataset(csv_path: str = "raw_train_val_splits.csv",
                           download_dir: str = "arkitscenes",
                           output_fol: str = "segmentation_summary") -> None:
    """
    Reads the ARKitScenes train/val CSV and downloads ALL scenes sequentially.
    Prints the scene name for each one.
    """
    csv_path = Path(csv_path)
    if not csv_path.exists():
        print(f"CSV not found: {csv_path}")
        return

    print(f"Reading split file: {csv_path}")

    with csv_path.open("r") as f:
        reader = csv.DictReader(f)

        for row in reader:
            print("Loaded row:", row)  # TEMP DEBUG

            video_id = row.get("video_id") or row.get("id") or row.get("scene_id")
            split = row.get("split") or row.get("fold") or row.get("scene_type")

            if not video_id or not split:
                print("Skipping row because missing video_id or split:", row)
                continue
            
            output_dir = Path(output_fol) / video_id
            output_dir.mkdir(parents=True, exist_ok=True)

            download_arkit_scene(video_id, split, download_dir)

            # Extract downloaded files from each scene or video id
            scene_dir = Path(download_dir) / "raw" / split / video_id
            intrinsics_dir = scene_dir / "vga_wide_intrinsics"   
            traj_path = scene_dir / "lowres_wide.traj"
            image_dir = scene_dir / "vga_wide"  

            download_arkit_labelmaker(video_id, split, scene_dir)
            
            mesh_path = scene_dir / "point_lifted_mesh.ply"
            labels_path = scene_dir / "labels.txt"

            # Quick sanity check before continuing
            if not (mesh_path.exists() and labels_path.exists() and traj_path.exists()):
                print(f"Missing required files for {video_id}, skipping.")
                return
            
            # --- Read camera poses from the trajectory file ---
            try:
                poses = init_arkit.read_traj(traj_path).items()
            except Exception as e:
                print(f"Failed to read traj for {video_id}: {e}")
                return
            
            # --- Prepare stats container and where we'll later save it ---
            stats_path = output_dir / f"{video_id}.json"
            class_pixel_stats = {
                "door": {},
                "switch": {}
            }

            # --- Enumerate frames (every 10th PNG), and find nearest pose for each ---
            png_files = sorted([f for f in os.listdir(image_dir) if f.endswith(".png")])
            print(f"Found {len(png_files)} PNG frames in {image_dir}")

            for idx, filename in enumerate(png_files):
                if idx % 30 != 0:
                    continue

                extracted_ts = init_arkit.extract_timestamp_from_filename(filename)
                if extracted_ts is None:
                    print(f"Invalid timestamp in filename: {filename}")
                    continue

                closest_ts, (rotvec, transvec) = init_arkit.get_pose_for_nearest_timestamp(extracted_ts, poses)
                print(f"Image: {filename} â†’ Extracted TS: {extracted_ts} â†’ Closest Pose TS: {closest_ts}")

                # Build frame-specific file names/paths we need to process later
                frame_name = f"{video_id}_{extracted_ts}"
                pincam_path = intrinsics_dir / f"{frame_name}.pincam"
                frame_path = image_dir / f"{frame_name}.png"

                # Skip if either the intrinsics file or the image is missing
                if not pincam_path.exists() or not frame_path.exists():
                    print(f"Missing data for frame {extracted_ts} in scene {video_id}, skipping frame.")
                    continue

                # The heavy lifting (reading the image, projecting labels, collecting stats, and saving the overlay) comes next.
                # We now have: rotvec, transvec, pincam_path, frame_path ready to use.
                try:
                    rgb_img = cv2.imread(str(frame_path))
                    if rgb_img is None:
                        print(f"Failed to read image {frame_path}")
                        continue
                
                    # --------------------------------------------------
                    # FIX ORIENTATION HERE (ONCE PER FRAME)
                    # --------------------------------------------------
                    raw_roll = init_arkit.pixel_roll(rotvec)
                    snap_roll = init_arkit.snap_roll_to_canonical(raw_roll)

                    if snap_roll is not None:
                        A, nW, nH = init_arkit.compute_roll_affine(*rgb_img.shape[:2], snap_roll)
                        rgb_img = cv2.warpAffine(
                            rgb_img,
                            A,
                            (nW, nH),
                            flags=cv2.INTER_LINEAR,
                            borderMode=cv2.BORDER_CONSTANT,
                            borderValue=(0, 0, 0)   # black background
                        )

                    projection, contains_target, label_counts, total_pixels, door_instances_2d = init_arkit.project_instance(
                    mesh_path=mesh_path,
                        labels_path=labels_path,
                        pincam_path=str(pincam_path),
                        rotation_vec=rotvec,
                        translation_vec=transvec,
                        rgb_frame=rgb_img,
                        alpha=0.6,
                    )

                    # Update stats only for frames that contain the target category
                    for obj_type, target_present in zip(["door"], contains_target):
                        if not target_present:
                            continue

                        ts = extracted_ts

                        if not door_instances_2d:
                            continue

                        for inst_id, inst_data in door_instances_2d.items():
                            inst_id = str(inst_id)

                            # Create instance bucket
                            if inst_id not in class_pixel_stats[obj_type]:
                                class_pixel_stats[obj_type][inst_id] = {}

                            if ts not in class_pixel_stats[obj_type][inst_id]:
                                class_pixel_stats[obj_type][inst_id][ts] = {"labels": []}

                            # Record pixel counts + percentages for each label id in the projection
                            for label_id, count in label_counts.items():
                                percentage = (count / max(total_pixels, 1)) * 100.0
                                class_pixel_stats[obj_type][inst_id][ts]["labels"].append({
                                    "label_id": int(label_id),
                                    "pixel_count": int(count),
                                    "pixel_percentage": percentage
                                })

                    # Save visualization overlay
                    overlay_path = Path(output_dir) / f"overlay_{extracted_ts}.png"
                    cv2.imwrite(str(overlay_path), projection)

                except Exception as e:
                    print(f"Error in scene {video_id}, frame {extracted_ts}: {e}")
                    continue

            # --- After loop: persist the collected statistics for this scene ---
            with (output_dir / f"{video_id}.json").open("w") as jsonfile:
                json.dump(class_pixel_stats, jsonfile, indent=4)
            print(f"Saved statistics for scene {video_id} to {stats_path}")

            init_arkit.run_molmo(
                video_id=video_id,
                json_path=output_dir / f"{video_id}.json",
                image_path=image_dir,
                output_dir=output_dir,
                poses=poses
            )
# -------------------------------------------------------------
# Main execution
# -------------------------------------------------------------
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Download ARKitScenes dataset one by one")
    parser.add_argument("--csv_path", default="raw_train_val_splits.csv")
    parser.add_argument("--download_dir", default="arkitscenes")
    parser.add_argument("--output_dir", default="segmentation_summary")

    # FIX: allow Jupyter to pass extra arguments like --f=...
    args, unknown = parser.parse_known_args()

    download_arkit_dataset(
        csv_path=args.csv_path,
        download_dir=args.download_dir
    )



IndentationError: expected an indented block after 'else' statement on line 1421 (init_arkit.py, line 1424)