# Full Tracker: Download, Process, and Upload Data
This notebook demonstrates the full pipeline for handling raw rgb. and .csq data:
1. Download data from a cloud bucket.
2. Process the data (e.g., align videos, run detection, and tracking).
3. Upload the processed data back to the cloud bucket.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path

# Import Environment Variables
from dotenv import load_dotenv

# Import Utility Functions
from collab_env.data.file_utils import expand_path, get_project_root
from collab_env.data.gcs_utils import GCSClient

# Import Custom Scripts
from collab_env.tracking.alignment_gui import align_videos
from collab_env.tracking.model.local_model_inference import infer_with_yolo
from collab_env.tracking.model.local_model_tracking import (
    generate_thermal_masks_from_bboxes,
    get_detections_from_video,
    output_tracked_bboxes_csv,
    output_tracked_bboxes_csv,
    overlay_tracks_on_video,
    run_tracking,
)
from collab_env.tracking.thermal_processing import (
    process_directory,
    validate_session_structure,
)
from collab_env.tracking.visualization import (
    export_tracks_with_masks,
    overlay_tracks_on_video,
    plot_tracks_at_frame_bbox_from_video,
)

### Set flags for data processing

In [None]:
skip_download = True
skip_thermal_extraction = False

### Setup gcloud

In [None]:
# Setup Configuration
load_dotenv()
data_key = expand_path(os.environ.get("COLLAB_DATA_KEY", ""), get_project_root())
PROJECT_ID = "collab-data-463313"

CREDENTIALS_PATH = expand_path(data_key.as_posix(), get_project_root())

# Connect to Google Cloud Storage
gcs_client = GCSClient(
    project_id=PROJECT_ID,
    credentials_path=CREDENTIALS_PATH,
)

Check available buckets

In [None]:
# Verify connection
print("Available buckets:", gcs_client.list_buckets())

Show files within buckets

In [None]:
BUCKET_NAME = "fieldwork_curated"  # Update with your bucket name
gcs_client.glob(f"{BUCKET_NAME}/*")

Select a session to download and process

In [None]:
# Download Data from Cloud Bucket
SESSION_FOLDER = "2024_02_06-session_0001"  # Update with data folder (session)
CLOUD_PREFIX = f"{BUCKET_NAME}/{SESSION_FOLDER}"
gcs_client.glob(f"{CLOUD_PREFIX}/**")

LOCAL_DOWNLOAD_DIR = expand_path(f"data/raw/{SESSION_FOLDER}", get_project_root())
LOCAL_PROCESSED_DIR = expand_path(
    f"data/processed/{SESSION_FOLDER}", get_project_root()
)

Download from gcloud

In [None]:
if not skip_download:
    if not LOCAL_DOWNLOAD_DIR.exists():
        LOCAL_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
    if not LOCAL_PROCESSED_DIR.exists():
        LOCAL_PROCESSED_DIR.mkdir(parents=True, exist_ok=True)

    for blob in gcs_client.glob(f"{CLOUD_PREFIX}/**"):
        relative_path = Path(blob).relative_to(f"{CLOUD_PREFIX}")
        local_name = relative_path.name
        suffix = relative_path.suffix
        print(f"local_name: {local_name}, suffix: {suffix}")
        if len(str(suffix)) > 0:
            # print("File!")
            parent_folder = relative_path.parent
            if not Path(LOCAL_DOWNLOAD_DIR / parent_folder).exists():
                print(f"Creating folder: {LOCAL_DOWNLOAD_DIR / parent_folder}")
                Path(LOCAL_DOWNLOAD_DIR / parent_folder).mkdir(
                    parents=True, exist_ok=True
                )
            # print(f"parent_folder: {parent_folder}")
            local_path = LOCAL_DOWNLOAD_DIR / parent_folder / local_name
            print(f"Downloading file: {blob} to {local_path}")
            gcs_client.gcs.get_file(blob, str(local_path))
        else:

            if not Path(LOCAL_PROCESSED_DIR / relative_path).exists():
                print(f"Creating folder: {LOCAL_PROCESSED_DIR / relative_path}")
                Path(LOCAL_PROCESSED_DIR / relative_path).mkdir(
                    parents=True, exist_ok=True
                )
        # check if there is an extension, if not this is a folder and we need to create it

    # print("Downloaded files:", list(LOCAL_DOWNLOAD_DIR.iterdir()))

Ensure everything worked properly

In [None]:
# Validate session structure
print("Validating session structure...")
issues = validate_session_structure(LOCAL_DOWNLOAD_DIR)
print(f"Issues found: {issues if len(issues)>0 else 'None'}")

### Extract thermal video info

In [None]:
if not skip_thermal_extraction:
    # thermal files processing
    print("Processing thermal files...")

    # call with preview=False to choose the vmin/vmax automatically, otherwise the user will be asked to choose the vmin/vmax
    # process_directory(folder_path=LOCAL_DOWNLOAD_DIR, out_path=LOCAL_DOWNLOAD_DIR, color='magma', preview=True, max_frames=None, fps=30)
    process_directory(
        folder_path=LOCAL_DOWNLOAD_DIR,
        out_path=LOCAL_PROCESSED_DIR,
        color="magma",
        preview=True,
        max_frames=100,
        fps=30,
    )

#### Spatiotemporal alignment of thermal and RGB 

In [None]:
# default parameters for alignment

frame_size = (640, 480)  # Default frame size
max_frames = None  # Process all frames by default
warp_to = "rgb"  # Default warp to rgb, thermal is changing, not rgb
rotation_angle = 0.0  # Default rotation angle
skip_homography = False  # Default to not skip homography
skip_translation = False  # Default to not skip translation
camera_numbers = [1, 2]

for camera in camera_numbers:
    print(f"Processing camera {camera}...")

    # Dynamically find the RGB and thermal MP4 files
    rgb_dir = LOCAL_DOWNLOAD_DIR / f"rgb_{camera}"
    thermal_dir = LOCAL_PROCESSED_DIR / f"thermal_{camera}"

    # Find the MP4 file in the RGB directory
    rgb_video_files = list(rgb_dir.glob("*.MP4")) + list(rgb_dir.glob("*.mp4"))
    print("files in rgb_dir:", rgb_video_files)
    if len(rgb_video_files) == 0:
        print(f"No MP4 file found in {rgb_dir}. Skipping camera {camera}.")
        continue
    elif len(rgb_video_files) > 1:
        print(f"Multiple MP4 files found in {rgb_dir}. Using the first one.")
    rgb_video_path = rgb_video_files[0]

    # Find the MP4 file in the thermal directory
    thermal_video_files = list(thermal_dir.glob("*.mp4")) + list(
        thermal_dir.glob("*.MP4")
    )
    print("files in thermal_dir:", thermal_video_files)
    if len(thermal_video_files) == 0:
        print(f"No MP4 file found in {thermal_dir}. Skipping camera {camera}.")
        continue
    elif len(thermal_video_files) > 1:
        print(f"Multiple MP4 files found in {thermal_dir}. Using the first one.")
    thermal_video_path = thermal_video_files[0]

    print(f"RGB video path: {rgb_video_path}")
    print(f"Thermal video path: {thermal_video_path}")

    output_dir_rgb = LOCAL_PROCESSED_DIR / "aligned" / f"rgb_{camera}"
    output_dir_thm = LOCAL_PROCESSED_DIR / "aligned" / f"thermal_{camera}"
    output_dir_rgb.mkdir(parents=True, exist_ok=True)
    output_dir_thm.mkdir(parents=True, exist_ok=True)

    # Align videos
    print(f"Aligning videos for camera {camera}...")

    align_videos(
        rgb_video_path,
        thermal_video_path,
        output_dir_rgb,
        output_dir_thm,
        frame_size=frame_size,
        max_frames=max_frames,
        warp_to=warp_to,
        rotation_angle=rotation_angle,
        skip_homography=skip_homography,
        skip_translation=skip_translation,
    )

#### Object detection and tracking

In [None]:
# Detection and tracking
print("Running detection...")
for camera in camera_numbers:
    print(f"Running detection and tracking on: thermal_{camera}")

    # Define paths for the thermal video and model inference
    thermal_video_path = (
        LOCAL_PROCESSED_DIR
        / "aligned_frames"
        / f"thermal_{camera}"
        / "warped_thermal_adjusted.mp4"
    )
    if not thermal_video_path.exists():
        print(f"Thermal video not found for camera {camera}. Skipping...")
        continue

    rgb_video_path = (
        LOCAL_PROCESSED_DIR
        / "aligned_frames"
        / f"rgb_{camera}"
        / "cropped_rgb_adjusted.mp4"
    )

    # Run local_model_inference script
    try:
        detect_csv = (
            LOCAL_PROCESSED_DIR
            / "aligned_frames"
            / f"thermal_{camera}"
            / f"detections_{camera}.csv"
        )
        output_video_path = (
            LOCAL_PROCESSED_DIR
            / "aligned_frames"
            / f"thermal_{camera}"
            / "annotated_warped_thermal.mp4"
        )
        checkpoint_path = (
            LOCAL_DOWNLOAD_DIR / "yolo11_weights.pt"
        )  # Update with your model path

        infer_with_yolo(
            video_path=thermal_video_path,
            model_path=checkpoint_path,
            output_csv_path=detect_csv,
        )
        print(
            f"Object detection completed for camera {camera}. Results saved to {detect_csv}."
        )
    except Exception as e:
        print(f"Error during object detection for camera {camera}: {e}")
        continue

In [None]:
# running tracking
for camera in camera_numbers:

    # Define paths for the thermal video and model inference
    thermal_video_path = (
        LOCAL_PROCESSED_DIR
        / "aligned_frames"
        / f"thermal_{camera}"
        / "warped_thermal_adjusted.mp4"
    )
    if not thermal_video_path.exists():
        print(f"Thermal video not found for camera {camera}. Skipping...")
        continue

    rgb_video_path = (
        LOCAL_PROCESSED_DIR
        / "aligned_frames"
        / f"rgb_{camera}"
        / "cropped_rgb_adjusted.mp4"
    )
    if not rgb_video_path.exists():
        print(f"RGB video not found for camera {camera}. Skipping...")
        continue

    # Check if the detection CSV exists
    detect_csv = (
        LOCAL_PROCESSED_DIR
        / "aligned_frames"
        / f"thermal_{camera}"
        / f"detections_{camera}.csv"
    )
    if not detect_csv.exists():
        print(f"Detection CSV not found for camera {camera}. Skipping tracking.")
        continue
    # Visualize detections on the thermal and RGB videos
    # visualization
    get_detections_from_video(
        csv_path=detect_csv,
        video_path=thermal_video_path,
        output_video_path=LOCAL_PROCESSED_DIR
        / "aligned_frames"
        / f"thermal_{camera}"
        / f"visualized_thermal_{camera}.mp4",
    )
    # visualization
    get_detections_from_video(
        csv_path=detect_csv,
        video_path=rgb_video_path,
        output_video_path=LOCAL_PROCESSED_DIR
        / "aligned_frames"
        / f"rgb_{camera}"
        / f"visualized_rgb_{camera}.mp4",
    )
    # Run tracking
    print(f"Running tracking on: camera {camera}")
    run_tracking(LOCAL_PROCESSED_DIR / "aligned_frames", "thermal", camera)
    run_tracking(LOCAL_PROCESSED_DIR / "aligned_frames", "rgb", camera)
    tracked_csv = (
        LOCAL_PROCESSED_DIR
        / "aligned_frames"
        / f"rgb_{camera}"
        / f"rgb_{camera}_tracks.csv"
    )
    if not tracked_csv.exists():
        print(f"Tracking CSV not found for camera {camera}. Skipping visualization.")
        continue
    # Output tracked bounding boxes to CSV
    output_tracked_bboxes_csv(
        track_csv=tracked_csv,
        detect_csv=detect_csv,
        output_csv=Path(
            LOCAL_PROCESSED_DIR
            / f"aligned_frames/rgb_{camera}/tracked_bboxes_{camera}.csv"
        ),
        iou_threshold=0.1,
    )  # Lower threshold if tracks are not centered

    print(f"Visualizing tracks for rgb camera {camera}...")
    overlay_tracks_on_video(
        csv_path=tracked_csv,
        frame_dir=LOCAL_PROCESSED_DIR
        / "aligned_frames"
        / f"rgb_{camera}"
        / "annotated_frames",
        output_video=LOCAL_PROCESSED_DIR
        / "aligned_frames"
        / f"rgb_{camera}"
        / f"overlayed_tracks_{camera}.mp4",
    )

### Optional: Visualization

In [None]:
frame_number = 100  # Example frame number to visualize
camera = 1  # Example camera number to visualize

print(f"Plotting tracks at frame {frame_number} for camera {camera}...")
plot_tracks_at_frame_bbox_from_video(
    tracked_bboxes_csv=tracked_csv,
    video_path=rgb_video_path,
    output_image=LOCAL_PROCESSED_DIR
    / "aligned_frames"
    / f"rgb_{camera}"
    / f"tracked_bboxes_{camera}_{frame_number}.png",
    frame_number=frame_number,
    max_frame=1000,
)

# If possible, generate pixel masks within each bounding box
print(f"Exporting tracks with masks for camera {camera}...")

generate_thermal_masks_from_bboxes(
    bbox_csv=LOCAL_PROCESSED_DIR
    / "aligned_frames"
    / f"rgb_{camera}"
    / f"tracked_bboxes_{camera}.csv",
    video_path=LOCAL_PROCESSED_DIR
    / "aligned_frames"
    / f"thermal_{camera}"
    / f"warped_thermal_{camera}_adjusted.mp4",
    output_mask_dir=LOCAL_PROCESSED_DIR
    / "aligned_frames"
    / f"thermal_{camera}"
    / "masks",
    temp_threshold=128,  # default threshold for thermal images, half of 255
    mask_value=255,
)
export_tracks_with_masks(
    tracked_bboxes_csv=tracked_csv,
    mask_dir=LOCAL_PROCESSED_DIR / "aligned_frames" / f"rgb_{camera}" / "masks",
    output_csv=LOCAL_PROCESSED_DIR
    / "aligned_frames"
    / f"rgb_{camera}"
    / f"tracked_bboxes_{camera}_with_masks.csv",
)

### Optional: Push processed files

In [None]:
# Upload Processed Data to Cloud Bucket
BUCKET_NAME = "fieldwork_processed"
CLOUD_PROCESSED_PREFIX = (
    "your-cloud-processed-prefix"  # Update with your processed data prefix
)
for file in LOCAL_PROCESSED_DIR.iterdir():
    cloud_path = f"{BUCKET_NAME}/{CLOUD_PROCESSED_PREFIX}/{file.name}"
    gcs_client.upload_file(str(file), cloud_path)

print("Uploaded processed files:", list(LOCAL_PROCESSED_DIR.iterdir()))