diff --git a/Changelog.md b/Changelog.md index d13dc2f..6af7b58 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,68 +1,164 @@ -# VisionDepth3D v3.8 – Changelog +# VisionDepth3D v3.8.2 - Changelog + +--- + +> This release delivers major performance improvements to both live 3D preview and offline rendering, alongside new depth engines, stability fixes, and encoding reliability upgrades. --- ## 1) Depth Estimation Tab -### Depth Models +### UI Depth Tab Labelling -- Fixed ONNX model loading: - - Distill-Any-Depth (inference resolution 518×518, batch size 8) - - Video Depth Anything (inference resolution 512×288, batch size 8) -- Implemented LBM depth model (development version). Thanks to Aether for the implementation fix. -- Removed depth models from the dropdown that returned no `d_type`. -- Fixed Hugging Face model downloads and caching so zoo models consistently save inside the app `weights/` directory (no more extra `.cache` downloads). -- Updated Transformers image processor loading to prefer `use_fast=True` when available (with automatic fallback when unsupported). +- Renamed the Depth Estimation tab to **Depth Engine** to better reflect multi-backend depth processing. +- Reduced console warning spam related to sequential pipeline usage by suppressing the specific Hugging Face warning message. -### Depth Backend +### Depth Anything 3 (DA3) Adapter Integration -- Implemented temporal smoothing in the depth pipeline to reduce flicker and improve temporal stability of depth map output. -- Packaged VisionDepth3D.exe with Distill-Any-Depth (ONNX), Video Depth Anything (ONNX), and Depth Anything v2 Giant weights. +- Added native Depth Anything 3 backend support via a dedicated DA3 adapter (separate from Hugging Face pipeline models). +- Implemented DA3 model loading through Hugging Face `from_pretrained` with VD3D cache routing into the `weights/` directory. +- Added DA3 model entries to the model selector (DA3-SMALL / BASE / LARGE / GIANT and DA3METRIC variants). +- Wired DA3 inference into the unified depth pipeline so it works with both image and video depth workflows. +- Mapped the UI “Inference Resolution” dropdown into DA3’s `process_res` logic (single max-side target resolution), with a video-friendly cap applied to prevent excessive internal upscaling. +- Normalized DA3 depth outputs into a consistent 0–1 range to match existing VD3D depth handling and export logic. +- Depth polarity handling for DA3 metric models remains user-controlled via the “Invert Depth” toggle. +- Improved DA3 batching compatibility by supporting list-of-PIL inference and ensuring returned depth frame counts match input batch size (with a per-image fallback if needed). +- Added a DA3 warm-up pass during model load to reduce first-frame hitching and confirm the backend is initialized correctly. ---- +### Video Depth Anything (VDA) Adapter Integration -## 2) 3D Render Tab +- Added native **Video Depth Anything** backend support via a dedicated VDA adapter for sequence-based video depth inference. +- Implemented VDA model loading directly from Hugging Face repositories (e.g. `depth-anything/Video-Depth-Anything-*`) with automatic checkpoint download and caching. +- Integrated VDA into the unified depth pipeline so it can be selected and used alongside DA3, ONNX, and Hugging Face depth models. +- Enabled sequence-aware inference for video input, allowing VDA to process temporal frame batches instead of independent per-frame depth estimation. +- Added configurable target FPS handling for VDA to reduce inference load on high-FPS sources by running depth inference at a lower temporal rate. +- Ensured VDA output depth frames are normalized into VD3D’s standard 0–1 depth range for compatibility with existing export, blending, and 3D rendering logic. +- Wired VDA output into the same post-processing, temporal normalization, and letterbox-handling pipeline used by other depth engines. +- Added VDA model warm-up during load to verify backend initialization and reduce first-inference latency. +- Depth polarity for VDA models remains user-controlled via the existing “Invert Depth” toggle for consistency across all depth engines. -### UI Fixes +### ONNX Model Fixes & Stability Improvements -- Added buttons for encoder settings and processing options. -- Implemented multi-language support and tooltips for new dialog boxes. -- Adjusted preview image window size and video info layout to prevent window overflow. -- 3D tab columns now stack correctly when resizing the window on smaller screens. +- Fixed Distill-Any-Depth ONNX models (Small / Base / Large) failing to run due to internal tensor shape mismatch. +- Distill-Any-Depth ONNX models now correctly use a fixed 518×518 inference size, matching their exported positional embedding grid. +- Added automatic detection for Distill-Any-Depth ONNX models and enforced fixed input resolution internally. +- Updated ONNX image preprocessing to preserve aspect ratio using padding instead of stretching, improving depth stability and quality on widescreen content. +- ONNX warm-up now succeeds reliably for Distill-Any-Depth models without broadcast or Add-node errors. +- Enabled safe ONNX Runtime graph optimizations to reduce unnecessary memory copies and warning spam. +- Added clearer ONNX model identification output in the console so users can see exactly which ONNX model is being loaded. -### 3D Backend +### Model List Consistency -- Reworked Auto Crop Black Bars to use first-frame detection with cached crop reuse. -- Prevents per-frame crop jitter and depth/frame misalignment. -- Improves stability for cinema content with subtle letterboxing. -- Keep Audio checkbox now respects the user-selected output container instead of forcing MP4. +- Fixed missing Distill-Any-Depth ONNX models in the depth inference script while still being listed in the UI. +- Ensured ONNX model availability in the UI now correctly matches backend support. ---- +### Video Encoding / Codec Handling -## Frametool Backend +- Fixed CPU and GPU FFmpeg codecs (libx264, libx265, NVENC, AMF, QSV) being incorrectly routed through OpenCV’s VideoWriter. +- Non-OpenCV-safe codecs are now encoded via FFmpeg piping, preventing OpenH264 DLL errors and codec initialization failures. +- OpenCV VideoWriter is now limited to compatible FourCC codecs (mp4v, XVID, DIVX) with automatic fallback handling. -- Reworked Frametool backend to support SSResNet models for feature model integration. +### Depth Inference Performance & Pipeline Optimizations ---- +- Reduced redundant image resizing during video depth inference to avoid double-scaling overhead. +- Consolidated resize to a single pass per frame, reducing CPU overhead. +- Enabled CUDA-optimized memory layout (`channels_last`) for Hugging Face depth models when running on GPU. +- Improved FP16 inference handling for supported Hugging Face models to increase throughput on CUDA devices. +- Optimized ONNX Runtime session configuration using safe graph optimizations and memory arena usage. +- Improved batch handling logic to reduce per-frame overhead during video processing. +- FFmpeg piping is now preferred by default for video output, significantly reducing encoding bottlenecks. -## Console Improvements +### Letterbox & Black Bar Handling (Video) -- Standardized startup console messages to clearly reflect which subsystems are initializing (Torch, depth estimation, upscaler, external 3D pipeline, language, settings). -- Unified compute device reporting across pipelines for consistent and clearer console output. -- Suppressed optional xFormers dependency warning on startup. -- Prevented duplicate language loading during settings restore. +- Fixed letterbox (black bar) regions incorrectly contributing to depth inference. +- Depth estimation now consistently ignores top and bottom letterbox bars instead of assigning artificial depth. +- Improved letterbox detection with multi-frame fallback probing and stabilization to prevent flicker. +- Letterbox regions are now filled with a neutral depth value, preventing pop-out artifacts and white banding in 3D renders. --- -## Summary +## 2) 3D Video Generator Tab + +### 3D Rendering Pipeline Performance & Stability + +- Implemented full render-state reset at the start of each video and image render to prevent temporal drift and accumulated smoothing artifacts between sessions. +- Reset internal pixel shift EMA buffers per render, ensuring clean disparity initialization and improved real-time stability. +- Reset floating window convergence trackers and easing states to eliminate carry-over offsets and unintended masking behavior across renders. +- Reinitialized depth percentile normalization per render, allowing depth range calibration to adapt cleanly to each clip for more consistent parallax response. +- Improved convergence and floating window behavior during the first frames of each render, eliminating “settling” artifacts and jitter. +- Resulted in significantly smoother live 3D playback and notable FPS improvements during real-time rendering. + +### Output Geometry & Eye Mode Fixes + +- Fixed output sizing logic for VR, Passive Interlaced, and single-eye export modes. +- Ensured per-eye resolution handling remains consistent across all 3D formats. +- Corrected floating window width calculations to always operate on per-eye dimensions instead of SBS frame width. +- Added safety resizing to guarantee encoded frames always match target output resolution. + +### Preview GUI + +- Preview GUI now supports an optional Convergence Crosshairs overlay for faster convergence tuning. + +### UI Label Consistency + +- Fixed mismatched labels for Foreground Shift and Background Shift. +- Sliders now correctly match their tooltips. + +### Encoding Settings Layout + +- Reworked the Encoding Settings dialog layout for improved spacing and readability. +- Grouped checkboxes, dropdowns, and quality controls into clearer rows. -v3.8 focuses on stabilizing depth estimation, improving model compatibility, -and refining the 3D Render tab UI with better layout behavior, clearer diagnostics, and improved localization support. +### Processing Options -> Back up your `weights/` and `presets/` folders before uninstalling v3.7. -> Then run VisionDepth3D_Setup_Downloader to download the official -> VisionDepth3D v3.8 Windows installer and required `.bin` files. +- Moved Clip Range (start/end time) controls into the Processing Options dialog. +- Clip range settings respect the selected UI language and include translated labels and tooltips. + +### Menu Fixes, Presets, and Updater Integration + +- Help → Check Updates now launches the bundled **VisionDepth3D Updater** window (`VisionDepth3D_Updater.exe`) to download and install the latest official Windows release. +- Added a confirmation prompt before launching the updater, since VisionDepth3D closes itself to allow safe updating. +- Fixed **File → Load Preset** failing from the dropdown due to the preset apply function not being available in scope. +- Fixed **File → Output Path** dropdown not opening the save dialog while the hotkey worked, by routing the menu action through the same handler used by `Ctrl+O`. +- Removed **Save Settings** and **Load Settings** from the File menu since preset save/load already covers the same workflow and simplifies the UI. + +## 3) VD3D Live 3D (Real-Time Depth + SBS Pipeline) + +### Live Depth Inference Performance Overhaul + +- Implemented persistent GPU tensor staging for live frame uploads, eliminating per-frame CUDA allocations and significantly reducing memory transfer overhead. +- Optimized live depth input preprocessing to reuse GPU buffers instead of recreating tensors each inference cycle. +- Reduced redundant CPU to GPU conversions during live depth updates. +- Improved FP16 autocast handling for Depth Anything V2 live inference to ensure stable mixed-precision execution on CUDA. + +### Real-Time Pixel Shift Pipeline Optimization + +- Added persistent CUDA frame buffers for the live pixel-shift SBS renderer to avoid per-frame GPU reallocations. +- Reduced per-frame normalization overhead by using in-place GPU operations. +- Improved handling of mixed return types from `pixel_shift_cuda` (CUDA tensors or NumPy fallback), ensuring stable live output without crashes. +- Prevented pipeline stalls caused by repeated tensor construction and shape revalidation. + +### Live Depth Update Scheduling & Stability + +- Implemented controlled depth refresh rate (Depth FPS) to decouple depth inference from preview frame rate for smoother live playback. +- Improved EMA depth smoothing behavior for live mode to reduce temporal jitter while preserving responsiveness. +- Reduced live preview hitching caused by first-frame warm-up and inference spikes. + +### Live Capture & Preview Improvements + +- Reduced capture overhead by allowing lower capture FPS without affecting SBS rendering smoothness. +- Improved screen capture pacing using high-precision timers to prevent uneven frame delivery. +- Improved live preview stability when mixing screen capture and GPU depth inference. + +### Overall Live Mode Gains + +- Live 3D preview performance increased by approximately 40 to 70 percent depending on GPU and inference resolution. +- Significantly reduced stutter caused by GPU memory churn. +- More consistent frame pacing for real-time SBS output. + +--- -> (Optional but recommended) Clear the Hugging Face cache to free space and -> avoid duplicate model downloads: -> `C:\Users\YOUR_USERNAME\.cache\huggingface` +> **Upgrade Note** +> Back up your `weights/` and `presets/` folders before uninstalling v3.8.1 +> Then run **VisionDepth3D_Setup_Downloader** to download the official +> VisionDepth3D v3.8.2 Windows installer and required `.bin` files. diff --git a/LICENSE.txt b/LICENSE.txt index 767cc81..22503f9 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,4 +1,4 @@ -Copyright (c) 2025 Johnathan Carpenter. All rights reserved +Copyright (c) 2026 Johnathan Carpenter. All rights reserved This License Agreement ("Agreement") is a legal agreement between you ("User") and VisionDepth ("Licensor") regarding the use of VisionDepth3D ("Software"). By downloading, installing, or using the Software, you acknowledge and agree to be bound by the terms of this Agreement. @@ -45,3 +45,4 @@ This Agreement shall be governed by and interpreted in accordance with the laws 9. Contact For inquiries regarding this Agreement, contact: redsky90@gmail.com + diff --git a/VisionDepth3D.py b/VisionDepth3D.py index e31a72e..9816d6f 100644 --- a/VisionDepth3D.py +++ b/VisionDepth3D.py @@ -9,6 +9,7 @@ from threading import Event from core.audio import launch_audio_gui import queue +import subprocess # ── External Libraries ─────────────────────────── @@ -81,6 +82,10 @@ from core.preview_gui import open_3d_preview_window from core.models.depth_anything_v2.dpt import DepthAnythingV2 +from transformers import logging +logging.set_verbosity_error() + + # At the top of GUI.py cancel_requested = threading.Event() os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" @@ -133,6 +138,45 @@ TORCH_AVAILABLE = False TORCH_DEVICE_NAME = "cpu" +def _app_dir(): + if getattr(sys, "frozen", False): + return os.path.dirname(sys.executable) + return os.path.dirname(os.path.abspath(__file__)) + +def launch_setup_downloader(): + if not messagebox.askyesno( + "Update VisionDepth3D", + "This will close VisionDepth3D and open the updater.\n\nContinue?" + ): + return + + base = _app_dir() + exe_name = "VisionDepth3D_Updater.exe" + + path = os.path.join(base, exe_name) + if not os.path.exists(path): + path2 = os.path.join(base, "tools", exe_name) + if os.path.exists(path2): + path = path2 + + if not os.path.exists(path): + messagebox.showerror( + "Updater not found", + f"Could not find:\n{exe_name}\n\nLooked in:\n{base}\n{os.path.join(base,'tools')}" + ) + return + + try: + subprocess.Popen([path], cwd=os.path.dirname(path)) + except Exception as e: + messagebox.showerror("Failed to launch updater", str(e)) + return + + # close VD3D after updater starts + try: + root.after(150, root.quit) # root.destroy also works, quit is safer for Tk loops + except Exception: + pass def gpu_available(): try: @@ -202,6 +246,39 @@ def ui_select_output_path(): except Exception: pass +def _apply_preset_config(config: dict): + """ + Applies preset JSON values to existing Tk variables safely. + Ignores unknown keys so older/newer presets never crash. + """ + if not isinstance(config, dict): + raise ValueError("Preset must be a JSON object") + + missing = [] + + for key, value in config.items(): + var = globals().get(key) + + if var is None: + missing.append(key) + continue + + try: + if hasattr(var, "set"): + var.set(value) + except Exception: + pass + + # Optional: refresh UI after applying + try: + refresh_ui_labels() + except Exception: + pass + + if missing: + print(f"Preset keys not bound to UI vars (ignored): {missing}") + + def load_preset_dialog(): """Pick any preset JSON (defaulting to PRESET_DIR) and apply it.""" @@ -1070,7 +1147,7 @@ def _on_shift_wheel(self, event): # --- Window Setup --- root = tk.Tk() -root.title("VisionDepth3D v3.8") +root.title("VisionDepth3D v3.8.2") screen_w = root.winfo_screenwidth() screen_h = root.winfo_screenheight() @@ -1266,15 +1343,9 @@ def mk_menu(): file_btn = ttk.Menubutton(hdr, text="File", style="VD.Menu.TMenubutton") file_menu = mk_menu() - _menu_add(file_menu, "Save Settings", save_settings, "Ctrl+S") - MENUS["FILE_IDX"]["save_settings"] = file_menu.index("end") - - _menu_add(file_menu, "Save Preset As…", prompt_and_save_preset, "Ctrl+Shift+S") + _menu_add(file_menu, "Save Preset As…", prompt_and_save_preset, "Ctrl+S") MENUS["FILE_IDX"]["save_preset_as"] = file_menu.index("end") - _menu_add(file_menu, "Load Settings", load_settings, "Ctrl+L") - MENUS["FILE_IDX"]["load_settings"] = file_menu.index("end") - _menu_add(file_menu, "Load Preset…", load_preset_dialog, "Ctrl+P") MENUS["FILE_IDX"]["load_preset"] = file_menu.index("end") @@ -1286,7 +1357,12 @@ def mk_menu(): _menu_add(file_menu, "Depth Map", ui_select_depth_map, "Ctrl+D") MENUS["FILE_IDX"]["depth_map"] = file_menu.index("end") - _menu_add(file_menu, "Output Path", select_output_video, "Ctrl+O") + _menu_add( + file_menu, + "Output Path", + lambda: select_output_video(output_sbs_video_path), + "Ctrl+O" + ) MENUS["FILE_IDX"]["output_path"] = file_menu.index("end") _menu_add(file_menu, "Generate 3D", handle_generate_3d, "Shift+Enter") @@ -1323,7 +1399,7 @@ def mk_menu(): command=lambda: messagebox.showinfo( "About VisionDepth3D", ( - "VisionDepth3D v3.8\n" + "VisionDepth3D v3.8.2\n" "----------------------------\n" "A hybrid 2D-to-3D conversion suite for cinema and VR.\n\n" "Features:\n" @@ -1334,7 +1410,7 @@ def mk_menu(): " • Real-time preview & batch processing\n\n" "Website: " + VD_WEBSITE + "\n" "GitHub: " + VD_GITHUB + "\n" - "© 2025 VisionDepth3D" + "© 2026 VisionDepth3D" ) ) ) @@ -1347,7 +1423,9 @@ def mk_menu(): help_menu.add_separator() - _menu_add(help_menu, t("Help.CheckUpdates"), open_releases, "F6"); MENUS["HELP_IDX"]["updates"] = help_menu.index("end") + _menu_add(help_menu, t("Help.CheckUpdates"), launch_setup_downloader, "F6") + MENUS["HELP_IDX"]["updates"] = help_menu.index("end") + _menu_add(help_menu, t("Help.ReportBug"), open_issues, "F7"); MENUS["HELP_IDX"]["report"] = help_menu.index("end") _menu_add(help_menu, t("Help.AspectCheat"), open_aspect_ratio_CheatSheet,"F8");MENUS["HELP_IDX"]["aspect"] = help_menu.index("end") _menu_add(help_menu, t("Help.PreviewGUI"), handle_open_preview, "F9"); MENUS["HELP_IDX"]["preview"] = help_menu.index("end") @@ -1391,9 +1469,7 @@ def refresh_menu_labels(): if MENUS["help_btn"]: MENUS["help_btn"].config(text=t("Menu.Help")) # File - _menu_set(fm, f["save_settings"], t("Menu.SaveSettings"), "Ctrl+S") - _menu_set(fm, f["save_preset_as"], t("Menu.SavePresetAs"), "Ctrl+Shift+S") - _menu_set(fm, f["load_settings"], t("Menu.LoadSettings"), "Ctrl+L") + _menu_set(fm, f["save_preset_as"], t("Menu.SavePresetAs"), "Ctrl+S") _menu_set(fm, f["load_preset"], t("Menu.LoadPreset"), "Ctrl+P") _menu_set(fm, f["video"], t("Menu.Video"), "Ctrl+I") _menu_set(fm, f["depth_map"], t("Menu.DepthMap"), "Ctrl+D") @@ -1442,7 +1518,7 @@ def _on_language_change(code): # Shortcuts root.bind_all("", lambda e: root.quit()) -root.bind_all("", lambda e: messagebox.showinfo("About", "VisionDepth3D v3.8\n" +root.bind_all("", lambda e: messagebox.showinfo("About", "VisionDepth3D v3.8.2\n" "----------------------------\n" "A hybrid 2D-to-3D conversion suite for cinema and VR.\n\n" "Features:\n" @@ -1453,7 +1529,7 @@ def _on_language_change(code): " • Real-time preview & batch processing\n\n" "Created by: Johnathan Carpenter\n" "Website: https://github.com/VisionDepth/VisionDepth3D\n" - "© 2025 VisionDepth3D. All rights reserved.",)) + "© 2026 VisionDepth3D. All rights reserved.",)) root.bind_all("", lambda e: open_website()) root.bind_all("", lambda e: open_reddit()) root.bind_all("", lambda e: open_github()) @@ -1470,9 +1546,7 @@ def _on_language_change(code): root.bind_all("", lambda e: select_output_video(output_sbs_video_path)) # Output file # Presets -root.bind_all("", lambda e: save_settings()) # Save current settings -root.bind_all("", lambda e: load_settings()) # Load saved settings -root.bind_all("", lambda e: prompt_and_save_preset()) # Save as preset +root.bind_all("", lambda e: prompt_and_save_preset()) # Save as preset root.bind_all("", lambda e: load_preset_dialog()) # Load preset # Render @@ -1508,7 +1582,7 @@ def _on_language_change(code): # --- Depth Estimation GUI --- depth_estimation_frame = ttk.Frame(tab_control, style="VD3D.TFrame") -tab_control.add(depth_estimation_frame, text="Depth Estimation") +tab_control.add(depth_estimation_frame, text="Depth Engine") depth_tab_index = tab_control.index("end") - 1 depth_content_frame = tk.Frame(depth_estimation_frame, bg=BG_MAIN, highlightthickness=0, bd=0) @@ -2259,6 +2333,8 @@ def load_supported_models(): # "Distill-Any-Depth Large (keetrap)": "keetrap/Distill-Any-Depth-Large-hf", # "Distill-Any-Depth Small (keetrap)": "keetrap/Distill-Any-Depth-Small-hf", + "Video Depth Anything Large": "vda:depth-anything/Video-Depth-Anything-Large", + "Video Depth Anything Small": "vda:depth-anything/Video-Depth-Anything-Small", # in load_supported_models() "Video Depth Anything (ONNX)": "onnx:VideoDepthAnything", @@ -2266,10 +2342,17 @@ def load_supported_models(): "Distill-Any-Depth Base(ONNX)": "onnx:DistillAnyDepthBase", "Distill-Any-Depth Small(ONNX)": "onnx:DistillAnyDepthSmall", -# "DA3-GIANT": "depth-anything/DA3-GIANT", -# "DA3-LARGE": "depth-anything/DA3-LARGE", -# "DA3-BASE": "depth-anything/DA3-BASE", -# "DA3-SMALL": "depth-anything/DA3-SMALL", + "DA3METRIC-LARGE": "da3:depth-anything/DA3METRIC-LARGE", + "DA3MONO-LARGE": "da3:depth-anything/DA3MONO-LARGE", + "DA3-LARGE": "da3:depth-anything/DA3-LARGE", + "DA3-LARGE-1.1": "da3:depth-anything/DA3-LARGE-1.1", + "DA3-BASE": "da3:depth-anything/DA3-BASE", + "DA3-SMALL": "da3:depth-anything/DA3-SMALL", + "DA3-GIANT": "da3:depth-anything/DA3-GIANT", + "DA3-GIANT-1.1": "da3:depth-anything/DA3-GIANT-1.1", + "DA3NESTED-GIANT-LARGE": "da3:depth-anything/DA3NESTED-GIANT-LARGE", + "DA3NESTED-GIANT-LARGE-1.1": "da3:depth-anything/DA3NESTED-GIANT-LARGE-1.1", + # Depth Anything v2 "Depth Anything v2 Large": "depth-anything/Depth-Anything-V2-Large-hf", @@ -2291,7 +2374,7 @@ def load_supported_models(): # Other popular models # "DA-2 (Haodongli)": "haodongli/DA-2", -# "Bridge (Dingning)": "Dingning/BRIDGE", +# "Pixel-Perfect-Depth": "gangweix/Pixel-Perfect-Depth", "LBM Depth": "jasperai/LBM_depth", "DepthPro (Apple)": "apple/DepthPro-hf", "ZoeDepth (NYU+KITTI)": "Intel/zoedepth-nyu-kitti", @@ -3586,6 +3669,7 @@ def clear_clip(): bg="#1c1c1c", fg="white" ) bg_push_label.grid(row=2, column=2, sticky="w") + bg_push_scale = tk.Scale( options_frame, from_=1.00, to=1.40, @@ -3643,20 +3727,19 @@ def _commit_pop_entries(): stretch_lo_entry.bind("", lambda _e: _commit_pop_entries()) stretch_hi_entry.bind("", lambda _e: _commit_pop_entries()) - - -bg_shift_label = tk.Label( +fg_shift_label = tk.Label( options_frame, text=t("Foreground Shift"), - bg="#1c1c1c", fg="white" + bg="#1c1c1c", + fg="white" ) -bg_shift_label.grid(row=4, column=0, sticky="w") +fg_shift_label.grid(row=4, column=0, sticky="w") tk.Scale( options_frame, from_=-20, to=20, resolution=0.1, orient=tk.HORIZONTAL, - variable=bg_shift, bg="#1c1c1c", fg="white", + variable=fg_shift, bg="#1c1c1c", fg="white", cursor="sb_h_double_arrow" ).grid(row=4, column=1, sticky="ew") @@ -3714,13 +3797,12 @@ def _commit_pop_entries(): ).grid(row=5, column=3, sticky="ew") # Row 6 -fg_shift_label = tk.Label( +bg_shift_label = tk.Label( options_frame, text=t("Background Shift"), - bg="#1c1c1c", - fg="white" + bg="#1c1c1c", fg="white" ) -fg_shift_label.grid(row=6, column=0, sticky="w") +bg_shift_label.grid(row=6, column=0, sticky="w") tk.Scale( options_frame, @@ -3728,7 +3810,7 @@ def _commit_pop_entries(): to=20, resolution=0.1, orient=tk.HORIZONTAL, - variable=fg_shift, + variable=bg_shift, bg="#1c1c1c", fg="white", cursor="sb_h_double_arrow" ).grid(row=6, column=1, sticky="ew") @@ -3888,6 +3970,20 @@ def open_processing_dialog(): ) pop_frame.pack(fill="x", expand=False, padx=10, pady=10) + # --- Clip Range UI --- + clip_frame = tk.LabelFrame( + dlg, + text=t("Clip Range (optional)"), + bg="#1c1c1c", + fg="white", + font=("Segoe UI", 10, "bold"), + labelanchor="nw", + padx=10, + pady=10 + ) + clip_frame.pack(fill="x", expand=False, padx=10, pady=10) # adjust placement + + # Make columns evenly resize & give a minimum so controls don't squash for i in range(3): pop_frame.columnconfigure(i, weight=1, minsize=110) @@ -3986,18 +4082,54 @@ def open_processing_dialog(): justify="left" ) use_dfw_checkbox.grid(row=3, column=2, sticky="w", padx=5) + + def _time_validate(s: str) -> bool: + """ + Allow digits, colon, dot, spaces so users can type 'HH:MM:SS(.ms)', 'MM:SS(.ms)', or 'SS(.ms)'. + Actual parsing is done by parse_timecode(); this just keeps the entry clean-ish. + """ + return bool(re.match(r'^[0-9:\.\s]*$', s)) + + vcmd = (clip_frame.register(_time_validate), "%P") + + start_clip_range_label = tk.Label( + clip_frame, text=t("Start (HH:MM:SS[.ms] or seconds):") + ) + start_clip_range_label.grid(row=5, column=0, sticky="w", padx=6, pady=4) + + start_entry = tk.Entry(clip_frame, textvariable=clip_start_var, width=18, validate="key", validatecommand=vcmd) + start_entry.grid(row=5, column=1, sticky="w", padx=6, pady=4) + + end_clip_range_label = tk.Label( + clip_frame, text=t("End (HH:MM:SS[.ms] or seconds):") + ) + end_clip_range_label.grid(row=6, column=0, sticky="w", padx=6, pady=4) + + end_entry = tk.Entry(clip_frame, textvariable=clip_end_var, width=18, validate="key", validatecommand=vcmd) + end_entry.grid(row=6, column=1, sticky="w", padx=6, pady=4) + + btns = tk.Frame(clip_frame) + btns.grid(row=6, column=2, rowspan=1, padx=6, pady=4, sticky="e") + + clear_button_label = tk.Button( + btns, text=t("Clear"), command=clear_clip + ) + clear_button_label.grid(row=0, column=0, padx=4) # 🔹 Tooltips inside the dialog, using your existing language keys - CreateToolTip(preserve_aspect_checkbox, lambda: t("Tooltip.PreserveAspect")) - CreateToolTip(auto_crop_checkbox, lambda: t("Tooltip.AutoCrop")) - CreateToolTip(use_subject_tracking_checkbox, lambda: t("Tooltip.SubjectTracking")) - CreateToolTip(use_dfw_checkbox, lambda: t("Tooltip.FloatingWindow")) - CreateToolTip(enable_edge_checkbox, lambda: t("Tooltip.EdgeMasking")) - CreateToolTip(enable_feathering_checkbox, lambda: t("Tooltip.Feathering")) - CreateToolTip(skip_blank_frames_checkbox, lambda: t("Tooltip.SkipBlankFrames")) -# CreateToolTip(use_ffmpeg_checkbox, lambda: t("Tooltip.SelectedCodec")) - CreateToolTip(enable_dynamic_convergence_checkbox, lambda: t("Tooltip.EnableDynConvergence")) - CreateToolTip(ipd_toggle, lambda: t("Tooltip.EnableIPD")) + CreateToolTip(preserve_aspect_checkbox, lambda: t("Tooltip.PreserveAspect")) + CreateToolTip(auto_crop_checkbox, lambda: t("Tooltip.AutoCrop")) + CreateToolTip(use_subject_tracking_checkbox, lambda: t("Tooltip.SubjectTracking")) + CreateToolTip(use_dfw_checkbox, lambda: t("Tooltip.FloatingWindow")) + CreateToolTip(enable_edge_checkbox, lambda: t("Tooltip.EdgeMasking")) + CreateToolTip(enable_feathering_checkbox, lambda: t("Tooltip.Feathering")) + CreateToolTip(skip_blank_frames_checkbox, lambda: t("Tooltip.SkipBlankFrames")) + CreateToolTip(enable_dynamic_convergence_checkbox, lambda: t("Tooltip.EnableDynConvergence")) + CreateToolTip(ipd_toggle, lambda: t("Tooltip.EnableIPD")) + CreateToolTip(clip_frame, lambda: t("Tooltip.ClipRangeLabel")) + CreateToolTip(start_clip_range_label, lambda: t("Tooltip.StartClipRangeLabel")) + CreateToolTip(end_clip_range_label, lambda: t("Tooltip.EndClipRangeLabel")) + # Close button tk.Button( @@ -4101,27 +4233,22 @@ def open_encoding_dialog(): ) frame.pack(fill="both", expand=True, padx=10, pady=10) - # Make columns evenly resize & give a minimum so controls don't squash - for i in range(7): - frame.columnconfigure(i, weight=1, minsize=110) - - # ───────── Row 0: Stereo output + Renderer + Keep Audio + Delete SBS + HDR ───────── - StereoOutput_label = tk.Label( - frame, - text=t("Left/Right Output"), - bg="#1c1c1c", - fg="white" - ) - StereoOutput_label.grid(row=0, column=0, sticky="w", padx=6, pady=4) + # Main frame grid: 3 rows (checkbox row, dropdown row, slider row) + frame.columnconfigure(0, weight=1) + frame.rowconfigure(0, weight=0) + frame.rowconfigure(1, weight=0) + frame.rowconfigure(2, weight=0) - tk.OptionMenu( - frame, - stereo_out_var, - "sbs", "left", "right", "both" - ).grid(row=0, column=1, sticky="ew", padx=6, pady=4) + # ========================= + # Row 0: Checkboxes band + # ========================= + checks = tk.Frame(frame, bg="#1c1c1c") + checks.grid(row=0, column=0, sticky="ew", padx=2, pady=(2, 8)) + for i in range(4): + checks.columnconfigure(i, weight=1, minsize=160) use_ffmpeg_checkbox = tk.Checkbutton( - frame, + checks, text=t("Use FFmpeg Renderer"), bg="#1c1c1c", fg="white", @@ -4129,10 +4256,10 @@ def open_encoding_dialog(): variable=use_ffmpeg, anchor="w" ) - use_ffmpeg_checkbox.grid(row=0, column=2, sticky="w", padx=5) + use_ffmpeg_checkbox.grid(row=0, column=0, sticky="w", padx=5) keep_audio_checkbox = tk.Checkbutton( - frame, + checks, text=t("Keep Original Audio"), variable=keep_original_audio, bg="#1c1c1c", @@ -4142,10 +4269,10 @@ def open_encoding_dialog(): anchor="w", justify="left" ) - keep_audio_checkbox.grid(row=0, column=3, sticky="w", padx=5) + keep_audio_checkbox.grid(row=0, column=1, sticky="w", padx=5) DeleteSBS_label = tk.Checkbutton( - frame, + checks, text=t("Delete SBS after"), variable=delete_fsbs_var, bg="#1c1c1c", @@ -4155,10 +4282,10 @@ def open_encoding_dialog(): anchor="w", justify="left" ) - DeleteSBS_label.grid(row=0, column=4, sticky="w", padx=5) + DeleteSBS_label.grid(row=0, column=2, sticky="w", padx=5) hdr_checkbox = tk.Checkbutton( - frame, + checks, text=t("Preserve HDR10"), variable=preserve_hdr10_var, onvalue=True, @@ -4170,39 +4297,85 @@ def open_encoding_dialog(): anchor="w", justify="left" ) - hdr_checkbox.grid(row=0, column=5, sticky="w", padx=5) + hdr_checkbox.grid(row=0, column=3, sticky="w", padx=5) + + # ========================= + # Row 1: Dropdowns band + # ========================= + opts = tk.Frame(frame, bg="#1c1c1c") + opts.grid(row=1, column=0, sticky="ew", padx=2, pady=(0, 10)) + + # 10 columns = 5 label+control pairs + for i in range(10): + opts.columnconfigure(i, weight=1, minsize=110) + + # Pair 1: 3D Format + format_button = tk.Label( + opts, text=t("3D Format"), + bg="#1c1c1c", fg="white" + ) + format_button.grid(row=0, column=0, sticky="w", padx=6, pady=4) + + option_menu = tk.OptionMenu( + opts, + output_format, + "Full-SBS", + "Half-SBS", + "VR", + "Red-Cyan Anaglyph", + "Passive Interlaced", + ) + option_menu.config(width=14, cursor="hand2") + option_menu.grid(row=0, column=1, sticky="ew", padx=6, pady=4) + + # Pair 2: Left/Right Output + StereoOutput_label = tk.Label( + opts, + text=t("Left/Right Output"), + bg="#1c1c1c", + fg="white" + ) + StereoOutput_label.grid(row=0, column=2, sticky="w", padx=6, pady=4) - # ───────── Row 1: Aspect • FFmpeg Codec • Codec ───────── + tk.OptionMenu( + opts, + stereo_out_var, + "sbs", "left", "right", "both" + ).grid(row=0, column=3, sticky="ew", padx=6, pady=4) + + # Pair 3: Aspect Ratio selected_aspect_ratio_label = tk.Label( - frame, + opts, text=t("Aspect Ratio:"), bg="#1c1c1c", fg="white" ) - selected_aspect_ratio_label.grid(row=1, column=0, sticky="w", padx=6, pady=4) + selected_aspect_ratio_label.grid(row=0, column=4, sticky="w", padx=6, pady=4) tk.OptionMenu( - frame, + opts, selected_aspect_ratio, *aspect_ratios.keys() - ).grid(row=1, column=1, sticky="ew", padx=6, pady=4) + ).grid(row=0, column=5, sticky="ew", padx=6, pady=4) + # Pair 4: FFmpeg Codec selected_ffmpeg_codec_label = tk.Label( - frame, + opts, text=t("FFmpeg Codec:"), bg="#1c1c1c", fg="white" ) - selected_ffmpeg_codec_label.grid(row=1, column=2, sticky="w", padx=6, pady=4) + selected_ffmpeg_codec_label.grid(row=1, column=0, sticky="w", padx=6, pady=4) tk.OptionMenu( - frame, + opts, selected_ffmpeg_codec, *FFMPEG_CODEC_MAP.keys() - ).grid(row=1, column=3, sticky="ew", padx=6, pady=4) + ).grid(row=1, column=1, columnspan=3, sticky="ew", padx=6, pady=4) + # Pair 5: Codec selected_codec_label = tk.Label( - frame, + opts, text=t("Codec:"), bg="#1c1c1c", fg="white" @@ -4210,57 +4383,65 @@ def open_encoding_dialog(): selected_codec_label.grid(row=1, column=4, sticky="w", padx=6, pady=4) tk.OptionMenu( - frame, + opts, selected_codec, *codec_options - ).grid(row=1, column=5, sticky="ew", padx=6, pady=4) + ).grid(row=1, column=5, columnspan=3, sticky="ew", padx=6, pady=4) + + # ========================= + # Row 2: Sliders band + # ========================= + sliders = tk.Frame(frame, bg="#1c1c1c") + sliders.grid(row=2, column=0, sticky="ew", padx=2, pady=(0, 6)) + for i in range(6): + sliders.columnconfigure(i, weight=1, minsize=110) - # ───────── Row 2: CRF • NVENC CQ ───────── crf_value_label = tk.Label( - frame, + sliders, text=t("CRF"), bg="#1c1c1c", fg="white" ) - crf_value_label.grid(row=2, column=0, sticky="w", padx=6, pady=6) + crf_value_label.grid(row=0, column=3, sticky="w", padx=6, pady=6) tk.Scale( - frame, + sliders, from_=0, to=51, resolution=1, orient=tk.HORIZONTAL, variable=crf_value, - length=150, + length=220, bg="#2b2b2b", fg="white", troughcolor="#444", highlightthickness=0, bd=0 - ).grid(row=2, column=1, columnspan=2, sticky="ew", padx=6, pady=6) + ).grid(row=0, column=4, columnspan=2, sticky="ew", padx=6, pady=6) nvenc_cq_value_label = tk.Label( - frame, + sliders, text=t("NVENC CQ"), bg="#1c1c1c", fg="white" ) - nvenc_cq_value_label.grid(row=2, column=3, sticky="w", padx=6, pady=6) + nvenc_cq_value_label.grid(row=0, column=0, sticky="w", padx=6, pady=6) tk.Scale( - frame, + sliders, from_=0, to=51, resolution=1, orient=tk.HORIZONTAL, variable=nvenc_cq_value, - length=150, + length=220, bg="#2b2b2b", fg="white", troughcolor="#444", highlightthickness=0, bd=0 - ).grid(row=2, column=4, columnspan=2, sticky="ew", padx=6, pady=6) + ).grid(row=0, column=1, columnspan=2, sticky="ew", padx=6, pady=6) + # 🔹 Tooltips inside the dialog, using your existing language keys CreateToolTip(StereoOutput_label, lambda: t("Tooltip.LROutput")) @@ -4273,6 +4454,7 @@ def open_encoding_dialog(): CreateToolTip(selected_codec_label, lambda: t("Tooltip.SelectedCodec")) CreateToolTip(crf_value_label, lambda: t("Tooltip.CRF")) CreateToolTip(nvenc_cq_value_label, lambda: t("Tooltip.NVENCCQ")) + CreateToolTip(format_button, lambda: t("Tooltip.OptionMenu")) # Close button tk.Button( @@ -4314,52 +4496,6 @@ def open_encoding_dialog(): ) processing_button.grid(row=1, column=1,pady=5, padx=5, sticky="ew") -# --- Clip Range UI --- -clip_frame = tk.LabelFrame( - right_col, - text="Clip Range (optional)", - bg="#1c1c1c", - fg="white", - font=("Segoe UI", 10, "bold"), - labelanchor="nw", - padx=10, - pady=10 -) -clip_frame.grid(row=4, column=0, padx=8, pady=8, sticky="we") # adjust placement - - -def _time_validate(s: str) -> bool: - """ - Allow digits, colon, dot, spaces so users can type 'HH:MM:SS(.ms)', 'MM:SS(.ms)', or 'SS(.ms)'. - Actual parsing is done by parse_timecode(); this just keeps the entry clean-ish. - """ - return bool(re.match(r'^[0-9:\.\s]*$', s)) - -vcmd = (clip_frame.register(_time_validate), "%P") - -start_clip_range_label = tk.Label( - clip_frame, text="Start (HH:MM:SS[.ms] or seconds):" -) -start_clip_range_label.grid(row=0, column=0, sticky="w", padx=6, pady=4) - -start_entry = tk.Entry(clip_frame, textvariable=clip_start_var, width=18, validate="key", validatecommand=vcmd) -start_entry.grid(row=0, column=1, sticky="w", padx=6, pady=4) - -end_clip_range_label = tk.Label( - clip_frame, text="End (HH:MM:SS[.ms] or seconds):" -) -end_clip_range_label.grid(row=1, column=0, sticky="w", padx=6, pady=4) - -end_entry = tk.Entry(clip_frame, textvariable=clip_end_var, width=18, validate="key", validatecommand=vcmd) -end_entry.grid(row=1, column=1, sticky="w", padx=6, pady=4) - -btns = tk.Frame(clip_frame) -btns.grid(row=0, column=2, rowspan=2, padx=6, pady=4, sticky="e") -clear_button_label = tk.Button( - btns, text="Clear", command=clear_clip -) -clear_button_label.grid(row=0, column=0, padx=4) - # ── INPUT SOURCES (own frame) ────────────────────────────────────────────── inputs_frame = tk.LabelFrame( right_col, @@ -4598,25 +4734,6 @@ def process_next_in_batch(): button_frame = tk.Frame(right_col, bg="#1c1c1c") button_frame.grid(row=3, column=0, columnspan=2, padx=10, pady=5, sticky="nsew") -# 3D Format Label and Dropdown (Inside button_frame) -format_button = tk.Label( - button_frame, text=t("3D Format"), - bg="#1c1c1c", fg="white" -) -format_button.grid(row=0, column=0, pady=5, padx=5, sticky="ew") - -option_menu = tk.OptionMenu( - button_frame, - output_format, - "Full-SBS", - "Half-SBS", - "VR", - "Red-Cyan Anaglyph", - "Passive Interlaced", -) -option_menu.config(width=10, cursor="hand2") # Adjust width to keep consistent look -option_menu.grid(row=0, column=1, pady=5, padx=5, sticky="ew") - # Buttons Inside button_frame to Keep Everything on One Line start_button = tk.Button( button_frame, @@ -4791,17 +4908,17 @@ def on_render_finished(created_files: list[str]): tooltip_refs["CancelButton"] = CreateToolTip(cancel_button, lambda: t("Tooltip.CancelButton")) tooltip_refs["ResetButton"] = CreateToolTip(reset_button, lambda: t("Tooltip.ResetButton")) tooltip_refs["ColorResetButton"] = CreateToolTip(color_reset_button, lambda: t("Tooltip.ColorResetButton")) -tooltip_refs["ClipRangeLabel"] = CreateToolTip(clip_frame, lambda: t("Tooltip.ClipRangeLabel")) -tooltip_refs["StartClipRangeLabel"] = CreateToolTip(start_clip_range_label, lambda: t("Tooltip.StartClipRangeLabel")) -tooltip_refs["EndClipRangeLabel"] = CreateToolTip(end_clip_range_label, lambda: t("Tooltip.EndClipRangeLabel")) +#tooltip_refs["ClipRangeLabel"] = CreateToolTip(clip_frame, lambda: t("Tooltip.ClipRangeLabel")) +#tooltip_refs["StartClipRangeLabel"] = CreateToolTip(start_clip_range_label, lambda: t("Tooltip.StartClipRangeLabel")) +#tooltip_refs["EndClipRangeLabel"] = CreateToolTip(end_clip_range_label, lambda: t("Tooltip.EndClipRangeLabel")) -tooltip_refs["OptionMenu"] = CreateToolTip(option_menu, lambda: t("Tooltip.OptionMenu")) +#tooltip_refs["OptionMenu"] = CreateToolTip(option_menu, lambda: t("Tooltip.OptionMenu")) tooltip_refs["AspectPreview"] = CreateToolTip(aspect_preview_label, lambda: t("Tooltip.AspectPreview")) # Sliders -tooltip_refs["FGShift"] = CreateToolTip(bg_shift_label, lambda: t("Tooltip.FGShift")) +tooltip_refs["FGShift"] = CreateToolTip(fg_shift_label, lambda: t("Tooltip.FGShift")) tooltip_refs["MGShift"] = CreateToolTip(mg_shift_label, lambda: t("Tooltip.MGShift")) -tooltip_refs["BGShift"] = CreateToolTip(fg_shift_label, lambda: t("Tooltip.BGShift")) +tooltip_refs["BGShift"] = CreateToolTip(bg_shift_label, lambda: t("Tooltip.BGShift")) tooltip_refs["Sharpness"] = CreateToolTip(sharpness_factor_label, lambda: t("Tooltip.Sharpness")) tooltip_refs["ZeroParallaxStrength"] = CreateToolTip(zero_parallax_strength_label, lambda: t("Tooltip.ZeroParallaxStrength")) tooltip_refs["ParallaxBalance"] = CreateToolTip(parallax_balance_label, lambda: t("Tooltip.ParallaxBalance")) @@ -4885,7 +5002,7 @@ def _cfg(w, **kw): pass # Tabs - tab_control.tab(depth_tab_index, text=t("Depth Estimation")) + tab_control.tab(depth_tab_index, text=t("Depth Engine")) tab_control.tab(visiondepth_tab_index, text=t("3D Video Generator")) tab_control.tab(frametools_tab_index, text=t("FrameTools")) tab_control.tab(depth_blend_index, text=t("Depth Blender")) @@ -4926,7 +5043,7 @@ def _cfg(w, **kw): _cfg(image_select_input_button, text=t("Select Input Image")) _cfg(image_select_depth_button, text=t("Select Depth Map Image")) _cfg(image_select_output_button, text=t("Select Output Image")) - _cfg(format_button, text=t("3D Format")) +# _cfg(format_button, text=t("3D Format")) _cfg(start_button, text=t("Generate 3D")) _cfg(batch_start_button, text=t("Start Batch Render")) _cfg(preview_button, text=t("Open Preview")) @@ -4944,9 +5061,9 @@ def _cfg(w, **kw): # Parallax/quality sliders (existing) - _cfg(bg_shift_label, text=t("Foreground Shift")) + _cfg(fg_shift_label, text=t("Foreground Shift")) _cfg(mg_shift_label, text=t("Midground Shift")) - _cfg(fg_shift_label, text=t("Background Shift")) + _cfg(bg_shift_label, text=t("Background Shift")) _cfg(sharpness_factor_label, text=t("Sharpness Factor")) _cfg(zero_parallax_strength_label, text=t("Zero Parallax Strength")) _cfg(parallax_balance_label, text=t("Parallax Balance")) @@ -4992,10 +5109,10 @@ def _cfg(w, **kw): # _cfg(DeleteSBS_label, text=t("Delete SBS after")) # _cfg(hdr_checkbox, text=t("Preserve HDR10")) # _cfg(keep_audio_checkbox , text=t("Keep Original Audio")) - _cfg(clip_frame, text=t("Clip Range (optional)")) - _cfg(start_clip_range_label, text=t("Start (HH:MM:SS[.ms] or seconds):")) - _cfg(end_clip_range_label, text=t("End (HH:MM:SS[.ms] or seconds):")) - _cfg(clear_button_label, text=t("Clear")) +# _cfg(clip_frame, text=t("Clip Range (optional)")) +# _cfg(start_clip_range_label, text=t("Start (HH:MM:SS[.ms] or seconds):")) +# _cfg(end_clip_range_label, text=t("End (HH:MM:SS[.ms] or seconds):")) +# _cfg(clear_button_label, text=t("Clear")) # FrameTools tab _cfg(extract_frames_button, text=t("Extract Frames from Video")) diff --git a/VisionDepth3D_Updater.exe b/VisionDepth3D_Updater.exe new file mode 100644 index 0000000..9d2ee28 Binary files /dev/null and b/VisionDepth3D_Updater.exe differ diff --git a/core/adapters/depthanything3_adapter.py b/core/adapters/depthanything3_adapter.py new file mode 100644 index 0000000..270b5bb --- /dev/null +++ b/core/adapters/depthanything3_adapter.py @@ -0,0 +1,128 @@ +# core/adapters/depthanything3_adapter.py +import torch +import numpy as np +from PIL import Image +from contextlib import nullcontext + +def _process_res_from_inference_size(inference_size, default=504): + # DA3 default in their API is 504 + if inference_size is None: + return default + w, h = inference_size + max_dim = max(int(w), int(h)) + if max_dim <= 512: return 512 + if max_dim <= 640: return 640 + if max_dim <= 768: return 768 + if max_dim <= 896: return 896 + if max_dim <= 1024: return 1024 + if max_dim <= 1280: return 1280 + return 1536 + +def _normalize_depth_percentile(d: torch.Tensor, q_lo=0.02, q_hi=0.98) -> torch.Tensor: + # d is [H,W] + flat = d.flatten() + lo = torch.quantile(flat, q_lo) + hi = torch.quantile(flat, q_hi) + d = (d - lo) / (hi - lo + 1e-6) + return d.clamp(0, 1) + +def load_da3_adapter(spec: str, cache_dir: str, use_fp16: bool = False): + from core.models.depth_anything_3.api import DepthAnything3 + + device = "cuda" if torch.cuda.is_available() else "cpu" + use_amp = (device == "cuda" and use_fp16) + + model = DepthAnything3.from_pretrained(spec, cache_dir=cache_dir) + # Move the WHOLE wrapper so processors + model are consistent + model.to(device) + model.eval() + + try: + model.device = torch.device(device) + except Exception: + pass + + @torch.no_grad() + def da3_infer(images, inference_size=None, **kw): + if not isinstance(images, list): + images = [images] + + # Prefer DA3 defaults unless you override + default_pr = int(kw.get("process_res", 504)) + process_res = _process_res_from_inference_size(inference_size, default=default_pr) + process_res_method = kw.get("process_res_method", "upper_bound_resize") + + # If you want a cap, make it optional and higher during quality tests + cap = kw.get("process_res_cap", None) # e.g. 1024 + if cap is not None: + process_res = min(int(cap), int(process_res)) + + cleaned = [] + for im in images: + if isinstance(im, Image.Image): + if im.mode != "RGB": + im = im.convert("RGB") + cleaned.append(im) + else: + cleaned.append(im) + + ctx = torch.autocast(device_type="cuda", dtype=torch.float16) if use_amp else nullcontext() + + # Run as batch (DA3 supports list input) + with ctx: + pred = model.inference( + cleaned, + process_res=process_res, + process_res_method=process_res_method, + export_dir=None, + export_format="mini_npz", + ) + + depths = getattr(pred, "depth", None) + if depths is None: + raise RuntimeError("DA3 returned Prediction with no .depth") + + # Convert to torch [N,H,W] + if isinstance(depths, np.ndarray): + depths_t = torch.from_numpy(depths).float() + else: + depths_t = depths.detach().float().cpu() + + if depths_t.ndim == 2: + depths_t = depths_t.unsqueeze(0) + + outputs = [] + for k in range(depths_t.shape[0]): + d = depths_t[k] + d = _normalize_depth_percentile(d, q_lo=0.02, q_hi=0.98) + outputs.append({"predicted_depth": d}) + + # If mismatch count, do per-image using SAME normalization + if len(outputs) != len(cleaned): + outputs = [] + for im in cleaned: + with ctx: + pred1 = model.inference( + [im], + process_res=process_res, + process_res_method=process_res_method, + export_dir=None, + export_format="mini_npz", + ) + d1 = pred1.depth[0] + if isinstance(d1, np.ndarray): + d1 = torch.from_numpy(d1).float() + else: + d1 = d1.detach().float().cpu() + d1 = _normalize_depth_percentile(d1, q_lo=0.02, q_hi=0.98) + outputs.append({"predicted_depth": d1}) + + return outputs + + caps = { + "kind": "da3", + "has_builtin_processor": True, + "supports_multi_view": True, + "supports_metric_models": True, + } + return da3_infer, caps diff --git a/core/adapters/videodepthanything_adapter.py b/core/adapters/videodepthanything_adapter.py new file mode 100644 index 0000000..c29e2ea --- /dev/null +++ b/core/adapters/videodepthanything_adapter.py @@ -0,0 +1,132 @@ +# core/adapters/video_depth_anything_adapter.py +import torch +import numpy as np +from PIL import Image + +def _frame_to_np(x): + # Convert PIL → np(H,W,3) uint8 (what VDA expects) + if isinstance(x, Image.Image): + if x.mode != "RGB": + x = x.convert("RGB") + return np.array(x) + + # Torch tensor → numpy + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + return x + + # Already numpy + return x + + +def _pick_encoder_from_repo(repo_id: str) -> str: + r = (repo_id or "").lower() + # simple heuristic + if "small" in r: + return "vits" + if "base" in r: + return "vitb" + return "vitl" # Large default + +def _ckpt_filename(encoder: str, metric: bool) -> str: + # matches upstream naming style + if metric: + return f"metric_video_depth_anything_{encoder}.pth" + return f"video_depth_anything_{encoder}.pth" + +def _input_size_from_inference_size(inference_size, default=518) -> int: + # VDA is square input_size in their CLI. Keep it stable. + if inference_size is None: + return int(default) + w, h = inference_size + m = max(int(w), int(h)) + # clamp to something sane + return 518 if m >= 518 else 392 if m >= 392 else 256 + +def load_vda_adapter(spec: str, cache_dir: str, use_fp16: bool = False): + """ + spec example: + - "depth-anything/Video-Depth-Anything-Large" + returns: (callable, caps) + """ + # your vendored source (or installed package) should expose this + from core.models.video_depth_anything.video_depth import VideoDepthAnything + + from huggingface_hub import hf_hub_download + + device = "cuda" if torch.cuda.is_available() else "cpu" + fp16_ok = (use_fp16 and device == "cuda") + + metric = ("metric" in (spec or "").lower()) + + encoder = _pick_encoder_from_repo(spec) + ckpt_name = _ckpt_filename(encoder, metric) + + ckpt_path = hf_hub_download( + repo_id=spec, + filename=ckpt_name, + cache_dir=cache_dir, + ) + + model_configs = { + "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]}, + "vitb": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768]}, + "vitl": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024]}, + } + + vda = VideoDepthAnything(**model_configs[encoder], metric=metric) + sd = torch.load(ckpt_path, map_location="cpu") + vda.load_state_dict(sd, strict=True) + vda.to(device).eval() + +# if fp16_ok: +# vda.half() + + @torch.no_grad() + def vda_infer(images, inference_size=None, **kw): + if not isinstance(images, list): + images = [images] + + frames = [] + for im in images: + if isinstance(im, Image.Image): + if im.mode != "RGB": + im = im.convert("RGB") + arr = np.array(im, dtype=np.uint8) # (H,W,3) + else: + arr = np.asarray(im) + # if someone passed a torch tensor, you may want: + # if isinstance(im, torch.Tensor): arr = im.detach().cpu().numpy() + frames.append(arr) + + # ✅ VDA expects (T,H,W,3) array (not list) + frames_np = np.stack(frames, axis=0) + + input_size = int(kw.get("input_size", 518)) + target_fps = int(kw.get("target_fps", -1)) + fp32 = bool(kw.get("fp32", False)) + + + depths, fps_out = vda.infer_video_depth( + frames_np, + target_fps, + input_size=input_size, + device=device, + fp32=fp32, + ) + + # return list of {"predicted_depth": tensor} per frame + d = np.asarray(depths, dtype=np.float32) + if d.ndim == 2: + d = d[None, ...] + return [{"predicted_depth": torch.from_numpy(d[i]).float()} for i in range(d.shape[0])] + + caps = { + "kind": "vda", + "has_builtin_processor": True, + "supports_multi_view": True, # sequence model + "supports_metric_models": True, + "is_video_model": True, + "prefers_sequence": True, + } + return vda_infer, caps diff --git a/core/models/depth_anything_3/api.py b/core/models/depth_anything_3/api.py new file mode 100644 index 0000000..d6a5408 --- /dev/null +++ b/core/models/depth_anything_3/api.py @@ -0,0 +1,447 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Depth Anything 3 API module. + +This module provides the main API for Depth Anything 3, including model loading, +inference, and export capabilities. It supports both single and nested model architectures. +""" + +from __future__ import annotations + +import time +from typing import Optional, Sequence +import numpy as np +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin +from PIL import Image + +from .cfg import create_object, load_config +from .registry import MODEL_REGISTRY +from .specs import Prediction +from .utils.export import export +from .utils.geometry import affine_inverse +from .utils.io.input_processor import InputProcessor +from .utils.io.output_processor import OutputProcessor +from .utils.logger import logger +from .utils.pose_align import align_poses_umeyama + + +torch.backends.cudnn.benchmark = False +# logger.info("CUDNN Benchmark Disabled") + +SAFETENSORS_NAME = "model.safetensors" +CONFIG_NAME = "config.json" + + +class DepthAnything3(nn.Module, PyTorchModelHubMixin): + """ + Depth Anything 3 main API class. + + This class provides a high-level interface for depth estimation using Depth Anything 3. + It supports both single and nested model architectures with metric scaling capabilities. + + Features: + - Hugging Face Hub integration via PyTorchModelHubMixin + - Support for multiple model presets (vitb, vitg, nested variants) + - Automatic mixed precision inference + - Export capabilities for various formats (GLB, PLY, NPZ, etc.) + - Camera pose estimation and metric depth scaling + + Usage: + # Load from Hugging Face Hub + model = DepthAnything3.from_pretrained("huggingface/model-name") + + # Or create with specific preset + model = DepthAnything3(preset="vitg") + + # Run inference + prediction = model.inference(images, export_dir="output", export_format="glb") + """ + + _commit_hash: str | None = None # Set by mixin when loading from Hub + + def __init__(self, model_name: str = "da3-large", **kwargs): + """ + Initialize DepthAnything3 with specified preset. + + Args: + model_name: The name of the model preset to use. + Examples: 'da3-giant', 'da3-large', 'da3metric-large', 'da3nested-giant-large'. + **kwargs: Additional keyword arguments (currently unused). + """ + super().__init__() + self.model_name = model_name + + # Build the underlying network + self.config = load_config(MODEL_REGISTRY[self.model_name]) + self.model = create_object(self.config) + self.model.eval() + + # Initialize processors + self.input_processor = InputProcessor() + self.output_processor = OutputProcessor() + + # Device management (set by user) + self.device = None + + @torch.inference_mode() + def forward( + self, + image: torch.Tensor, + extrinsics: torch.Tensor | None = None, + intrinsics: torch.Tensor | None = None, + export_feat_layers: list[int] | None = None, + infer_gs: bool = False, + use_ray_pose: bool = False, + ref_view_strategy: str = "saddle_balanced", + ) -> dict[str, torch.Tensor]: + """ + Forward pass through the model. + + Args: + image: Input batch with shape ``(B, N, 3, H, W)`` on the model device. + extrinsics: Optional camera extrinsics with shape ``(B, N, 4, 4)``. + intrinsics: Optional camera intrinsics with shape ``(B, N, 3, 3)``. + export_feat_layers: Layer indices to return intermediate features for. + infer_gs: Enable Gaussian Splatting branch. + use_ray_pose: Use ray-based pose estimation instead of camera decoder. + ref_view_strategy: Strategy for selecting reference view from multiple views. + + Returns: + Dictionary containing model predictions + """ + # Determine optimal autocast dtype + autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + with torch.no_grad(): + with torch.autocast(device_type=image.device.type, dtype=autocast_dtype): + return self.model( + image, extrinsics, intrinsics, export_feat_layers, infer_gs, use_ray_pose, ref_view_strategy + ) + + def inference( + self, + image: list[np.ndarray | Image.Image | str], + extrinsics: np.ndarray | None = None, + intrinsics: np.ndarray | None = None, + align_to_input_ext_scale: bool = True, + infer_gs: bool = False, + use_ray_pose: bool = False, + ref_view_strategy: str = "saddle_balanced", + render_exts: np.ndarray | None = None, + render_ixts: np.ndarray | None = None, + render_hw: tuple[int, int] | None = None, + process_res: int = 504, + process_res_method: str = "upper_bound_resize", + export_dir: str | None = None, + export_format: str = "mini_npz", + export_feat_layers: Sequence[int] | None = None, + # GLB export parameters + conf_thresh_percentile: float = 40.0, + num_max_points: int = 1_000_000, + show_cameras: bool = True, + # Feat_vis export parameters + feat_vis_fps: int = 15, + # Other export parameters, e.g., gs_ply, gs_video + export_kwargs: Optional[dict] = {}, + ) -> Prediction: + """ + Run inference on input images. + + Args: + image: List of input images (numpy arrays, PIL Images, or file paths) + extrinsics: Camera extrinsics (N, 4, 4) + intrinsics: Camera intrinsics (N, 3, 3) + align_to_input_ext_scale: whether to align the input pose scale to the prediction + infer_gs: Enable the 3D Gaussian branch (needed for `gs_ply`/`gs_video` exports) + use_ray_pose: Use ray-based pose estimation instead of camera decoder (default: False) + ref_view_strategy: Strategy for selecting reference view from multiple views. + Options: "first", "middle", "saddle_balanced", "saddle_sim_range". + Default: "saddle_balanced". For single view input (S ≤ 2), no reordering is performed. + render_exts: Optional render extrinsics for Gaussian video export + render_ixts: Optional render intrinsics for Gaussian video export + render_hw: Optional render resolution for Gaussian video export + process_res: Processing resolution + process_res_method: Resize method for processing + export_dir: Directory to export results + export_format: Export format (mini_npz, npz, glb, ply, gs, gs_video) + export_feat_layers: Layer indices to export intermediate features from + conf_thresh_percentile: [GLB] Lower percentile for adaptive confidence threshold (default: 40.0) # noqa: E501 + num_max_points: [GLB] Maximum number of points in the point cloud (default: 1,000,000) + show_cameras: [GLB] Show camera wireframes in the exported scene (default: True) + feat_vis_fps: [FEAT_VIS] Frame rate for output video (default: 15) + export_kwargs: additional arguments to export functions. + + Returns: + Prediction object containing depth maps and camera parameters + """ + if "gs" in export_format: + assert infer_gs, "must set `infer_gs=True` to perform gs-related export." + + if "colmap" in export_format: + assert isinstance(image[0], str), "`image` must be image paths for COLMAP export." + + # Preprocess images + imgs_cpu, extrinsics, intrinsics = self._preprocess_inputs( + image, extrinsics, intrinsics, process_res, process_res_method + ) + + # Prepare tensors for model + imgs, ex_t, in_t = self._prepare_model_inputs(imgs_cpu, extrinsics, intrinsics) + + # Normalize extrinsics + ex_t_norm = self._normalize_extrinsics(ex_t.clone() if ex_t is not None else None) + + # Run model forward pass + export_feat_layers = list(export_feat_layers) if export_feat_layers is not None else [] + + raw_output = self._run_model_forward( + imgs, ex_t_norm, in_t, export_feat_layers, infer_gs, use_ray_pose, ref_view_strategy + ) + + # Convert raw output to prediction + prediction = self._convert_to_prediction(raw_output) + + # Align prediction to extrinsincs + prediction = self._align_to_input_extrinsics_intrinsics( + extrinsics, intrinsics, prediction, align_to_input_ext_scale + ) + + # Add processed images for visualization + prediction = self._add_processed_images(prediction, imgs_cpu) + + # Export if requested + if export_dir is not None: + + if "gs" in export_format: + if infer_gs and "gs_video" not in export_format: + export_format = f"{export_format}-gs_video" + if "gs_video" in export_format: + if "gs_video" not in export_kwargs: + export_kwargs["gs_video"] = {} + export_kwargs["gs_video"].update( + { + "extrinsics": render_exts, + "intrinsics": render_ixts, + "out_image_hw": render_hw, + } + ) + # Add GLB export parameters + if "glb" in export_format: + if "glb" not in export_kwargs: + export_kwargs["glb"] = {} + export_kwargs["glb"].update( + { + "conf_thresh_percentile": conf_thresh_percentile, + "num_max_points": num_max_points, + "show_cameras": show_cameras, + } + ) + # Add Feat_vis export parameters + if "feat_vis" in export_format: + if "feat_vis" not in export_kwargs: + export_kwargs["feat_vis"] = {} + export_kwargs["feat_vis"].update( + { + "fps": feat_vis_fps, + } + ) + # Add COLMAP export parameters + if "colmap" in export_format: + if "colmap" not in export_kwargs: + export_kwargs["colmap"] = {} + export_kwargs["colmap"].update( + { + "image_paths": image, + "conf_thresh_percentile": conf_thresh_percentile, + "process_res_method": process_res_method, + } + ) + self._export_results(prediction, export_format, export_dir, **export_kwargs) + + return prediction + + def _preprocess_inputs( + self, + image: list[np.ndarray | Image.Image | str], + extrinsics: np.ndarray | None = None, + intrinsics: np.ndarray | None = None, + process_res: int = 504, + process_res_method: str = "upper_bound_resize", + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """Preprocess input images using input processor.""" + start_time = time.time() + imgs_cpu, extrinsics, intrinsics = self.input_processor( + image, + extrinsics.copy() if extrinsics is not None else None, + intrinsics.copy() if intrinsics is not None else None, + process_res, + process_res_method, + ) + end_time = time.time() + logger.info( + "Processed Images Done taking", + end_time - start_time, + "seconds. Shape: ", + imgs_cpu.shape, + ) + return imgs_cpu, extrinsics, intrinsics + + def _prepare_model_inputs( + self, + imgs_cpu: torch.Tensor, + extrinsics: torch.Tensor | None, + intrinsics: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """Prepare tensors for model input.""" + device = self._get_model_device() + + # Move images to model device + imgs = imgs_cpu.to(device, non_blocking=True)[None].float() + + # Convert camera parameters to tensors + ex_t = ( + extrinsics.to(device, non_blocking=True)[None].float() + if extrinsics is not None + else None + ) + in_t = ( + intrinsics.to(device, non_blocking=True)[None].float() + if intrinsics is not None + else None + ) + + return imgs, ex_t, in_t + + def _normalize_extrinsics(self, ex_t: torch.Tensor | None) -> torch.Tensor | None: + """Normalize extrinsics""" + if ex_t is None: + return None + transform = affine_inverse(ex_t[:, :1]) + ex_t_norm = ex_t @ transform + c2ws = affine_inverse(ex_t_norm) + translations = c2ws[..., :3, 3] + dists = translations.norm(dim=-1) + median_dist = torch.median(dists) + median_dist = torch.clamp(median_dist, min=1e-1) + ex_t_norm[..., :3, 3] = ex_t_norm[..., :3, 3] / median_dist + return ex_t_norm + + def _align_to_input_extrinsics_intrinsics( + self, + extrinsics: torch.Tensor | None, + intrinsics: torch.Tensor | None, + prediction: Prediction, + align_to_input_ext_scale: bool = True, + ransac_view_thresh: int = 10, + ) -> Prediction: + """Align depth map to input extrinsics""" + if extrinsics is None: + return prediction + prediction.intrinsics = intrinsics.numpy() + _, _, scale, aligned_extrinsics = align_poses_umeyama( + prediction.extrinsics, + extrinsics.numpy(), + ransac=len(extrinsics) >= ransac_view_thresh, + return_aligned=True, + random_state=42, + ) + if align_to_input_ext_scale: + prediction.extrinsics = extrinsics[..., :3, :].numpy() + prediction.depth /= scale + else: + prediction.extrinsics = aligned_extrinsics + return prediction + + def _run_model_forward( + self, + imgs: torch.Tensor, + ex_t: torch.Tensor | None, + in_t: torch.Tensor | None, + export_feat_layers: Sequence[int] | None = None, + infer_gs: bool = False, + use_ray_pose: bool = False, + ref_view_strategy: str = "saddle_balanced", + ) -> dict[str, torch.Tensor]: + """Run model forward pass.""" + device = imgs.device + need_sync = device.type == "cuda" + if need_sync: + torch.cuda.synchronize(device) + start_time = time.time() + feat_layers = list(export_feat_layers) if export_feat_layers is not None else None + output = self.forward(imgs, ex_t, in_t, feat_layers, infer_gs, use_ray_pose, ref_view_strategy) + if need_sync: + torch.cuda.synchronize(device) + end_time = time.time() + logger.info(f"Model Forward Pass Done. Time: {end_time - start_time} seconds") + return output + + def _convert_to_prediction(self, raw_output: dict[str, torch.Tensor]) -> Prediction: + """Convert raw model output to Prediction object.""" + start_time = time.time() + output = self.output_processor(raw_output) + end_time = time.time() + logger.info(f"Conversion to Prediction Done. Time: {end_time - start_time} seconds") + return output + + def _add_processed_images(self, prediction: Prediction, imgs_cpu: torch.Tensor) -> Prediction: + """Add processed images to prediction for visualization.""" + # Convert from (N, 3, H, W) to (N, H, W, 3) and denormalize + processed_imgs = imgs_cpu.permute(0, 2, 3, 1).cpu().numpy() # (N, H, W, 3) + + # Denormalize from ImageNet normalization + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + processed_imgs = processed_imgs * std + mean + processed_imgs = np.clip(processed_imgs, 0, 1) + processed_imgs = (processed_imgs * 255).astype(np.uint8) + + prediction.processed_images = processed_imgs + return prediction + + def _export_results( + self, prediction: Prediction, export_format: str, export_dir: str, **kwargs + ) -> None: + """Export results to specified format and directory.""" + start_time = time.time() + export(prediction, export_format, export_dir, **kwargs) + end_time = time.time() + logger.info(f"Export Results Done. Time: {end_time - start_time} seconds") + + def _get_model_device(self) -> torch.device: + """ + Get the device where the model is located. + + Returns: + Device where the model parameters are located + + Raises: + ValueError: If no tensors are found in the model + """ + if self.device is not None: + return self.device + + # Find device from parameters + for param in self.parameters(): + self.device = param.device + return param.device + + # Find device from buffers + for buffer in self.buffers(): + self.device = buffer.device + return buffer.device + + raise ValueError("No tensor found in model") \ No newline at end of file diff --git a/core/models/depth_anything_3/cfg.py b/core/models/depth_anything_3/cfg.py new file mode 100644 index 0000000..4165bd2 --- /dev/null +++ b/core/models/depth_anything_3/cfg.py @@ -0,0 +1,144 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Configuration utility functions +""" + +import importlib +from pathlib import Path +from typing import Any, Callable, List, Union +from omegaconf import DictConfig, ListConfig, OmegaConf + +try: + OmegaConf.register_new_resolver("eval", eval) +except Exception as e: + # if eval is not available, we can just pass + print(f"Error registering eval resolver: {e}") + + +def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]: + """ + Load a configuration. Will resolve inheritance. + Supports both file paths and module paths (e.g., depth_anything_3.configs.giant). + """ + # Check if path is a module path (contains dots but no slashes and doesn't end with .yaml) + if "." in path and "/" not in path and not path.endswith(".yaml"): + # It's a module path, load from package resources + path_parts = path.split(".")[1:] + config_path = Path(__file__).resolve().parent + for part in path_parts: + config_path = config_path.joinpath(part) + config_path = config_path.with_suffix(".yaml") + config = OmegaConf.load(str(config_path)) + else: + # It's a file path (absolute, relative, or with .yaml extension) + config = OmegaConf.load(path) + + if argv is not None: + config_argv = OmegaConf.from_dotlist(argv) + config = OmegaConf.merge(config, config_argv) + config = resolve_recursive(config, resolve_inheritance) + return config + + +def resolve_recursive( + config: Any, + resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]], +) -> Any: + config = resolver(config) + if isinstance(config, DictConfig): + for k in config.keys(): + v = config.get(k) + if isinstance(v, (DictConfig, ListConfig)): + config[k] = resolve_recursive(v, resolver) + if isinstance(config, ListConfig): + for i in range(len(config)): + v = config.get(i) + if isinstance(v, (DictConfig, ListConfig)): + config[i] = resolve_recursive(v, resolver) + return config + + +def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any: + """ + Recursively resolve inheritance if the config contains: + __inherit__: path/to/parent.yaml or a ListConfig of such paths. + """ + if isinstance(config, DictConfig): + inherit = config.pop("__inherit__", None) + + if inherit: + inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit] + + parent_config = None + for parent_path in inherit_list: + assert isinstance(parent_path, str) + parent_config = ( + load_config(parent_path) + if parent_config is None + else OmegaConf.merge(parent_config, load_config(parent_path)) + ) + + if len(config.keys()) > 0: + config = OmegaConf.merge(parent_config, config) + else: + config = parent_config + return config + + +def import_item(path: str, name: str) -> Any: + """ + Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass + """ + return getattr(importlib.import_module(path), name) + + +def create_object(config: DictConfig) -> Any: + """ + Create an object from config. + The config is expected to contains the following: + __object__: + path: path.to.module + name: MyClass + args: as_config | as_params (default to as_config) + """ + config = DictConfig(config) + item = import_item( + path=config.__object__.path, + name=config.__object__.name, + ) + args = config.__object__.get("args", "as_config") + if args == "as_config": + return item(config) + if args == "as_params": + config = OmegaConf.to_object(config) + config.pop("__object__") + return item(**config) + raise NotImplementedError(f"Unknown args type: {args}") + + +def create_dataset(path: str, *args, **kwargs) -> Any: + """ + Create a dataset. Requires the file to contain a "create_dataset" function. + """ + return import_item(path, "create_dataset")(*args, **kwargs) + + +def to_dict_recursive(config_obj): + if isinstance(config_obj, DictConfig): + return {k: to_dict_recursive(v) for k, v in config_obj.items()} + elif isinstance(config_obj, ListConfig): + return [to_dict_recursive(item) for item in config_obj] + return config_obj \ No newline at end of file diff --git a/core/models/depth_anything_3/configs/da3-base.yaml b/core/models/depth_anything_3/configs/da3-base.yaml new file mode 100644 index 0000000..ff594c0 --- /dev/null +++ b/core/models/depth_anything_3/configs/da3-base.yaml @@ -0,0 +1,45 @@ +__object__: + path: core.models.depth_anything_3.model.da3 + name: DepthAnything3Net + args: as_params + +net: + __object__: + path: core.models.depth_anything_3.model.dinov2.dinov2 + name: DinoV2 + args: as_params + + name: vitb + out_layers: [5, 7, 9, 11] + alt_start: 4 + qknorm_start: 4 + rope_start: 4 + cat_token: True + +head: + __object__: + path: core.models.depth_anything_3.model.dualdpt + name: DualDPT + args: as_params + + dim_in: &head_dim_in 1536 + output_dim: 2 + features: &head_features 128 + out_channels: &head_out_channels [96, 192, 384, 768] + + +cam_enc: + __object__: + path: core.models.depth_anything_3.model.cam_enc + name: CameraEnc + args: as_params + + dim_out: 768 + +cam_dec: + __object__: + path: core.models.depth_anything_3.model.cam_dec + name: CameraDec + args: as_params + + dim_in: 1536 diff --git a/core/models/depth_anything_3/configs/da3-giant.yaml b/core/models/depth_anything_3/configs/da3-giant.yaml new file mode 100644 index 0000000..a9d4d31 --- /dev/null +++ b/core/models/depth_anything_3/configs/da3-giant.yaml @@ -0,0 +1,71 @@ +__object__: + path: core.models.depth_anything_3.model.da3 + name: DepthAnything3Net + args: as_params + +net: + __object__: + path: core.models.depth_anything_3.model.dinov2.dinov2 + name: DinoV2 + args: as_params + + name: vitg + out_layers: [19, 27, 33, 39] + alt_start: 13 + qknorm_start: 13 + rope_start: 13 + cat_token: True + +head: + __object__: + path: core.models.depth_anything_3.model.dualdpt + name: DualDPT + args: as_params + + dim_in: &head_dim_in 3072 + output_dim: 2 + features: &head_features 256 + out_channels: &head_out_channels [256, 512, 1024, 1024] + + +cam_enc: + __object__: + path: core.models.depth_anything_3.model.cam_enc + name: CameraEnc + args: as_params + + dim_out: 1536 + +cam_dec: + __object__: + path: core.models.depth_anything_3.model.cam_dec + name: CameraDec + args: as_params + + dim_in: 3072 + + +gs_head: + __object__: + path: core.models.depth_anything_3.model.gsdpt + name: GSDPT + args: as_params + + dim_in: *head_dim_in + output_dim: 38 # should align with gs_adapter's setting, for gs params + features: *head_features + out_channels: *head_out_channels + + +gs_adapter: + __object__: + path: core.models.depth_anything_3.model.gs_adapter + name: GaussianAdapter + args: as_params + + sh_degree: 2 + pred_color: false # predict SH coefficient if false + pred_offset_depth: true + pred_offset_xy: true + gaussian_scale_min: 1e-5 + gaussian_scale_max: 30.0 diff --git a/core/models/depth_anything_3/configs/da3-large.yaml b/core/models/depth_anything_3/configs/da3-large.yaml new file mode 100644 index 0000000..653e336 --- /dev/null +++ b/core/models/depth_anything_3/configs/da3-large.yaml @@ -0,0 +1,45 @@ +__object__: + path: core.models.depth_anything_3.model.da3 + name: DepthAnything3Net + args: as_params + +net: + __object__: + path: core.models.depth_anything_3.model.dinov2.dinov2 + name: DinoV2 + args: as_params + + name: vitl + out_layers: [11, 15, 19, 23] + alt_start: 8 + qknorm_start: 8 + rope_start: 8 + cat_token: True + +head: + __object__: + path: core.models.depth_anything_3.model.dualdpt + name: DualDPT + args: as_params + + dim_in: &head_dim_in 2048 + output_dim: 2 + features: &head_features 256 + out_channels: &head_out_channels [256, 512, 1024, 1024] + + +cam_enc: + __object__: + path: core.models.depth_anything_3.model.cam_enc + name: CameraEnc + args: as_params + + dim_out: 1024 + +cam_dec: + __object__: + path: core.models.depth_anything_3.model.cam_dec + name: CameraDec + args: as_params + + dim_in: 2048 diff --git a/core/models/depth_anything_3/configs/da3-small.yaml b/core/models/depth_anything_3/configs/da3-small.yaml new file mode 100644 index 0000000..051b51c --- /dev/null +++ b/core/models/depth_anything_3/configs/da3-small.yaml @@ -0,0 +1,45 @@ +__object__: + path: core.models.depth_anything_3.model.da3 + name: DepthAnything3Net + args: as_params + +net: + __object__: + path: core.models.depth_anything_3.model.dinov2.dinov2 + name: DinoV2 + args: as_params + + name: vits + out_layers: [5, 7, 9, 11] + alt_start: 4 + qknorm_start: 4 + rope_start: 4 + cat_token: True + +head: + __object__: + path: core.models.depth_anything_3.model.dualdpt + name: DualDPT + args: as_params + + dim_in: &head_dim_in 768 + output_dim: 2 + features: &head_features 64 + out_channels: &head_out_channels [48, 96, 192, 384] + + +cam_enc: + __object__: + path: core.models.depth_anything_3.model.cam_enc + name: CameraEnc + args: as_params + + dim_out: 384 + +cam_dec: + __object__: + path: core.models.depth_anything_3.model.cam_dec + name: CameraDec + args: as_params + + dim_in: 768 diff --git a/core/models/depth_anything_3/configs/da3metric-large.yaml b/core/models/depth_anything_3/configs/da3metric-large.yaml new file mode 100644 index 0000000..823297b --- /dev/null +++ b/core/models/depth_anything_3/configs/da3metric-large.yaml @@ -0,0 +1,28 @@ +__object__: + path: core.models.depth_anything_3.model.da3 + name: DepthAnything3Net + args: as_params + +net: + __object__: + path: core.models.depth_anything_3.model.dinov2.dinov2 + name: DinoV2 + args: as_params + + name: vitl + out_layers: [4, 11, 17, 23] + alt_start: -1 # -1 means disable + qknorm_start: -1 + rope_start: -1 + cat_token: False + +head: + __object__: + path: core.models.depth_anything_3.model.dpt + name: DPT + args: as_params + + dim_in: 1024 + output_dim: 1 + features: 256 + out_channels: [256, 512, 1024, 1024] diff --git a/core/models/depth_anything_3/configs/da3mono-large.yaml b/core/models/depth_anything_3/configs/da3mono-large.yaml new file mode 100644 index 0000000..823297b --- /dev/null +++ b/core/models/depth_anything_3/configs/da3mono-large.yaml @@ -0,0 +1,28 @@ +__object__: + path: core.models.depth_anything_3.model.da3 + name: DepthAnything3Net + args: as_params + +net: + __object__: + path: core.models.depth_anything_3.model.dinov2.dinov2 + name: DinoV2 + args: as_params + + name: vitl + out_layers: [4, 11, 17, 23] + alt_start: -1 # -1 means disable + qknorm_start: -1 + rope_start: -1 + cat_token: False + +head: + __object__: + path: core.models.depth_anything_3.model.dpt + name: DPT + args: as_params + + dim_in: 1024 + output_dim: 1 + features: 256 + out_channels: [256, 512, 1024, 1024] diff --git a/core/models/depth_anything_3/configs/da3nested-giant-large.yaml b/core/models/depth_anything_3/configs/da3nested-giant-large.yaml new file mode 100644 index 0000000..b5f36ee --- /dev/null +++ b/core/models/depth_anything_3/configs/da3nested-giant-large.yaml @@ -0,0 +1,10 @@ +__object__: + path: core.models.depth_anything_3.model.da3 + name: NestedDepthAnything3Net + args: as_params + +anyview: + __inherit__: core.models.depth_anything_3.configs.da3-giant + +metric: + __inherit__: core.models.depth_anything_3.configs.da3metric-large diff --git a/core/models/depth_anything_3/model/__init__.py b/core/models/depth_anything_3/model/__init__.py new file mode 100644 index 0000000..ce3cc5c --- /dev/null +++ b/core/models/depth_anything_3/model/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from core.models.depth_anything_3.model.da3 import DepthAnything3Net, NestedDepthAnything3Net + +__export__ = [ + NestedDepthAnything3Net, + DepthAnything3Net, +] diff --git a/core/models/depth_anything_3/model/__pycache__/__init__.cpython-313.pyc b/core/models/depth_anything_3/model/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..44f31cb Binary files /dev/null and b/core/models/depth_anything_3/model/__pycache__/__init__.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/__pycache__/cam_dec.cpython-313.pyc b/core/models/depth_anything_3/model/__pycache__/cam_dec.cpython-313.pyc new file mode 100644 index 0000000..3af4f75 Binary files /dev/null and b/core/models/depth_anything_3/model/__pycache__/cam_dec.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/__pycache__/cam_enc.cpython-313.pyc b/core/models/depth_anything_3/model/__pycache__/cam_enc.cpython-313.pyc new file mode 100644 index 0000000..8499356 Binary files /dev/null and b/core/models/depth_anything_3/model/__pycache__/cam_enc.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/__pycache__/da3.cpython-313.pyc b/core/models/depth_anything_3/model/__pycache__/da3.cpython-313.pyc new file mode 100644 index 0000000..f1c8c79 Binary files /dev/null and b/core/models/depth_anything_3/model/__pycache__/da3.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/__pycache__/dpt.cpython-313.pyc b/core/models/depth_anything_3/model/__pycache__/dpt.cpython-313.pyc new file mode 100644 index 0000000..eb20c1a Binary files /dev/null and b/core/models/depth_anything_3/model/__pycache__/dpt.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/__pycache__/dualdpt.cpython-313.pyc b/core/models/depth_anything_3/model/__pycache__/dualdpt.cpython-313.pyc new file mode 100644 index 0000000..a23e5b7 Binary files /dev/null and b/core/models/depth_anything_3/model/__pycache__/dualdpt.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/__pycache__/gs_adapter.cpython-313.pyc b/core/models/depth_anything_3/model/__pycache__/gs_adapter.cpython-313.pyc new file mode 100644 index 0000000..2e1eae4 Binary files /dev/null and b/core/models/depth_anything_3/model/__pycache__/gs_adapter.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/__pycache__/gsdpt.cpython-313.pyc b/core/models/depth_anything_3/model/__pycache__/gsdpt.cpython-313.pyc new file mode 100644 index 0000000..af65e66 Binary files /dev/null and b/core/models/depth_anything_3/model/__pycache__/gsdpt.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/__pycache__/reference_view_selector.cpython-313.pyc b/core/models/depth_anything_3/model/__pycache__/reference_view_selector.cpython-313.pyc new file mode 100644 index 0000000..7ea7bb7 Binary files /dev/null and b/core/models/depth_anything_3/model/__pycache__/reference_view_selector.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/cam_dec.py b/core/models/depth_anything_3/model/cam_dec.py new file mode 100644 index 0000000..3353b40 --- /dev/null +++ b/core/models/depth_anything_3/model/cam_dec.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +class CameraDec(nn.Module): + def __init__(self, dim_in=1536): + super().__init__() + output_dim = dim_in + self.backbone = nn.Sequential( + nn.Linear(output_dim, output_dim), + nn.ReLU(), + nn.Linear(output_dim, output_dim), + nn.ReLU(), + ) + self.fc_t = nn.Linear(output_dim, 3) + self.fc_qvec = nn.Linear(output_dim, 4) + self.fc_fov = nn.Sequential(nn.Linear(output_dim, 2), nn.ReLU()) + + def forward(self, feat, camera_encoding=None, *args, **kwargs): + B, N = feat.shape[:2] + feat = feat.reshape(B * N, -1) + feat = self.backbone(feat) + out_t = self.fc_t(feat.float()).reshape(B, N, 3) + if camera_encoding is None: + out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4) + out_fov = self.fc_fov(feat.float()).reshape(B, N, 2) + else: + out_qvec = camera_encoding[..., 3:7] + out_fov = camera_encoding[..., -2:] + pose_enc = torch.cat([out_t, out_qvec, out_fov], dim=-1) + return pose_enc diff --git a/core/models/depth_anything_3/model/cam_enc.py b/core/models/depth_anything_3/model/cam_enc.py new file mode 100644 index 0000000..3c8ff2a --- /dev/null +++ b/core/models/depth_anything_3/model/cam_enc.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + +from core.models.depth_anything_3.model.utils.attention import Mlp +from core.models.depth_anything_3.model.utils.block import Block +from core.models.depth_anything_3.model.utils.transform import extri_intri_to_pose_encoding +from core.models.depth_anything_3.utils.geometry import affine_inverse + + +class CameraEnc(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_out: int = 1024, + dim_in: int = 9, + trunk_depth: int = 4, + target_dim: int = 9, + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + **kwargs, + ): + super().__init__() + self.target_dim = target_dim + self.trunk_depth = trunk_depth + self.trunk = nn.Sequential( + *[ + Block( + dim=dim_out, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + init_values=init_values, + ) + for _ in range(trunk_depth) + ] + ) + self.token_norm = nn.LayerNorm(dim_out) + self.trunk_norm = nn.LayerNorm(dim_out) + self.pose_branch = Mlp( + in_features=dim_in, + hidden_features=dim_out // 2, + out_features=dim_out, + drop=0, + ) + + def forward( + self, + ext, + ixt, + image_size, + ) -> tuple: + c2ws = affine_inverse(ext) + pose_encoding = extri_intri_to_pose_encoding( + c2ws, + ixt, + image_size, + ) + pose_tokens = self.pose_branch(pose_encoding) + pose_tokens = self.token_norm(pose_tokens) + pose_tokens = self.trunk(pose_tokens) + pose_tokens = self.trunk_norm(pose_tokens) + return pose_tokens diff --git a/core/models/depth_anything_3/model/da3.py b/core/models/depth_anything_3/model/da3.py new file mode 100644 index 0000000..6cdd7db --- /dev/null +++ b/core/models/depth_anything_3/model/da3.py @@ -0,0 +1,442 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn +from addict import Dict +from omegaconf import DictConfig, OmegaConf + +from core.models.depth_anything_3.cfg import create_object +from core.models.depth_anything_3.model.utils.transform import pose_encoding_to_extri_intri +from core.models.depth_anything_3.utils.alignment import ( + apply_metric_scaling, + compute_alignment_mask, + compute_sky_mask, + least_squares_scale_scalar, + sample_tensor_for_quantile, + set_sky_regions_to_max_depth, +) +from core.models.depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, map_pdf_to_opacity +from core.models.depth_anything_3.utils.ray_utils import get_extrinsic_from_camray + + +def _wrap_cfg(cfg_obj): + return OmegaConf.create(cfg_obj) + + +class DepthAnything3Net(nn.Module): + """ + Depth Anything 3 network for depth estimation and camera pose estimation. + + This network consists of: + - Backbone: DinoV2 feature extractor + - Head: DPT or DualDPT for depth prediction + - Optional camera decoders for pose estimation + - Optional GSDPT for 3DGS prediction + + Args: + preset: Configuration preset containing network dimensions and settings + + Returns: + Dictionary containing: + - depth: Predicted depth map (B, H, W) + - depth_conf: Depth confidence map (B, H, W) + - extrinsics: Camera extrinsics (B, N, 4, 4) + - intrinsics: Camera intrinsics (B, N, 3, 3) + - gaussians: 3D Gaussian Splats (world space), type: model.gs_adapter.Gaussians + - aux: Auxiliary features for specified layers + """ + + # Patch size for feature extraction + PATCH_SIZE = 14 + + def __init__(self, net, head, cam_dec=None, cam_enc=None, gs_head=None, gs_adapter=None): + """ + Initialize DepthAnything3Net with given yaml-initialized configuration. + """ + super().__init__() + self.backbone = net if isinstance(net, nn.Module) else create_object(_wrap_cfg(net)) + self.head = head if isinstance(head, nn.Module) else create_object(_wrap_cfg(head)) + self.cam_dec, self.cam_enc = None, None + if cam_dec is not None: + self.cam_dec = ( + cam_dec if isinstance(cam_dec, nn.Module) else create_object(_wrap_cfg(cam_dec)) + ) + self.cam_enc = ( + cam_enc if isinstance(cam_enc, nn.Module) else create_object(_wrap_cfg(cam_enc)) + ) + self.gs_adapter, self.gs_head = None, None + if gs_head is not None and gs_adapter is not None: + self.gs_adapter = ( + gs_adapter + if isinstance(gs_adapter, nn.Module) + else create_object(_wrap_cfg(gs_adapter)) + ) + gs_out_dim = self.gs_adapter.d_in + 1 + if isinstance(gs_head, nn.Module): + assert ( + gs_head.out_dim == gs_out_dim + ), f"gs_head.out_dim should be {gs_out_dim}, got {gs_head.out_dim}" + self.gs_head = gs_head + else: + assert ( + gs_head["output_dim"] == gs_out_dim + ), f"gs_head output_dim should set to {gs_out_dim}, got {gs_head['output_dim']}" + self.gs_head = create_object(_wrap_cfg(gs_head)) + + def forward( + self, + x: torch.Tensor, + extrinsics: torch.Tensor | None = None, + intrinsics: torch.Tensor | None = None, + export_feat_layers: list[int] | None = [], + infer_gs: bool = False, + use_ray_pose: bool = False, + ref_view_strategy: str = "saddle_balanced", + ) -> Dict[str, torch.Tensor]: + """ + Forward pass through the network. + + Args: + x: Input images (B, N, 3, H, W) + extrinsics: Camera extrinsics (B, N, 4, 4) + intrinsics: Camera intrinsics (B, N, 3, 3) + feat_layers: List of layer indices to extract features from + infer_gs: Enable Gaussian Splatting branch + use_ray_pose: Use ray-based pose estimation + ref_view_strategy: Strategy for selecting reference view + + Returns: + Dictionary containing predictions and auxiliary features + """ + # Extract features using backbone + if extrinsics is not None: + with torch.autocast(device_type=x.device.type, enabled=False): + cam_token = self.cam_enc(extrinsics, intrinsics, x.shape[-2:]) + else: + cam_token = None + + feats, aux_feats = self.backbone( + x, cam_token=cam_token, export_feat_layers=export_feat_layers, ref_view_strategy=ref_view_strategy + ) + # feats = [[item for item in feat] for feat in feats] + H, W = x.shape[-2], x.shape[-1] + + # Process features through depth head + with torch.autocast(device_type=x.device.type, enabled=False): + output = self._process_depth_head(feats, H, W) + if use_ray_pose: + output = self._process_ray_pose_estimation(output, H, W) + else: + output = self._process_camera_estimation(feats, H, W, output) + if infer_gs: + output = self._process_gs_head(feats, H, W, output, x, extrinsics, intrinsics) + + output = self._process_mono_sky_estimation(output) + + # Extract auxiliary features if requested + output.aux = self._extract_auxiliary_features(aux_feats, export_feat_layers, H, W) + + return output + + def _process_mono_sky_estimation( + self, output: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Process mono sky estimation.""" + if "sky" not in output: + return output + non_sky_mask = compute_sky_mask(output.sky, threshold=0.3) + if non_sky_mask.sum() <= 10: + return output + if (~non_sky_mask).sum() <= 10: + return output + + non_sky_depth = output.depth[non_sky_mask] + if non_sky_depth.numel() > 100000: + idx = torch.randint(0, non_sky_depth.numel(), (100000,), device=non_sky_depth.device) + sampled_depth = non_sky_depth[idx] + else: + sampled_depth = non_sky_depth + non_sky_max = torch.quantile(sampled_depth, 0.99) + + # Set sky regions to maximum depth and high confidence + output.depth, _ = set_sky_regions_to_max_depth( + output.depth, None, non_sky_mask, max_depth=non_sky_max + ) + return output + + def _process_ray_pose_estimation( + self, output: Dict[str, torch.Tensor], height: int, width: int + ) -> Dict[str, torch.Tensor]: + """Process ray pose estimation if ray pose decoder is available.""" + if "ray" in output and "ray_conf" in output: + pred_extrinsic, pred_focal_lengths, pred_principal_points = get_extrinsic_from_camray( + output.ray, + output.ray_conf, + output.ray.shape[-3], + output.ray.shape[-2], + ) + pred_extrinsic = affine_inverse(pred_extrinsic) # w2c -> c2w + pred_extrinsic = pred_extrinsic[:, :, :3, :] + pred_intrinsic = torch.eye(3, 3)[None, None].repeat(pred_extrinsic.shape[0], pred_extrinsic.shape[1], 1, 1).clone().to(pred_extrinsic.device) + pred_intrinsic[:, :, 0, 0] = pred_focal_lengths[:, :, 0] / 2 * width + pred_intrinsic[:, :, 1, 1] = pred_focal_lengths[:, :, 1] / 2 * height + pred_intrinsic[:, :, 0, 2] = pred_principal_points[:, :, 0] * width * 0.5 + pred_intrinsic[:, :, 1, 2] = pred_principal_points[:, :, 1] * height * 0.5 + del output.ray + del output.ray_conf + output.extrinsics = pred_extrinsic + output.intrinsics = pred_intrinsic + return output + + def _process_depth_head( + self, feats: list[torch.Tensor], H: int, W: int + ) -> Dict[str, torch.Tensor]: + """Process features through the depth prediction head.""" + return self.head(feats, H, W, patch_start_idx=0) + + def _process_camera_estimation( + self, feats: list[torch.Tensor], H: int, W: int, output: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Process camera pose estimation if camera decoder is available.""" + if self.cam_dec is not None: + pose_enc = self.cam_dec(feats[-1][1]) + # Remove ray information as it's not needed for pose estimation + if "ray" in output: + del output.ray + if "ray_conf" in output: + del output.ray_conf + + # Convert pose encoding to extrinsics and intrinsics + c2w, ixt = pose_encoding_to_extri_intri(pose_enc, (H, W)) + output.extrinsics = affine_inverse(c2w) + output.intrinsics = ixt + + return output + + def _process_gs_head( + self, + feats: list[torch.Tensor], + H: int, + W: int, + output: Dict[str, torch.Tensor], + in_images: torch.Tensor, + extrinsics: torch.Tensor | None = None, + intrinsics: torch.Tensor | None = None, + ) -> Dict[str, torch.Tensor]: + """Process 3DGS parameters estimation if 3DGS head is available.""" + if self.gs_head is None or self.gs_adapter is None: + return output + assert output.get("depth", None) is not None, "must provide MV depth for the GS head." + + # The depth is defined in the DA3 model's camera space, + # so even with provided GT camera poses, + # we instead use the predicted camera poses for better alignment. + ctx_extr = output.get("extrinsics", None) + ctx_intr = output.get("intrinsics", None) + assert ( + ctx_extr is not None and ctx_intr is not None + ), "must process camera info first if GT is not available" + + gt_extr = extrinsics + # homo the extr if needed + ctx_extr = as_homogeneous(ctx_extr) + if gt_extr is not None: + gt_extr = as_homogeneous(gt_extr) + + # forward through the gs_dpt head to get 'camera space' parameters + gs_outs = self.gs_head( + feats=feats, + H=H, + W=W, + patch_start_idx=0, + images=in_images, + ) + raw_gaussians = gs_outs.raw_gs + densities = gs_outs.raw_gs_conf + + # convert to 'world space' 3DGS parameters; ready to export and render + # gt_extr could be None, and will be used to align the pose scale if available + gs_world = self.gs_adapter( + extrinsics=ctx_extr, + intrinsics=ctx_intr, + depths=output.depth, + opacities=map_pdf_to_opacity(densities), + raw_gaussians=raw_gaussians, + image_shape=(H, W), + gt_extrinsics=gt_extr, + ) + output.gaussians = gs_world + + return output + + def _extract_auxiliary_features( + self, feats: list[torch.Tensor], feat_layers: list[int], H: int, W: int + ) -> Dict[str, torch.Tensor]: + """Extract auxiliary features from specified layers.""" + aux_features = Dict() + assert len(feats) == len(feat_layers) + for feat, feat_layer in zip(feats, feat_layers): + # Reshape features to spatial dimensions + feat_reshaped = feat.reshape( + [ + feat.shape[0], + feat.shape[1], + H // self.PATCH_SIZE, + W // self.PATCH_SIZE, + feat.shape[-1], + ] + ) + aux_features[f"feat_layer_{feat_layer}"] = feat_reshaped + + return aux_features + + +class NestedDepthAnything3Net(nn.Module): + """ + Nested Depth Anything 3 network with metric scaling capabilities. + + This network combines two DepthAnything3Net branches: + - Main branch: Standard depth estimation + - Metric branch: Metric depth estimation for scaling alignment + + The network performs depth alignment using least squares scaling + and handles sky region masking for improved depth estimation. + + Args: + preset: Configuration for the main depth estimation branch + second_preset: Configuration for the metric depth branch + """ + + def __init__(self, anyview: DictConfig, metric: DictConfig): + """ + Initialize NestedDepthAnything3Net with two branches. + + Args: + preset: Configuration for main depth estimation branch + second_preset: Configuration for metric depth branch + """ + super().__init__() + self.da3 = create_object(anyview) + self.da3_metric = create_object(metric) + + def forward( + self, + x: torch.Tensor, + extrinsics: torch.Tensor | None = None, + intrinsics: torch.Tensor | None = None, + export_feat_layers: list[int] | None = [], + infer_gs: bool = False, + use_ray_pose: bool = False, + ref_view_strategy: str = "saddle_balanced", + ) -> Dict[str, torch.Tensor]: + """ + Forward pass through both branches with metric scaling alignment. + + Args: + x: Input images (B, N, 3, H, W) + extrinsics: Camera extrinsics (B, N, 4, 4) - unused + intrinsics: Camera intrinsics (B, N, 3, 3) - unused + feat_layers: List of layer indices to extract features from + infer_gs: Enable Gaussian Splatting branch + use_ray_pose: Use ray-based pose estimation + ref_view_strategy: Strategy for selecting reference view + + Returns: + Dictionary containing aligned depth predictions and camera parameters + """ + # Get predictions from both branches + output = self.da3( + x, extrinsics, intrinsics, export_feat_layers=export_feat_layers, infer_gs=infer_gs, use_ray_pose=use_ray_pose, ref_view_strategy=ref_view_strategy + ) + metric_output = self.da3_metric(x) + + # Apply metric scaling and alignment + output = self._apply_metric_scaling(output, metric_output) + output = self._apply_depth_alignment(output, metric_output) + output = self._handle_sky_regions(output, metric_output) + + return output + + def _apply_metric_scaling( + self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Apply metric scaling to the metric depth output.""" + # Scale metric depth based on camera intrinsics + metric_output.depth = apply_metric_scaling( + metric_output.depth, + output.intrinsics, + ) + return output + + def _apply_depth_alignment( + self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Apply depth alignment using least squares scaling.""" + # Compute non-sky mask + non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3) + + # Ensure we have enough non-sky pixels + assert non_sky_mask.sum() > 10, "Insufficient non-sky pixels for alignment" + + # Sample depth confidence for quantile computation + depth_conf_ns = output.depth_conf[non_sky_mask] + depth_conf_sampled = sample_tensor_for_quantile(depth_conf_ns, max_samples=100000) + median_conf = torch.quantile(depth_conf_sampled, 0.5) + + # Compute alignment mask + align_mask = compute_alignment_mask( + output.depth_conf, non_sky_mask, output.depth, metric_output.depth, median_conf + ) + + # Compute scale factor using least squares + valid_depth = output.depth[align_mask] + valid_metric_depth = metric_output.depth[align_mask] + scale_factor = least_squares_scale_scalar(valid_metric_depth, valid_depth) + + # Apply scaling to depth and extrinsics + output.depth *= scale_factor + output.extrinsics[:, :, :3, 3] *= scale_factor + output.is_metric = 1 + output.scale_factor = scale_factor.item() + + return output + + def _handle_sky_regions( + self, + output: Dict[str, torch.Tensor], + metric_output: Dict[str, torch.Tensor], + sky_depth_def: float = 200.0, + ) -> Dict[str, torch.Tensor]: + """Handle sky regions by setting them to maximum depth.""" + non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3) + + # Compute maximum depth for non-sky regions + # Use sampling to safely compute quantile on large tensors + non_sky_depth = output.depth[non_sky_mask] + if non_sky_depth.numel() > 100000: + idx = torch.randint(0, non_sky_depth.numel(), (100000,), device=non_sky_depth.device) + sampled_depth = non_sky_depth[idx] + else: + sampled_depth = non_sky_depth + non_sky_max = min(torch.quantile(sampled_depth, 0.99), sky_depth_def) + + # Set sky regions to maximum depth and high confidence + output.depth, output.depth_conf = set_sky_regions_to_max_depth( + output.depth, output.depth_conf, non_sky_mask, max_depth=non_sky_max + ) + + return output diff --git a/core/models/depth_anything_3/model/dinov2/__pycache__/dinov2.cpython-313.pyc b/core/models/depth_anything_3/model/dinov2/__pycache__/dinov2.cpython-313.pyc new file mode 100644 index 0000000..1c2c02d Binary files /dev/null and b/core/models/depth_anything_3/model/dinov2/__pycache__/dinov2.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/dinov2/__pycache__/vision_transformer.cpython-313.pyc b/core/models/depth_anything_3/model/dinov2/__pycache__/vision_transformer.cpython-313.pyc new file mode 100644 index 0000000..bbb84a4 Binary files /dev/null and b/core/models/depth_anything_3/model/dinov2/__pycache__/vision_transformer.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/dinov2/dinov2.py b/core/models/depth_anything_3/model/dinov2/dinov2.py new file mode 100644 index 0000000..96ac751 --- /dev/null +++ b/core/models/depth_anything_3/model/dinov2/dinov2.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + + +from typing import List +import torch.nn as nn + +from core.models.depth_anything_3.model.dinov2.vision_transformer import ( + vit_base, + vit_giant2, + vit_large, + vit_small, +) + + +class DinoV2(nn.Module): + def __init__( + self, + name: str, + out_layers: List[int], + alt_start: int = -1, + qknorm_start: int = -1, + rope_start: int = -1, + cat_token: bool = True, + **kwargs, + ): + super().__init__() + assert name in {"vits", "vitb", "vitl", "vitg"} + self.name = name + self.out_layers = out_layers + self.alt_start = alt_start + self.qknorm_start = qknorm_start + self.rope_start = rope_start + self.cat_token = cat_token + encoder_map = { + "vits": vit_small, + "vitb": vit_base, + "vitl": vit_large, + "vitg": vit_giant2, + } + encoder_fn = encoder_map[self.name] + ffn_layer = "swiglufused" if self.name == "vitg" else "mlp" + self.pretrained = encoder_fn( + img_size=518, + patch_size=14, + ffn_layer=ffn_layer, + alt_start=alt_start, + qknorm_start=qknorm_start, + rope_start=rope_start, + cat_token=cat_token, + ) + + def forward(self, x, **kwargs): + return self.pretrained.get_intermediate_layers( + x, + self.out_layers, + **kwargs, + ) diff --git a/core/models/depth_anything_3/model/dinov2/layers/__init__.py b/core/models/depth_anything_3/model/dinov2/layers/__init__.py new file mode 100644 index 0000000..97dfba9 --- /dev/null +++ b/core/models/depth_anything_3/model/dinov2/layers/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# from .attention import MemEffAttention +from .block import Block +from .layer_scale import LayerScale +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .rope import PositionGetter, RotaryPositionEmbedding2D +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused + +__all__ = [ + Mlp, + PatchEmbed, + SwiGLUFFN, + SwiGLUFFNFused, + Block, + # MemEffAttention, + LayerScale, + PositionGetter, + RotaryPositionEmbedding2D, +] diff --git a/core/models/depth_anything_3/model/dinov2/layers/__pycache__/__init__.cpython-313.pyc b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..50229d6 Binary files /dev/null and b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/__init__.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/dinov2/layers/__pycache__/attention.cpython-313.pyc b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/attention.cpython-313.pyc new file mode 100644 index 0000000..a07a8ca Binary files /dev/null and b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/attention.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/dinov2/layers/__pycache__/block.cpython-313.pyc b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/block.cpython-313.pyc new file mode 100644 index 0000000..f6c0e75 Binary files /dev/null and b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/block.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/dinov2/layers/__pycache__/drop_path.cpython-313.pyc b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/drop_path.cpython-313.pyc new file mode 100644 index 0000000..bf87f2f Binary files /dev/null and b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/drop_path.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/dinov2/layers/__pycache__/layer_scale.cpython-313.pyc b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/layer_scale.cpython-313.pyc new file mode 100644 index 0000000..85fbb97 Binary files /dev/null and b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/layer_scale.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/dinov2/layers/__pycache__/mlp.cpython-313.pyc b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/mlp.cpython-313.pyc new file mode 100644 index 0000000..9c895b4 Binary files /dev/null and b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/mlp.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/dinov2/layers/__pycache__/patch_embed.cpython-313.pyc b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/patch_embed.cpython-313.pyc new file mode 100644 index 0000000..74d5565 Binary files /dev/null and b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/patch_embed.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/dinov2/layers/__pycache__/rope.cpython-313.pyc b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/rope.cpython-313.pyc new file mode 100644 index 0000000..f288176 Binary files /dev/null and b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/rope.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/dinov2/layers/__pycache__/swiglu_ffn.cpython-313.pyc b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/swiglu_ffn.cpython-313.pyc new file mode 100644 index 0000000..cede724 Binary files /dev/null and b/core/models/depth_anything_3/model/dinov2/layers/__pycache__/swiglu_ffn.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/dinov2/layers/attention.py b/core/models/depth_anything_3/model/dinov2/layers/attention.py new file mode 100644 index 0000000..096b9d4 --- /dev/null +++ b/core/models/depth_anything_3/model/dinov2/layers/attention.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import torch.nn.functional as F +from torch import Tensor, nn + +logger = logging.getLogger("dinov2") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + q, k = self.q_norm(q), self.k_norm(k) + if self.rope is not None and pos is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + attn_mask=( + (attn_mask)[:, None].repeat(1, self.num_heads, 1, 1) + if attn_mask is not None + else None + ), + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def _forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/core/models/depth_anything_3/model/dinov2/layers/block.py b/core/models/depth_anything_3/model/dinov2/layers/block.py new file mode 100644 index 0000000..731519b --- /dev/null +++ b/core/models/depth_anything_3/model/dinov2/layers/block.py @@ -0,0 +1,143 @@ +# flake8: noqa: F821 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, Optional +import torch +from torch import Tensor, nn + +from .attention import Attention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + +logger = logging.getLogger("dinov2") +XFORMERS_AVAILABLE = True + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + rope=None, + ln_eps: float = 1e-6, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim, eps=ln_eps) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + rope=rope, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim, eps=ln_eps) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor: + def attn_residual_func(x: Tensor, pos=None, attn_mask=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + pos=pos, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, pos=pos, attn_mask=attn_mask)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, + pos: Optional[Tensor] = None, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + if pos is not None: + # if necessary, apply rope to the subset + pos = pos[brange] + residual = residual_func(x_subset, pos=pos) + else: + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor diff --git a/core/models/depth_anything_3/model/dinov2/layers/drop_path.py b/core/models/depth_anything_3/model/dinov2/layers/drop_path.py new file mode 100644 index 0000000..1c2cc94 --- /dev/null +++ b/core/models/depth_anything_3/model/dinov2/layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/core/models/depth_anything_3/model/dinov2/layers/layer_scale.py b/core/models/depth_anything_3/model/dinov2/layers/layer_scale.py new file mode 100644 index 0000000..898ee12 --- /dev/null +++ b/core/models/depth_anything_3/model/dinov2/layers/layer_scale.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 # noqa: E501 + +from typing import Union +import torch +from torch import Tensor, nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.dim = dim + self.inplace = inplace + self.init_values = init_values + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + def extra_repr(self) -> str: + return f"{self.dim}, init_values={self.init_values}, inplace={self.inplace}" diff --git a/core/models/depth_anything_3/model/dinov2/layers/mlp.py b/core/models/depth_anything_3/model/dinov2/layers/mlp.py new file mode 100644 index 0000000..78ad0d8 --- /dev/null +++ b/core/models/depth_anything_3/model/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/core/models/depth_anything_3/model/dinov2/layers/patch_embed.py b/core/models/depth_anything_3/model/dinov2/layers/patch_embed.py new file mode 100644 index 0000000..64bf6be --- /dev/null +++ b/core/models/depth_anything_3/model/dinov2/layers/patch_embed.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union +import torch.nn as nn +from torch import Tensor + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert ( + H % patch_H == 0 + ), f"Input image height {H} is not a multiple of patch height {patch_H}" + assert ( + W % patch_W == 0 + ), f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = ( + Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/core/models/depth_anything_3/model/dinov2/layers/rope.py b/core/models/depth_anything_3/model/dinov2/layers/rope.py new file mode 100644 index 0000000..f75ba37 --- /dev/null +++ b/core/models/depth_anything_3/model/dinov2/layers/rope.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +# Implementation of 2D Rotary Position Embeddings (RoPE). + +# This module provides a clean implementation of 2D Rotary Position Embeddings, +# which extends the original RoPE concept to handle 2D spatial positions. + +# Inspired by: +# https://github.com/meta-llama/codellama/blob/main/llama/model.py +# https://github.com/naver-ai/rope-vit + + +from typing import Dict, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PositionGetter: + """Generates and caches 2D spatial positions for patches in a grid. + + This class efficiently manages the generation of spatial coordinates for patches + in a 2D grid, caching results to avoid redundant computations. + + Attributes: + position_cache: Dictionary storing precomputed position tensors for different + grid dimensions. + """ + + def __init__(self): + """Initializes the position generator with an empty cache.""" + self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} + + def __call__( + self, batch_size: int, height: int, width: int, device: torch.device + ) -> torch.Tensor: + """Generates spatial positions for a batch of patches. + + Args: + batch_size: Number of samples in the batch. + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size, height*width, 2) containing y,x coordinates + for each position in the grid, repeated for each batch item. + """ + if (height, width) not in self.position_cache: + y_coords = torch.arange(height, device=device) + x_coords = torch.arange(width, device=device) + positions = torch.cartesian_prod(y_coords, x_coords) + self.position_cache[height, width] = positions + + cached_positions = self.position_cache[height, width] + return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(nn.Module): + """2D Rotary Position Embedding implementation. + + This module applies rotary position embeddings to input tokens based on their + 2D spatial positions. It handles the position-dependent rotation of features + separately for vertical and horizontal dimensions. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 2D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + # Compute frequency bands + exponents = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency**exponents) + + # Generate position-dependent frequencies + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + # Compute and cache frequency components + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, + tokens: torch.Tensor, + positions: torch.Tensor, + cos_comp: torch.Tensor, + sin_comp: torch.Tensor, + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + # Embed positions with frequency components + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + # Apply rotation + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 2D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 4. + positions: Position tensor of shape (batch_size, n_tokens, 2) containing + the y and x coordinates for each token. + + Returns: + Tensor of same shape as input with applied 2D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + # Validate inputs + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert ( + positions.ndim == 3 and positions.shape[-1] == 2 + ), "Positions must have shape (batch_size, n_tokens, 2)" + + # Compute feature dimension for each spatial direction + feature_dim = tokens.size(-1) // 2 + + # Get frequency components + max_position = int(positions.max()) + 1 + cos_comp, sin_comp = self._compute_frequency_components( + feature_dim, max_position, tokens.device, tokens.dtype + ) + + # Split features for vertical and horizontal processing + vertical_features, horizontal_features = tokens.chunk(2, dim=-1) + + # Apply RoPE separately for each dimension + vertical_features = self._apply_1d_rope( + vertical_features, positions[..., 0], cos_comp, sin_comp + ) + horizontal_features = self._apply_1d_rope( + horizontal_features, positions[..., 1], cos_comp, sin_comp + ) + + # Combine processed features + return torch.cat((vertical_features, horizontal_features), dim=-1) diff --git a/core/models/depth_anything_3/model/dinov2/layers/swiglu_ffn.py b/core/models/depth_anything_3/model/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000..c8f58e5 --- /dev/null +++ b/core/models/depth_anything_3/model/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional +import torch.nn.functional as F +from torch import Tensor, nn + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/core/models/depth_anything_3/model/dinov2/vision_transformer.py b/core/models/depth_anything_3/model/dinov2/vision_transformer.py new file mode 100644 index 0000000..814fac0 --- /dev/null +++ b/core/models/depth_anything_3/model/dinov2/vision_transformer.py @@ -0,0 +1,456 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import math +from typing import Callable, List, Sequence, Tuple, Union +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint +from einops import rearrange + +from core.models.depth_anything_3.utils.logger import logger + +from .layers import LayerScale # noqa: F401 +from .layers import Mlp # noqa: F401 +from .layers import ( # noqa: F401 + Block, + PatchEmbed, + PositionGetter, + RotaryPositionEmbedding2D, + SwiGLUFFNFused, +) +from core.models.depth_anything_3.model.reference_view_selector import ( + RefViewStrategy, + select_reference_view, + reorder_by_reference, + restore_original_order, +) +from core.models.depth_anything_3.utils.constants import THRESH_FOR_REF_SELECTION + +# logger = logging.getLogger("dinov2") + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def named_apply( + fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=1.0, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + alt_start=-1, + qknorm_start=-1, + rope_start=-1, + rope_freq=100, + plus_cam_token=False, + cat_token=True, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating + positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating + positional embeddings + """ + super().__init__() + self.patch_start_idx = 1 + norm_layer = nn.LayerNorm + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.alt_start = alt_start + self.qknorm_start = qknorm_start + self.rope_start = rope_start + self.cat_token = cat_token + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim + ) + num_patches = self.patch_embed.num_patches + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if self.alt_start != -1: + self.camera_token = nn.Parameter(torch.randn(1, 2, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) + if num_register_tokens + else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + if self.rope_start != -1: + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None + self.position_getter = PositionGetter() if self.rope is not None else None + else: + self.rope = None + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + qk_norm=i >= qknorm_start if qknorm_start != -1 else False, + rope=self.rope if i >= rope_start and rope_start != -1 else None, + ) + for i in range(depth) + ] + self.blocks = nn.ModuleList(blocks_list) + self.norm = norm_layer(embed_dim) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the + # interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using + # both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_cls_token(self, B, S): + cls_token = self.cls_token.expand(B, S, -1) + cls_token = cls_token.reshape(B * S, -1, self.embed_dim) + return cls_token + + def prepare_tokens_with_masks(self, x, masks=None, cls_token=None, **kwargs): + B, S, nc, w, h = x.shape + x = rearrange(x, "b s c h w -> (b s) c h w") + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + cls_token = self.prepare_cls_token(B, S) + x = torch.cat((cls_token, x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + x = rearrange(x, "(b s) n c -> b s n c", b=B, s=S) + return x + + def _prepare_rope(self, B, S, H, W, device): + pos = None + pos_nodiff = None + if self.rope is not None: + pos = self.position_getter( + B * S, H // self.patch_size, W // self.patch_size, device=device + ) + pos = rearrange(pos, "(b s) n c -> b s n c", b=B) + pos_nodiff = torch.zeros_like(pos).to(pos.dtype) + if self.patch_start_idx > 0: + pos = pos + 1 + pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(device).to(pos.dtype) + pos_special = rearrange(pos_special, "(b s) n c -> b s n c", b=B) + pos = torch.cat([pos_special, pos], dim=2) + pos_nodiff = pos_nodiff + 1 + pos_nodiff = torch.cat([pos_special, pos_nodiff], dim=2) + return pos, pos_nodiff + + def _get_intermediate_layers_not_chunked(self, x, n=1, export_feat_layers=[], **kwargs): + B, S, _, H, W = x.shape + x = self.prepare_tokens_with_masks(x) + output, total_block_len, aux_output = [], len(self.blocks), [] + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + pos, pos_nodiff = self._prepare_rope(B, S, H, W, x.device) + + for i, blk in enumerate(self.blocks): + if i < self.rope_start or self.rope is None: + g_pos, l_pos = None, None + else: + g_pos = pos_nodiff + l_pos = pos + + if self.alt_start != -1 and (i == self.alt_start - 1) and x.shape[1] >= THRESH_FOR_REF_SELECTION and kwargs.get("cam_token", None) is None: + # Select reference view using configured strategy + strategy = kwargs.get("ref_view_strategy", "saddle_balanced") + logger.info(f"Selecting reference view using strategy: {strategy}") + b_idx = select_reference_view(x, strategy=strategy) + # Reorder views to place reference view first + x = reorder_by_reference(x, b_idx) + local_x = reorder_by_reference(local_x, b_idx) + + if self.alt_start != -1 and i == self.alt_start: + if kwargs.get("cam_token", None) is not None: + logger.info("Using camera conditions provided by the user") + cam_token = kwargs.get("cam_token") + else: + ref_token = self.camera_token[:, :1].expand(B, -1, -1) + src_token = self.camera_token[:, 1:].expand(B, S - 1, -1) + cam_token = torch.cat([ref_token, src_token], dim=1) + x[:, :, 0] = cam_token + + if self.alt_start != -1 and i >= self.alt_start and i % 2 == 1: + x = self.process_attention( + x, blk, "global", pos=g_pos, attn_mask=kwargs.get("attn_mask", None) + ) + else: + x = self.process_attention(x, blk, "local", pos=l_pos) + local_x = x + + if i in blocks_to_take: + out_x = torch.cat([local_x, x], dim=-1) if self.cat_token else x + # Restore original view order if reordering was applied + if x.shape[1] >= THRESH_FOR_REF_SELECTION and self.alt_start != -1 and 'b_idx' in locals(): + out_x = restore_original_order(out_x, b_idx) + output.append((out_x[:, :, 0], out_x)) + if i in export_feat_layers: + aux_output.append(x) + return output, aux_output + + def process_attention(self, x, block, attn_type="global", pos=None, attn_mask=None): + b, s, n = x.shape[:3] + if attn_type == "local": + x = rearrange(x, "b s n c -> (b s) n c") + if pos is not None: + pos = rearrange(pos, "b s n c -> (b s) n c") + elif attn_type == "global": + x = rearrange(x, "b s n c -> b (s n) c") + if pos is not None: + pos = rearrange(pos, "b s n c -> b (s n) c") + else: + raise ValueError(f"Invalid attention type: {attn_type}") + + x = block(x, pos=pos, attn_mask=attn_mask) + + if attn_type == "local": + x = rearrange(x, "(b s) n c -> b s n c", b=b, s=s) + elif attn_type == "global": + x = rearrange(x, "b (s n) c -> b s n c", b=b, s=s) + return x + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + export_feat_layers: List[int] = [], + **kwargs, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + outputs, aux_outputs = self._get_intermediate_layers_not_chunked( + x, n, export_feat_layers=export_feat_layers, **kwargs + ) + camera_tokens = [out[0] for out in outputs] + if outputs[0][1].shape[-1] == self.embed_dim: + outputs = [self.norm(out[1]) for out in outputs] + elif outputs[0][1].shape[-1] == (self.embed_dim * 2): + outputs = [ + torch.cat( + [out[1][..., : self.embed_dim], self.norm(out[1][..., self.embed_dim :])], + dim=-1, + ) + for out in outputs + ] + else: + raise ValueError(f"Invalid output shape: {outputs[0][1].shape}") + aux_outputs = [self.norm(out) for out in aux_outputs] + outputs = [out[..., 1 + self.num_register_tokens :, :] for out in outputs] + aux_outputs = [out[..., 1 + self.num_register_tokens :, :] for out in aux_outputs] + return tuple(zip(outputs, camera_tokens)), aux_outputs + + +def vit_small(patch_size=16, num_register_tokens=0, depth=12, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=depth, + num_heads=6, + mlp_ratio=4, + # block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, depth=12, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=depth, + num_heads=12, + mlp_ratio=4, + # block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, depth=24, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=depth, + num_heads=16, + mlp_ratio=4, + # block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, depth=40, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=depth, + num_heads=24, + mlp_ratio=4, + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/core/models/depth_anything_3/model/dpt.py b/core/models/depth_anything_3/model/dpt.py new file mode 100644 index 0000000..cef7893 --- /dev/null +++ b/core/models/depth_anything_3/model/dpt.py @@ -0,0 +1,458 @@ +# flake8: noqa E501 +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict as TyDict +from typing import List, Sequence, Tuple +import torch +import torch.nn as nn +from addict import Dict +from einops import rearrange + +from core.models.depth_anything_3.model.utils.head_utils import ( + Permute, + create_uv_grid, + custom_interpolate, + position_grid_to_embed, +) + + +class DPT(nn.Module): + """ + DPT for dense prediction (main head + optional sky head, sky always 1 channel). + + Returns: + - Main head: + * If output_dim>1: { head_name, f"{head_name}_conf" } + * If output_dim==1: { head_name } + - Sky head (if use_sky_head=True): { sky_name } # [B, S, 1, H/down_ratio, W/down_ratio] + """ + + def __init__( + self, + dim_in: int, + *, + patch_size: int = 14, + output_dim: int = 1, + activation: str = "exp", + conf_activation: str = "expp1", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = False, + down_ratio: int = 1, + head_name: str = "depth", + # ---- sky head (fixed 1 channel) ---- + use_sky_head: bool = True, + sky_name: str = "sky", + sky_activation: str = "relu", # 'sigmoid' / 'relu' / 'linear' + use_ln_for_heads: bool = False, # If needed, apply LayerNorm on intermediate features of both heads + norm_type: str = "idt", # use to match legacy GS-DPT head, "idt" / "layer" + fusion_block_inplace: bool = False, + ) -> None: + super().__init__() + + # -------------------- configuration -------------------- + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.down_ratio = down_ratio + + # Names + self.head_main = head_name + self.sky_name = sky_name + + # Main head: output dimension and confidence switch + self.out_dim = output_dim + self.has_conf = output_dim > 1 + + # Sky head parameters (always 1 channel) + self.use_sky_head = use_sky_head + self.sky_activation = sky_activation + + # Fixed 4 intermediate outputs + self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3) + + # -------------------- token pre-norm + per-stage projection -------------------- + if norm_type == "layer": + self.norm = nn.LayerNorm(dim_in) + elif norm_type == "idt": + self.norm = nn.Identity() + else: + raise Exception(f"Unknown norm_type {norm_type}, should be 'layer' or 'idt'.") + self.projects = nn.ModuleList( + [nn.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0) for oc in out_channels] + ) + + # -------------------- Spatial re-size (align to common scale before fusion) -------------------- + # Design consistent with original: relative to patch grid (x4, x2, x1, /2) + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1), + ] + ) + + # -------------------- scratch: stage adapters + main fusion chain -------------------- + self.scratch = _make_scratch(list(out_channels), features, expand=False) + + # Main fusion chain + self.scratch.refinenet1 = _make_fusion_block(features, inplace=fusion_block_inplace) + self.scratch.refinenet2 = _make_fusion_block(features, inplace=fusion_block_inplace) + self.scratch.refinenet3 = _make_fusion_block(features, inplace=fusion_block_inplace) + self.scratch.refinenet4 = _make_fusion_block( + features, has_residual=False, inplace=fusion_block_inplace + ) + + # Heads (shared neck1; then split into two heads) + head_features_1 = features + head_features_2 = 32 + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + + ln_seq = ( + [Permute((0, 2, 3, 1)), nn.LayerNorm(head_features_2), Permute((0, 3, 1, 2))] + if use_ln_for_heads + else [] + ) + + # Main head + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + *ln_seq, + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + # Sky head (fixed 1 channel) + if self.use_sky_head: + self.scratch.sky_output_conv2 = nn.Sequential( + nn.Conv2d( + head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1 + ), + *ln_seq, + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + ) + + # ------------------------------------------------------------------------- + # Public forward (supports frame chunking to save memory) + # ------------------------------------------------------------------------- + def forward( + self, + feats: List[torch.Tensor], + H: int, + W: int, + patch_start_idx: int, + chunk_size: int = 8, + **kwargs, + ) -> Dict: + """ + Args: + feats: List of 4 entries, each entry is a tensor like [B, S, T, C] (or the 0th element of tuple/list is that tensor). + H, W: Original image dimensions + patch_start_idx: Starting index of patch tokens in sequence (for cropping non-patch tokens) + chunk_size: Chunk size along time dimension S + + Returns: + Dict[str, Tensor] + """ + B, S, N, C = feats[0][0].shape + feats = [feat[0].reshape(B * S, N, C) for feat in feats] + + # update image info, used by the GS-DPT head + extra_kwargs = {} + if "images" in kwargs: + extra_kwargs.update({"images": rearrange(kwargs["images"], "B S ... -> (B S) ...")}) + + if chunk_size is None or chunk_size >= S: + out_dict = self._forward_impl(feats, H, W, patch_start_idx, **extra_kwargs) + out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()} + return Dict(out_dict) + + out_dicts: List[TyDict[str, torch.Tensor]] = [] + for s0 in range(0, S, chunk_size): + s1 = min(s0 + chunk_size, S) + kw = {} + if "images" in extra_kwargs: + kw.update({"images": extra_kwargs["images"][s0:s1]}) + out_dicts.append( + self._forward_impl([f[s0:s1] for f in feats], H, W, patch_start_idx, **kw) + ) + out_dict = {k: torch.cat([od[k] for od in out_dicts], dim=0) for k in out_dicts[0].keys()} + out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()} + return Dict(out_dict) + + # ------------------------------------------------------------------------- + # Internal forward (single chunk) + # ------------------------------------------------------------------------- + def _forward_impl( + self, + feats: List[torch.Tensor], + H: int, + W: int, + patch_start_idx: int, + ) -> TyDict[str, torch.Tensor]: + B, _, C = feats[0].shape + ph, pw = H // self.patch_size, W // self.patch_size + resized_feats = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats[take_idx][:, patch_start_idx:] # [B*S, N_patch, C] + x = self.norm(x) + # permute -> contiguous before reshape to keep conv input contiguous + x = x.permute(0, 2, 1).contiguous().reshape(B, C, ph, pw) # [B*S, C, ph, pw] + + x = self.projects[stage_idx](x) + if self.pos_embed: + x = self._add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) # Align scale + resized_feats.append(x) + + # 2) Fusion pyramid (main branch only) + fused = self._fuse(resized_feats) + + # 3) Upsample to target resolution, optionally add position encoding again + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + fused = self.scratch.output_conv1(fused) + fused = custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True) + if self.pos_embed: + fused = self._add_pos_embed(fused, W, H) + + # 4) Shared neck1 + feat = fused + + # 5) Main head: logits -> activation + main_logits = self.scratch.output_conv2(feat) + outs: TyDict[str, torch.Tensor] = {} + if self.has_conf: + fmap = main_logits.permute(0, 2, 3, 1) + pred = self._apply_activation_single(fmap[..., :-1], self.activation) + conf = self._apply_activation_single(fmap[..., -1], self.conf_activation) + outs[self.head_main] = pred.squeeze(1) + outs[f"{self.head_main}_conf"] = conf.squeeze(1) + else: + outs[self.head_main] = self._apply_activation_single( + main_logits, self.activation + ).squeeze(1) + + # 6) Sky head (fixed 1 channel) + if self.use_sky_head: + sky_logits = self.scratch.sky_output_conv2(feat) + outs[self.sky_name] = self._apply_sky_activation(sky_logits).squeeze(1) + + return outs + + # ------------------------------------------------------------------------- + # Subroutines + # ------------------------------------------------------------------------- + def _fuse(self, feats: List[torch.Tensor]) -> torch.Tensor: + """ + 4-layer top-down fusion, returns finest scale features (after fusion, before neck1). + """ + l1, l2, l3, l4 = feats + + l1_rn = self.scratch.layer1_rn(l1) + l2_rn = self.scratch.layer2_rn(l2) + l3_rn = self.scratch.layer3_rn(l3) + l4_rn = self.scratch.layer4_rn(l4) + + # 4 -> 3 -> 2 -> 1 + out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:]) + out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:]) + out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:]) + out = self.scratch.refinenet1(out, l1_rn) + return out + + def _apply_activation_single( + self, x: torch.Tensor, activation: str = "linear" + ) -> torch.Tensor: + """ + Apply activation to single channel output, maintaining semantic consistency with value branch in multi-channel case. + Supports: exp / relu / sigmoid / softplus / tanh / linear / expp1 + """ + act = activation.lower() if isinstance(activation, str) else activation + if act == "exp": + return torch.exp(x) + if act == "expp1": + return torch.exp(x) + 1 + if act == "expm1": + return torch.expm1(x) + if act == "relu": + return torch.relu(x) + if act == "sigmoid": + return torch.sigmoid(x) + if act == "softplus": + return torch.nn.functional.softplus(x) + if act == "tanh": + return torch.tanh(x) + # Default linear + return x + + def _apply_sky_activation(self, x: torch.Tensor) -> torch.Tensor: + """ + Sky head activation (fixed 1 channel): + * 'sigmoid' -> Sigmoid probability map + * 'relu' -> ReLU positive domain output + * 'linear' -> Original value (logits) + """ + act = ( + self.sky_activation.lower() + if isinstance(self.sky_activation, str) + else self.sky_activation + ) + if act == "sigmoid": + return torch.sigmoid(x) + if act == "relu": + return torch.relu(x) + # 'linear' + return x + + def _add_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """Simple UV position encoding directly added to feature map.""" + pw, ph = x.shape[-1], x.shape[-2] + pe = create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pe = position_grid_to_embed(pe, x.shape[1]) * ratio + pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pe + + +# ----------------------------------------------------------------------------- +# Building blocks (preserved, consistent with original) +# ----------------------------------------------------------------------------- +def _make_fusion_block( + features: int, + size: Tuple[int, int] = None, + has_residual: bool = True, + groups: int = 1, + inplace: bool = False, +) -> nn.Module: + return FeatureFusionBlock( + features=features, + activation=nn.ReLU(inplace=inplace), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch( + in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False +) -> nn.Module: + scratch = nn.Module() + # Optional expansion by stage + c1 = out_shape + c2 = out_shape * (2 if expand else 1) + c3 = out_shape * (4 if expand else 1) + c4 = out_shape * (8 if expand else 1) + + scratch.layer1_rn = nn.Conv2d(in_shape[0], c1, 3, 1, 1, bias=False, groups=groups) + scratch.layer2_rn = nn.Conv2d(in_shape[1], c2, 3, 1, 1, bias=False, groups=groups) + scratch.layer3_rn = nn.Conv2d(in_shape[2], c3, 3, 1, 1, bias=False, groups=groups) + scratch.layer4_rn = nn.Conv2d(in_shape[3], c4, 3, 1, 1, bias=False, groups=groups) + return scratch + + +class ResidualConvUnit(nn.Module): + """Lightweight residual convolution block for fusion""" + + def __init__(self, features: int, activation: nn.Module, bn: bool, groups: int = 1) -> None: + super().__init__() + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups) + self.conv2 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups) + self.norm1 = None + self.norm2 = None + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override] + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Top-down fusion block: (optional) residual merge + upsampling + 1x1 contraction""" + + def __init__( + self, + features: int, + activation: nn.Module, + deconv: bool = False, + bn: bool = False, + expand: bool = False, + align_corners: bool = True, + size: Tuple[int, int] = None, + has_residual: bool = True, + groups: int = 1, + ) -> None: + super().__init__() + self.align_corners = align_corners + self.size = size + self.has_residual = has_residual + + self.resConfUnit1 = ( + ResidualConvUnit(features, activation, bn, groups=groups) if has_residual else None + ) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=groups) + + out_features = (features // 2) if expand else features + self.out_conv = nn.Conv2d(features, out_features, 1, 1, 0, bias=True, groups=groups) + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs: torch.Tensor, size: Tuple[int, int] = None) -> torch.Tensor: # type: ignore[override] + """ + xs: + - xs[0]: Top branch input + - xs[1]: Lateral input (can do residual addition with top branch) + """ + y = xs[0] + if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None: + y = self.skip_add.add(y, self.resConfUnit1(xs[1])) + + y = self.resConfUnit2(y) + + # Upsampling + if (size is None) and (self.size is None): + up_kwargs = {"scale_factor": 2} + elif size is None: + up_kwargs = {"size": self.size} + else: + up_kwargs = {"size": size} + + y = custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners) + y = self.out_conv(y) + return y diff --git a/core/models/depth_anything_3/model/dualdpt.py b/core/models/depth_anything_3/model/dualdpt.py new file mode 100644 index 0000000..9ef30c2 --- /dev/null +++ b/core/models/depth_anything_3/model/dualdpt.py @@ -0,0 +1,488 @@ +# flake8: noqa E501 +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Sequence, Tuple +import torch +import torch.nn as nn +from addict import Dict + +from core.models.depth_anything_3.model.dpt import _make_fusion_block, _make_scratch +from core.models.depth_anything_3.model.utils.head_utils import ( + Permute, + create_uv_grid, + custom_interpolate, + position_grid_to_embed, +) + + +class DualDPT(nn.Module): + """ + Dual-head DPT for dense prediction with an always-on auxiliary head. + + Architectural notes: + - Sky/object branches are removed. + - `intermediate_layer_idx` is fixed to (0, 1, 2, 3). + - Auxiliary head has its **own** fusion blocks (no fusion_inplace / no sharing). + - Auxiliary head is internally multi-level; **only the final level** is returned. + - Returns a **dict** with keys from `head_names`, e.g.: + { main_name, f"{main_name}_conf", aux_name, f"{aux_name}_conf" } + - `feature_only` is fixed to False. + """ + + def __init__( + self, + dim_in: int, + *, + patch_size: int = 14, + output_dim: int = 2, + activation: str = "exp", + conf_activation: str = "expp1", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = True, + down_ratio: int = 1, + aux_pyramid_levels: int = 4, + aux_out1_conv_num: int = 5, + head_names: Tuple[str, str] = ("depth", "ray"), + ) -> None: + super().__init__() + + # -------------------- configuration -------------------- + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.down_ratio = down_ratio + + self.aux_levels = aux_pyramid_levels + self.aux_out1_conv_num = aux_out1_conv_num + + # names ONLY come from config (no hard-coded strings elsewhere) + self.head_main, self.head_aux = head_names + + # Always expect 4 scales; enforce intermediate idx = (0, 1, 2, 3) + self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3) + + # -------------------- token pre-norm + per-stage projection -------------------- + self.norm = nn.LayerNorm(dim_in) + self.projects = nn.ModuleList( + [nn.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0) for oc in out_channels] + ) + + # -------------------- spatial re-sizers (align to common scale before fusion) -------------------- + # design: stage strides (x4, x2, x1, /2) relative to patch grid to align to a common pivot scale + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1), + ] + ) + + # -------------------- scratch: stage adapters + fusion (main & aux are separate) -------------------- + self.scratch = _make_scratch(list(out_channels), features, expand=False) + + # Main fusion chain (independent) + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + # Primary head neck + head (independent) + head_features_1 = features + head_features_2 = 32 + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + # Auxiliary fusion chain (completely separate; no sharing, i.e., "fusion_inplace=False") + self.scratch.refinenet1_aux = _make_fusion_block(features) + self.scratch.refinenet2_aux = _make_fusion_block(features) + self.scratch.refinenet3_aux = _make_fusion_block(features) + self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False) + + # Aux pre-head per level (we will only *return final level*) + self.scratch.output_conv1_aux = nn.ModuleList( + [self._make_aux_out1_block(head_features_1) for _ in range(self.aux_levels)] + ) + + # Aux final projection per level + use_ln = True + ln_seq = ( + [Permute((0, 2, 3, 1)), nn.LayerNorm(head_features_2), Permute((0, 3, 1, 2))] + if use_ln + else [] + ) + self.scratch.output_conv2_aux = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d( + head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1 + ), + *ln_seq, + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0), + ) + for _ in range(self.aux_levels) + ] + ) + + # ------------------------------------------------------------------------- + # Public forward (supports frame chunking for memory) + # ------------------------------------------------------------------------- + + def forward( + self, + feats: List[torch.Tensor], + H: int, + W: int, + patch_start_idx: int, + chunk_size: int = 8, + ) -> Dict[str, torch.Tensor]: + """ + Args: + aggregated_tokens_list: List of 4 tensors [B, S, T, C] from transformer. + images: [B, S, 3, H, W], in [0, 1]. + patch_start_idx: Patch-token start in the token sequence (to drop non-patch tokens). + frames_chunk_size: Optional chunking along S for memory. + + Returns: + Dict[str, Tensor] with keys based on `head_names`, e.g.: + self.head_main, f"{self.head_main}_conf", + self.head_aux, f"{self.head_aux}_conf" + Shapes: + main: [B, S, out_dim, H/down_ratio, W/down_ratio] + main_cf: [B, S, 1, H/down_ratio, W/down_ratio] + aux: [B, S, 7, H/down_ratio, W/down_ratio] + aux_cf: [B, S, 1, H/down_ratio, W/down_ratio] + """ + B, S, N, C = feats[0][0].shape + feats = [feat[0].reshape(B * S, N, C) for feat in feats] + if chunk_size is None or chunk_size >= S: + out_dict = self._forward_impl(feats, H, W, patch_start_idx) + out_dict = {k: v.reshape(B, S, *v.shape[1:]) for k, v in out_dict.items()} + return Dict(out_dict) + out_dicts = [] + for s0 in range(0, B * S, chunk_size): + s1 = min(s0 + chunk_size, B * S) + out_dict = self._forward_impl( + [feat[s0:s1] for feat in feats], + H, + W, + patch_start_idx, + ) + out_dicts.append(out_dict) + out_dict = { + k: torch.cat([out_dict[k] for out_dict in out_dicts], dim=0) + for k in out_dicts[0].keys() + } + out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()} + return Dict(out_dict) + + # ------------------------------------------------------------------------- + # Internal forward (single chunk) + # ------------------------------------------------------------------------- + + def _forward_impl( + self, + feats: List[torch.Tensor], + H: int, + W: int, + patch_start_idx: int, + ) -> Dict[str, torch.Tensor]: + B, _, C = feats[0].shape + ph, pw = H // self.patch_size, W // self.patch_size + resized_feats = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats[take_idx][:, patch_start_idx:] + x = self.norm(x) + x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw] + + x = self.projects[stage_idx](x) + if self.pos_embed: + x = self._add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) # align scales + resized_feats.append(x) + + # 2) Fuse pyramid (main & aux are completely independent) + fused_main, fused_aux_pyr = self._fuse(resized_feats) + + # 3) Upsample to target resolution and (optional) add pos-embed again + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + fused_main = custom_interpolate( + fused_main, (h_out, w_out), mode="bilinear", align_corners=True + ) + if self.pos_embed: + fused_main = self._add_pos_embed(fused_main, W, H) + + # Primary head: conv1 -> conv2 -> activate + # fused_main = self.scratch.output_conv1(fused_main) + main_logits = self.scratch.output_conv2(fused_main) + fmap = main_logits.permute(0, 2, 3, 1) + main_pred = self._apply_activation_single(fmap[..., :-1], self.activation) + main_conf = self._apply_activation_single(fmap[..., -1], self.conf_activation) + + # Auxiliary head (multi-level inside) -> only last level returned (after activation) + last_aux = fused_aux_pyr[-1] + if self.pos_embed: + last_aux = self._add_pos_embed(last_aux, W, H) + # neck (per-level pre-conv) then final projection (only for last level) + # last_aux = self.scratch.output_conv1_aux[-1](last_aux) + last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux) + fmap_last = last_aux_logits.permute(0, 2, 3, 1) + aux_pred = self._apply_activation_single(fmap_last[..., :-1], "linear") + aux_conf = self._apply_activation_single(fmap_last[..., -1], self.conf_activation) + return { + self.head_main: main_pred.squeeze(-1), + f"{self.head_main}_conf": main_conf, + self.head_aux: aux_pred, + f"{self.head_aux}_conf": aux_conf, + } + + # ------------------------------------------------------------------------- + # Subroutines + # ------------------------------------------------------------------------- + + def _fuse(self, feats: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Feature pyramid fusion. + Returns: + fused_main: Tensor at finest scale (after refinenet1) + aux_pyr: List of aux tensors at each level (pre out_conv1_aux) + """ + l1, l2, l3, l4 = feats + + l1_rn = self.scratch.layer1_rn(l1) + l2_rn = self.scratch.layer2_rn(l2) + l3_rn = self.scratch.layer3_rn(l3) + l4_rn = self.scratch.layer4_rn(l4) + + # level 4 -> 3 + out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:]) + aux_out = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:]) + aux_list: List[torch.Tensor] = [] + if self.aux_levels >= 4: + aux_list.append(aux_out) + + # level 3 -> 2 + out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:]) + aux_out = self.scratch.refinenet3_aux(aux_out, l3_rn, size=l2_rn.shape[2:]) + if self.aux_levels >= 3: + aux_list.append(aux_out) + + # level 2 -> 1 + out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:]) + aux_out = self.scratch.refinenet2_aux(aux_out, l2_rn, size=l1_rn.shape[2:]) + if self.aux_levels >= 2: + aux_list.append(aux_out) + + # level 1 (final) + out = self.scratch.refinenet1(out, l1_rn) + aux_out = self.scratch.refinenet1_aux(aux_out, l1_rn) + aux_list.append(aux_out) + + out = self.scratch.output_conv1(out) + aux_list = [self.scratch.output_conv1_aux[i](aux) for i, aux in enumerate(aux_list)] + + return out, aux_list + + def _add_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """Simple UV positional embedding added to feature maps.""" + pw, ph = x.shape[-1], x.shape[-2] + pe = create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pe = position_grid_to_embed(pe, x.shape[1]) * ratio + pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pe + + def _make_aux_out1_block(self, in_ch: int) -> nn.Sequential: + """Factory for the aux pre-head stack before the final 1x1 projection.""" + if self.aux_out1_conv_num == 5: + return nn.Sequential( + nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), + nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1), + nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), + nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1), + nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), + ) + if self.aux_out1_conv_num == 3: + return nn.Sequential( + nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), + nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1), + nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), + ) + if self.aux_out1_conv_num == 1: + return nn.Sequential(nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1)) + raise ValueError(f"aux_out1_conv_num {self.aux_out1_conv_num} not supported") + + def _apply_activation_single( + self, x: torch.Tensor, activation: str = "linear" + ) -> torch.Tensor: + """ + Apply activation to single channel output, maintaining semantic consistency with value branch in multi-channel case. + Supports: exp / relu / sigmoid / softplus / tanh / linear / expp1 + """ + act = activation.lower() if isinstance(activation, str) else activation + if act == "exp": + return torch.exp(x) + if act == "expm1": + return torch.expm1(x) + if act == "expp1": + return torch.exp(x) + 1 + if act == "relu": + return torch.relu(x) + if act == "sigmoid": + return torch.sigmoid(x) + if act == "softplus": + return torch.nn.functional.softplus(x) + if act == "tanh": + return torch.tanh(x) + # Default linear + return x + + +# # ----------------------------------------------------------------------------- +# # Building blocks (tidy) +# # ----------------------------------------------------------------------------- + + +# def _make_fusion_block( +# features: int, +# size: Tuple[int, int] = None, +# has_residual: bool = True, +# groups: int = 1, +# inplace: bool = False, # <- activation uses inplace=True by default; not related to "fusion_inplace" +# ) -> nn.Module: +# return FeatureFusionBlock( +# features=features, +# activation=nn.ReLU(inplace=inplace), +# deconv=False, +# bn=False, +# expand=False, +# align_corners=True, +# size=size, +# has_residual=has_residual, +# groups=groups, +# ) + + +# def _make_scratch( +# in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False +# ) -> nn.Module: +# scratch = nn.Module() +# # optionally expand widths by stage +# c1 = out_shape +# c2 = out_shape * (2 if expand else 1) +# c3 = out_shape * (4 if expand else 1) +# c4 = out_shape * (8 if expand else 1) + +# scratch.layer1_rn = nn.Conv2d(in_shape[0], c1, 3, 1, 1, bias=False, groups=groups) +# scratch.layer2_rn = nn.Conv2d(in_shape[1], c2, 3, 1, 1, bias=False, groups=groups) +# scratch.layer3_rn = nn.Conv2d(in_shape[2], c3, 3, 1, 1, bias=False, groups=groups) +# scratch.layer4_rn = nn.Conv2d(in_shape[3], c4, 3, 1, 1, bias=False, groups=groups) +# return scratch + + +# class ResidualConvUnit(nn.Module): +# """Lightweight residual conv block used within fusion.""" + +# def __init__(self, features: int, activation: nn.Module, bn: bool, groups: int = 1) -> None: +# super().__init__() +# self.bn = bn +# self.groups = groups +# self.conv1 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups) +# self.conv2 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups) +# self.norm1 = None +# self.norm2 = None +# self.activation = activation +# self.skip_add = nn.quantized.FloatFunctional() + +# def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override] +# out = self.activation(x) +# out = self.conv1(out) +# if self.norm1 is not None: +# out = self.norm1(out) + +# out = self.activation(out) +# out = self.conv2(out) +# if self.norm2 is not None: +# out = self.norm2(out) + +# return self.skip_add.add(out, x) + + +# class FeatureFusionBlock(nn.Module): +# """Top-down fusion block: (optional) residual merge + upsample + 1x1 shrink.""" + +# def __init__( +# self, +# features: int, +# activation: nn.Module, +# deconv: bool = False, +# bn: bool = False, +# expand: bool = False, +# align_corners: bool = True, +# size: Tuple[int, int] = None, +# has_residual: bool = True, +# groups: int = 1, +# ) -> None: +# super().__init__() +# self.align_corners = align_corners +# self.size = size +# self.has_residual = has_residual + +# self.resConfUnit1 = ( +# ResidualConvUnit(features, activation, bn, groups=groups) if has_residual else None +# ) +# self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=groups) + +# out_features = (features // 2) if expand else features +# self.out_conv = nn.Conv2d(features, out_features, 1, 1, 0, bias=True, groups=groups) +# self.skip_add = nn.quantized.FloatFunctional() + +# def forward(self, *xs: torch.Tensor, size: Tuple[int, int] = None) -> torch.Tensor: # type: ignore[override] +# """ +# xs: +# - xs[0]: top input +# - xs[1]: (optional) lateral (to be added with residual) +# """ +# y = xs[0] +# if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None: +# y = self.skip_add.add(y, self.resConfUnit1(xs[1])) + +# y = self.resConfUnit2(y) + +# # upsample +# if (size is None) and (self.size is None): +# up_kwargs = {"scale_factor": 2} +# elif size is None: +# up_kwargs = {"size": self.size} +# else: +# up_kwargs = {"size": size} + +# y = custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners) +# y = self.out_conv(y) +# return y diff --git a/core/models/depth_anything_3/model/gs_adapter.py b/core/models/depth_anything_3/model/gs_adapter.py new file mode 100644 index 0000000..a764899 --- /dev/null +++ b/core/models/depth_anything_3/model/gs_adapter.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +import torch +from einops import einsum, rearrange, repeat +from torch import nn + +from core.models.depth_anything_3.model.utils.transform import cam_quat_xyzw_to_world_quat_wxyz +from core.models.depth_anything_3.specs import Gaussians +from core.models.depth_anything_3.utils.geometry import affine_inverse, get_world_rays, sample_image_grid +from core.models.depth_anything_3.utils.pose_align import batch_align_poses_umeyama +from core.models.depth_anything_3.utils.sh_helpers import rotate_sh + + +class GaussianAdapter(nn.Module): + + def __init__( + self, + sh_degree: int = 0, + pred_color: bool = False, + pred_offset_depth: bool = False, + pred_offset_xy: bool = True, + gaussian_scale_min: float = 1e-5, + gaussian_scale_max: float = 30.0, + ): + super().__init__() + self.sh_degree = sh_degree + self.pred_color = pred_color + self.pred_offset_depth = pred_offset_depth + self.pred_offset_xy = pred_offset_xy + self.gaussian_scale_min = gaussian_scale_min + self.gaussian_scale_max = gaussian_scale_max + + # Create a mask for the spherical harmonics coefficients. This ensures that at + # initialization, the coefficients are biased towards having a large DC + # component and small view-dependent components. + if not pred_color: + self.register_buffer( + "sh_mask", + torch.ones((self.d_sh,), dtype=torch.float32), + persistent=False, + ) + for degree in range(1, sh_degree + 1): + self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree + + def forward( + self, + extrinsics: torch.Tensor, # "*#batch 4 4" + intrinsics: torch.Tensor, # "*#batch 3 3" + depths: torch.Tensor, # "*#batch" + opacities: torch.Tensor, # "*#batch" | "*#batch _" + raw_gaussians: torch.Tensor, # "*#batch _" + image_shape: tuple[int, int], + eps: float = 1e-8, + gt_extrinsics: Optional[torch.Tensor] = None, # "*#batch 4 4" + **kwargs, + ) -> Gaussians: + device = extrinsics.device + dtype = raw_gaussians.dtype + H, W = image_shape + b, v = raw_gaussians.shape[:2] + + # get cam2worlds and intr_normed to adapt to 3DGS codebase + cam2worlds = affine_inverse(extrinsics) + intr_normed = intrinsics.clone().detach() + intr_normed[..., 0, :] /= W + intr_normed[..., 1, :] /= H + + # 1. compute 3DGS means + # 1.1) offset the predicted depth if needed + if self.pred_offset_depth: + gs_depths = depths + raw_gaussians[..., -1] + raw_gaussians = raw_gaussians[..., :-1] + else: + gs_depths = depths + # 1.2) align predicted poses with GT if needed + if gt_extrinsics is not None and not torch.equal(extrinsics, gt_extrinsics): + try: + _, _, pose_scales = batch_align_poses_umeyama( + gt_extrinsics.detach().float(), + extrinsics.detach().float(), + ) + except Exception: + pose_scales = torch.ones_like(extrinsics[:, 0, 0, 0]) + pose_scales = torch.clamp(pose_scales, min=1 / 3.0, max=3.0) + cam2worlds[:, :, :3, 3] = cam2worlds[:, :, :3, 3] * rearrange( + pose_scales, "b -> b () ()" + ) # [b, i, j] + gs_depths = gs_depths * rearrange(pose_scales, "b -> b () () ()") # [b, v, h, w] + # 1.3) casting xy in image space + xy_ray, _ = sample_image_grid((H, W), device) + xy_ray = xy_ray[None, None, ...].expand(b, v, -1, -1, -1) # b v h w xy + # offset xy if needed + if self.pred_offset_xy: + pixel_size = 1 / torch.tensor((W, H), dtype=xy_ray.dtype, device=device) + offset_xy = raw_gaussians[..., :2] + xy_ray = xy_ray + offset_xy * pixel_size + raw_gaussians = raw_gaussians[..., 2:] # skip the offset_xy + # 1.4) unproject depth + xy to world ray + origins, directions = get_world_rays( + xy_ray, + repeat(cam2worlds, "b v i j -> b v h w i j", h=H, w=W), + repeat(intr_normed, "b v i j -> b v h w i j", h=H, w=W), + ) + gs_means_world = origins + directions * gs_depths[..., None] + gs_means_world = rearrange(gs_means_world, "b v h w d -> b (v h w) d") + + # 2. compute other GS attributes + scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1) + + # 2.1) 3DGS scales + # make the scale invarient to resolution + scale_min = self.gaussian_scale_min + scale_max = self.gaussian_scale_max + scales = scale_min + (scale_max - scale_min) * scales.sigmoid() + pixel_size = 1 / torch.tensor((W, H), dtype=dtype, device=device) + multiplier = self.get_scale_multiplier(intr_normed, pixel_size) + gs_scales = scales * gs_depths[..., None] * multiplier[..., None, None, None] + gs_scales = rearrange(gs_scales, "b v h w d -> b (v h w) d") + + # 2.2) 3DGS quaternion (world space) + # due to historical issue, assume quaternion in order xyzw, not wxyz + # Normalize the quaternion features to yield a valid quaternion. + rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps) + # rotate them to world space + cam_quat_xyzw = rearrange(rotations, "b v h w c -> b (v h w) c") + c2w_mat = repeat( + cam2worlds, + "b v i j -> b (v h w) i j", + h=H, + w=W, + ) + world_quat_wxyz = cam_quat_xyzw_to_world_quat_wxyz(cam_quat_xyzw, c2w_mat) + gs_rotations_world = world_quat_wxyz # b (v h w) c + + # 2.3) 3DGS color / SH coefficient (world space) + sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3) + if not self.pred_color: + sh = sh * self.sh_mask + + if self.pred_color or self.sh_degree == 0: + # predict pre-computed color or predict only DC band, no need to transform + gs_sh_world = sh + else: + gs_sh_world = rotate_sh(sh, cam2worlds[:, :, None, None, None, :3, :3]) + gs_sh_world = rearrange(gs_sh_world, "b v h w xyz d_sh -> b (v h w) xyz d_sh") + + # 2.4) 3DGS opacity + gs_opacities = rearrange(opacities, "b v h w ... -> b (v h w) ...") + + return Gaussians( + means=gs_means_world, + harmonics=gs_sh_world, + opacities=gs_opacities, + scales=gs_scales, + rotations=gs_rotations_world, + ) + + def get_scale_multiplier( + self, + intrinsics: torch.Tensor, # "*#batch 3 3" + pixel_size: torch.Tensor, # "*#batch 2" + multiplier: float = 0.1, + ) -> torch.Tensor: # " *batch" + xy_multipliers = multiplier * einsum( + intrinsics[..., :2, :2].float().inverse().to(intrinsics), + pixel_size, + "... i j, j -> ... i", + ) + return xy_multipliers.sum(dim=-1) + + @property + def d_sh(self) -> int: + return 1 if self.pred_color else (self.sh_degree + 1) ** 2 + + @property + def d_in(self) -> int: + # provided as reference to the gs_dpt output dim + raw_gs_dim = 0 + if self.pred_offset_xy: + raw_gs_dim += 2 + raw_gs_dim += 3 # scales + raw_gs_dim += 4 # quaternion + raw_gs_dim += 3 * self.d_sh # color + if self.pred_offset_depth: + raw_gs_dim += 1 + + return raw_gs_dim diff --git a/core/models/depth_anything_3/model/gsdpt.py b/core/models/depth_anything_3/model/gsdpt.py new file mode 100644 index 0000000..aaec112 --- /dev/null +++ b/core/models/depth_anything_3/model/gsdpt.py @@ -0,0 +1,133 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict as TyDict +from typing import List, Sequence +import torch +import torch.nn as nn + +from core.models.depth_anything_3.model.dpt import DPT +from core.models.depth_anything_3.model.utils.head_utils import activate_head_gs, custom_interpolate + + +class GSDPT(DPT): + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "linear", + conf_activation: str = "sigmoid", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + conf_dim: int = 1, + norm_type: str = "idt", # use to match legacy GS-DPT head, "idt" / "layer" + fusion_block_inplace: bool = False, + ) -> None: + super().__init__( + dim_in=dim_in, + patch_size=patch_size, + output_dim=output_dim, + activation=activation, + conf_activation=conf_activation, + features=features, + out_channels=out_channels, + pos_embed=pos_embed, + down_ratio=down_ratio, + head_name="raw_gs", + use_sky_head=False, + norm_type=norm_type, + fusion_block_inplace=fusion_block_inplace, + ) + self.conf_dim = conf_dim + if conf_dim and conf_dim > 1: + assert ( + conf_activation == "linear" + ), "use linear prediction when using view-dependent opacity" + + merger_out_dim = features if feature_only else features // 2 + self.images_merger = nn.Sequential( + nn.Conv2d(3, merger_out_dim // 4, 3, 1, 1), # fewer channels first + nn.GELU(), + nn.Conv2d(merger_out_dim // 4, merger_out_dim // 2, 3, 1, 1), + nn.GELU(), + nn.Conv2d(merger_out_dim // 2, merger_out_dim, 3, 1, 1), + nn.GELU(), + ) + + # ------------------------------------------------------------------------- + # Internal forward (single chunk) + # ------------------------------------------------------------------------- + def _forward_impl( + self, + feats: List[torch.Tensor], + H: int, + W: int, + patch_start_idx: int, + images: torch.Tensor, + ) -> TyDict[str, torch.Tensor]: + B, _, C = feats[0].shape + ph, pw = H // self.patch_size, W // self.patch_size + resized_feats = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats[take_idx][:, patch_start_idx:] # [B*S, N_patch, C] + x = self.norm(x) + x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw] + + x = self.projects[stage_idx](x) + if self.pos_embed: + x = self._add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) # Align scale + resized_feats.append(x) + + # 2) Fusion pyramid (main branch only) + fused = self._fuse(resized_feats) + fused = self.scratch.output_conv1(fused) + + # 3) Upsample to target resolution, optionally add position encoding again + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + fused = custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True) + + # inject the image information here + fused = fused + self.images_merger(images) + + if self.pos_embed: + fused = self._add_pos_embed(fused, W, H) + + # 4) Shared neck1 + # feat = self.scratch.output_conv1(fused) + feat = fused + + # 5) Main head: logits -> activate_head or single channel activation + main_logits = self.scratch.output_conv2(feat) + outs: TyDict[str, torch.Tensor] = {} + if self.has_conf: + pred, conf = activate_head_gs( + main_logits, + activation=self.activation, + conf_activation=self.conf_activation, + conf_dim=self.conf_dim, + ) + outs[self.head_main] = pred.squeeze(1) + outs[f"{self.head_main}_conf"] = conf.squeeze(1) + else: + outs[self.head_main] = self._apply_activation_single(main_logits).squeeze(1) + + return outs diff --git a/core/models/depth_anything_3/model/reference_view_selector.py b/core/models/depth_anything_3/model/reference_view_selector.py new file mode 100644 index 0000000..f406f5f --- /dev/null +++ b/core/models/depth_anything_3/model/reference_view_selector.py @@ -0,0 +1,223 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Reference View Selection Strategies + +This module provides different strategies for selecting a reference view +from multiple input views in multi-view depth estimation. +""" + +import torch +from typing import Literal + + +RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"] + + +def select_reference_view( + x: torch.Tensor, + strategy: RefViewStrategy = "saddle_balanced", +) -> torch.Tensor: + """ + Select a reference view from multiple views using the specified strategy. + + Args: + x: Input tensor of shape (B, S, N, C) where + B = batch size + S = number of views + N = number of tokens + C = channel dimension + strategy: Selection strategy, one of: + - "first": Always select the first view + - "middle": Select the middle view + - "saddle_balanced": Select view with balanced features across multiple metrics + - "saddle_sim_range": Select view with largest similarity range + + Returns: + b_idx: Tensor of shape (B,) containing the selected view index for each batch + """ + B, S, N, C = x.shape + + # For single view, no reordering needed + if S <= 1: + return torch.zeros(B, dtype=torch.long, device=x.device) + + # Simple position-based strategies + if strategy == "first": + return torch.zeros(B, dtype=torch.long, device=x.device) + + elif strategy == "middle": + return torch.full((B,), S // 2, dtype=torch.long, device=x.device) + + # Feature-based strategies require normalized class tokens + # Extract and normalize class tokens (first token of each view) + img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # B S C + + if strategy == "saddle_balanced": + # Select view with balanced features across multiple metrics + # Compute similarity matrix + sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # B S S + sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0) + sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # B S + + feat_norm = x[:, :, 0].norm(dim=-1) # B S + feat_var = img_class_feat.var(dim=-1) # B S + + # Normalize all metrics to [0, 1] + def normalize_metric(metric): + min_val = metric.min(dim=1, keepdim=True).values + max_val = metric.max(dim=1, keepdim=True).values + return (metric - min_val) / (max_val - min_val + 1e-8) + + sim_score_norm = normalize_metric(sim_score) + norm_norm = normalize_metric(feat_norm) + var_norm = normalize_metric(feat_var) + + # Select view closest to the median (0.5) across all metrics + balance_score = ( + (sim_score_norm - 0.5).abs() + + (norm_norm - 0.5).abs() + + (var_norm - 0.5).abs() + ) + b_idx = balance_score.argmin(dim=1) + + elif strategy == "saddle_sim_range": + # Select view with largest similarity range (max - min) + sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # B S S + sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0) + + sim_max = sim_no_diag.max(dim=-1).values # B S + sim_min = sim_no_diag.min(dim=-1).values # B S + sim_range = sim_max - sim_min + b_idx = sim_range.argmax(dim=1) + + else: + raise ValueError( + f"Unknown reference view selection strategy: {strategy}. " + f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'" + ) + + return b_idx + + +def reorder_by_reference( + x: torch.Tensor, + b_idx: torch.Tensor, +) -> torch.Tensor: + """ + Reorder views to place the selected reference view first. + + Args: + x: Input tensor of shape (B, S, N, C) + b_idx: Reference view indices of shape (B,) + + Returns: + Reordered tensor with reference view at position 0 + + Example: + If b_idx = [2] and S = 5 (views [0,1,2,3,4]), + result order is [2,0,1,3,4] (ref_idx first, then others in order) + """ + B, S = x.shape[0], x.shape[1] + + # For single view, no reordering needed + if S <= 1: + return x + + # Create position indices: (B, S) where each row is [0, 1, 2, ..., S-1] + positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) # B S + + # For each position, determine which original index it should take + # Position 0 gets ref_idx + # Position 1 to ref_idx gets indices 0 to ref_idx-1 + # Position ref_idx+1 to S-1 gets indices ref_idx+1 to S-1 + + b_idx_expanded = b_idx.unsqueeze(1) # B 1 + + # Create the reordering indices + # For positions 1 to ref_idx: map to indices 0 to ref_idx-1 (shift by -1) + # For positions > ref_idx: keep the same + reorder_indices = positions.clone() + reorder_indices = torch.where( + (positions > 0) & (positions <= b_idx_expanded), + positions - 1, + positions + ) + # Set position 0 to ref_idx + reorder_indices[:, 0] = b_idx + + # Gather using advanced indexing + batch_indices = torch.arange(B, device=x.device).unsqueeze(1) # B 1 + x_reordered = x[batch_indices, reorder_indices] + + return x_reordered + + +def restore_original_order( + x: torch.Tensor, + b_idx: torch.Tensor, +) -> torch.Tensor: + """ + Restore original view order after processing. + + Args: + x: Reordered tensor of shape (B, S, ...) + b_idx: Original reference view indices of shape (B,) + + Returns: + Tensor with original view order restored + + Example: + If original order was [0, 1, 2, 3, 4] and b_idx=2, + reordered becomes [2, 0, 1, 3, 4] (reference at position 0), + restore should return [0, 1, 2, 3, 4] (original order). + """ + B, S = x.shape[0], x.shape[1] + + # For single view, no restoration needed + if S <= 1: + return x + + # Create target position indices: (B, S) where each row is [0, 1, 2, ..., S-1] + target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) # B S + + # For each target position, determine which current position it comes from + # Target position 0 to ref_idx-1 <- Current position 1 to ref_idx (shift by +1) + # Target position ref_idx <- Current position 0 + # Target position ref_idx+1 to S-1 <- Current position ref_idx+1 to S-1 (no change) + + b_idx_expanded = b_idx.unsqueeze(1) # B 1 + + # Create the restore indices + restore_indices = torch.where( + target_positions < b_idx_expanded, + target_positions + 1, # Positions before ref_idx come from current position + 1 + target_positions # Positions after ref_idx stay the same + ) + # Target position = ref_idx comes from current position 0 + # Use scatter to set specific positions + restore_indices = torch.scatter( + restore_indices, + dim=1, + index=b_idx_expanded, + src=torch.zeros_like(b_idx_expanded) + ) + + # Gather using advanced indexing + batch_indices = torch.arange(B, device=x.device).unsqueeze(1) # B 1 + x_restored = x[batch_indices, restore_indices] + + return x_restored + diff --git a/core/models/depth_anything_3/model/utils/__pycache__/attention.cpython-313.pyc b/core/models/depth_anything_3/model/utils/__pycache__/attention.cpython-313.pyc new file mode 100644 index 0000000..8d1c087 Binary files /dev/null and b/core/models/depth_anything_3/model/utils/__pycache__/attention.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/utils/__pycache__/block.cpython-313.pyc b/core/models/depth_anything_3/model/utils/__pycache__/block.cpython-313.pyc new file mode 100644 index 0000000..c3b6160 Binary files /dev/null and b/core/models/depth_anything_3/model/utils/__pycache__/block.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/utils/__pycache__/gs_renderer.cpython-313.pyc b/core/models/depth_anything_3/model/utils/__pycache__/gs_renderer.cpython-313.pyc new file mode 100644 index 0000000..5d63c0a Binary files /dev/null and b/core/models/depth_anything_3/model/utils/__pycache__/gs_renderer.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/utils/__pycache__/head_utils.cpython-313.pyc b/core/models/depth_anything_3/model/utils/__pycache__/head_utils.cpython-313.pyc new file mode 100644 index 0000000..f4f959c Binary files /dev/null and b/core/models/depth_anything_3/model/utils/__pycache__/head_utils.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/utils/__pycache__/transform.cpython-313.pyc b/core/models/depth_anything_3/model/utils/__pycache__/transform.cpython-313.pyc new file mode 100644 index 0000000..2439e2e Binary files /dev/null and b/core/models/depth_anything_3/model/utils/__pycache__/transform.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/model/utils/attention.py b/core/models/depth_anything_3/model/utils/attention.py new file mode 100644 index 0000000..49c07a8 --- /dev/null +++ b/core/models/depth_anything_3/model/utils/attention.py @@ -0,0 +1,109 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 # noqa + +from typing import Callable, Optional, Union +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + rope=None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor: + # Debug breakpoint removed for production + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + q = self.rope(q, pos) if self.rope is not None else q + k = self.rope(k, pos) if self.rope is not None else k + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + attn_mask=attn_mask, + ) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/core/models/depth_anything_3/model/utils/block.py b/core/models/depth_anything_3/model/utils/block.py new file mode 100644 index 0000000..993fb4c --- /dev/null +++ b/core/models/depth_anything_3/model/utils/block.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable +from torch import Tensor, nn + +from .attention import Attention, LayerScale, Mlp + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + rope=None, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + rope=rope, + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + + self.sample_drop_ratio = 0.0 # Equivalent to always having drop_path=0 + + def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor: + def attn_residual_func(x: Tensor, pos=None, attn_mask=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + # drop_path is always 0, so always take the else branch + x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask) + x = x + ffn_residual_func(x) + return x diff --git a/core/models/depth_anything_3/model/utils/gs_renderer.py b/core/models/depth_anything_3/model/utils/gs_renderer.py new file mode 100644 index 0000000..2929076 --- /dev/null +++ b/core/models/depth_anything_3/model/utils/gs_renderer.py @@ -0,0 +1,340 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from math import isqrt +from typing import Literal, Optional +import torch +from einops import rearrange, repeat +from tqdm import tqdm + +from core.models.depth_anything_3.specs import Gaussians +from core.models.depth_anything_3.utils.camera_trj_helpers import ( + interpolate_extrinsics, + interpolate_intrinsics, + render_dolly_zoom_path, + render_stabilization_path, + render_wander_path, + render_wobble_inter_path, +) +from core.models.depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, get_fov +from core.models.depth_anything_3.utils.logger import logger + +try: + from gsplat import rasterization +except ImportError: + logger.warn( + "Dependency `gsplat` is required for rendering 3DGS. " + "Install via: pip install git+https://github.com/nerfstudio-project/" + "gsplat.git@0b4dddf04cb687367602c01196913cde6a743d70" + ) + + +def render_3dgs( + extrinsics: torch.Tensor, # "batch_views 4 4", w2c + intrinsics: torch.Tensor, # "batch_views 3 3", normalized + image_shape: tuple[int, int], + gaussian: Gaussians, + background_color: Optional[torch.Tensor] = None, # "batch_views 3" + use_sh: bool = True, + num_view: int = 1, + color_mode: Literal["RGB+D", "RGB+ED"] = "RGB+D", + **kwargs, +) -> tuple[ + torch.Tensor, # "batch_views 3 height width" + torch.Tensor, # "batch_views height width" +]: + # extract gaussian params + gaussian_means = gaussian.means + gaussian_scales = gaussian.scales + gaussian_quats = gaussian.rotations + gaussian_opacities = gaussian.opacities + gaussian_sh_coefficients = gaussian.harmonics + b, _, _ = extrinsics.shape + + if background_color is None: + background_color = repeat(torch.tensor([0.0, 0.0, 0.0]), "c -> b c", b=b).to( + gaussian_sh_coefficients + ) + + if use_sh: + _, _, _, n = gaussian_sh_coefficients.shape + degree = isqrt(n) - 1 + shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous() + else: # use color + shs = ( + gaussian_sh_coefficients.squeeze(-1).sigmoid().contiguous() + ) # (b, g, c), normed to (0, 1) + + h, w = image_shape + + fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1) + tan_fov_x = (0.5 * fov_x).tan() + tan_fov_y = (0.5 * fov_y).tan() + focal_length_x = w / (2 * tan_fov_x) + focal_length_y = h / (2 * tan_fov_y) + + view_matrix = extrinsics.float() + + all_images = [] + all_radii = [] + all_depths = [] + # render view in a batch based, each batch contains one scene + # assume the Gaussian parameters are originally repeated along the view dim + batch_scene = b // num_view + + def index_i_gs_attr(full_attr, idx): + # return rearrange(full_attr, "(b v) ... -> b v ...", v=num_view)[idx, 0] + return full_attr[idx] + + for i in range(batch_scene): + K = repeat( + torch.tensor( + [ + [0, 0, w / 2.0], + [0, 0, h / 2.0], + [0, 0, 1], + ] + ), + "i j -> v i j", + v=num_view, + ).to(gaussian_means) + K[:, 0, 0] = focal_length_x.reshape(batch_scene, num_view)[i] + K[:, 1, 1] = focal_length_y.reshape(batch_scene, num_view)[i] + + i_means = index_i_gs_attr(gaussian_means, i) # [N, 3] + i_scales = index_i_gs_attr(gaussian_scales, i) + i_quats = index_i_gs_attr(gaussian_quats, i) + i_opacities = index_i_gs_attr(gaussian_opacities, i) # [N,] + i_colors = index_i_gs_attr(shs, i) # [N, K, 3] + i_viewmats = rearrange(view_matrix, "(b v) ... -> b v ...", v=num_view)[i] # [v, 4, 4] + i_backgrounds = rearrange(background_color, "(b v) ... -> b v ...", v=num_view)[ + i + ] # [v, 3] + + render_colors, render_alphas, info = rasterization( + means=i_means, + quats=i_quats, # [N, 4] + scales=i_scales, # [N, 3] + opacities=i_opacities, + colors=i_colors, + viewmats=i_viewmats, # [v, 4, 4] + Ks=K, # [v, 3, 3] + backgrounds=i_backgrounds, + render_mode=color_mode, + width=w, + height=h, + packed=False, + sh_degree=degree if use_sh else None, + ) + depth = render_colors[..., -1].unbind(dim=0) + + image = rearrange(render_colors[..., :3], "v h w c -> v c h w").unbind(dim=0) + radii = info["radii"].unbind(dim=0) + try: + info["means2d"].retain_grad() # [1, N, 2] + except Exception: + pass + all_images.extend(image) + all_depths.extend(depth) + all_radii.extend(radii) + + return torch.stack(all_images), torch.stack(all_depths) + + +def run_renderer_in_chunk_w_trj_mode( + gaussians: Gaussians, + extrinsics: torch.Tensor, # world2cam, "batch view 4 4" | "batch view 3 4" + intrinsics: torch.Tensor, # unnormed intrinsics, "batch view 3 3" + image_shape: tuple[int, int], + chunk_size: Optional[int] = 8, + trj_mode: Literal[ + "original", + "smooth", + "interpolate", + "interpolate_smooth", + "wander", + "dolly_zoom", + "extend", + "wobble_inter", + ] = "smooth", + input_shape: Optional[tuple[int, int]] = None, + enable_tqdm: Optional[bool] = False, + **kwargs, +) -> tuple[ + torch.Tensor, # color, "batch view 3 height width" + torch.Tensor, # depth, "batch view height width" +]: + cam2world = affine_inverse(as_homogeneous(extrinsics)) + if input_shape is not None: + in_h, in_w = input_shape + else: + in_h, in_w = image_shape + intr_normed = intrinsics.clone().detach() + intr_normed[..., 0, :] /= in_w + intr_normed[..., 1, :] /= in_h + if extrinsics.shape[1] <= 1: + assert trj_mode in [ + "wander", + "dolly_zoom", + ], "Please set trj_mode to 'wander' or 'dolly_zoom' when n_views=1" + + def _smooth_trj_fn_batch(raw_c2ws, k_size=50): + try: + smooth_c2ws = torch.stack( + [render_stabilization_path(c2w_i, k_size) for c2w_i in raw_c2ws], + dim=0, + ) + except Exception as e: + print(f"[DEBUG] Path smoothing failed with error: {e}.") + smooth_c2ws = raw_c2ws + return smooth_c2ws + + # get rendered trj + if trj_mode == "original": + tgt_c2w = cam2world + tgt_intr = intr_normed + elif trj_mode == "smooth": + tgt_c2w = _smooth_trj_fn_batch(cam2world) + tgt_intr = intr_normed + elif trj_mode in ["interpolate", "interpolate_smooth", "extend"]: + inter_len = 8 + total_len = (cam2world.shape[1] - 1) * inter_len + if total_len > 24 * 18: # no more than 18s + inter_len = max(1, 24 * 10 // (cam2world.shape[1] - 1)) + if total_len < 24 * 2: # no less than 2s + inter_len = max(1, 24 * 2 // (cam2world.shape[1] - 1)) + + if inter_len > 2: + t = torch.linspace(0, 1, inter_len, dtype=torch.float32, device=cam2world.device) + t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 + tgt_c2w_b = [] + tgt_intr_b = [] + for b_idx in range(cam2world.shape[0]): + tgt_c2w = [] + tgt_intr = [] + for cur_idx in range(cam2world.shape[1] - 1): + tgt_c2w.append( + interpolate_extrinsics( + cam2world[b_idx, cur_idx], cam2world[b_idx, cur_idx + 1], t + )[(0 if cur_idx == 0 else 1) :] + ) + tgt_intr.append( + interpolate_intrinsics( + intr_normed[b_idx, cur_idx], intr_normed[b_idx, cur_idx + 1], t + )[(0 if cur_idx == 0 else 1) :] + ) + tgt_c2w_b.append(torch.cat(tgt_c2w)) + tgt_intr_b.append(torch.cat(tgt_intr)) + tgt_c2w = torch.stack(tgt_c2w_b) # b v 4 4 + tgt_intr = torch.stack(tgt_intr_b) # b v 3 3 + else: + tgt_c2w = cam2world + tgt_intr = intr_normed + if trj_mode in ["interpolate_smooth", "extend"]: + tgt_c2w = _smooth_trj_fn_batch(tgt_c2w) + if trj_mode == "extend": + # apply dolly_zoom and wander in the middle frame + assert cam2world.shape[0] == 1, "extend only supports for batch_size=1 currently." + mid_idx = tgt_c2w.shape[1] // 2 + c2w_wd, intr_wd = render_wander_path( + tgt_c2w[0, mid_idx], + tgt_intr[0, mid_idx], + h=in_h, + w=in_w, + num_frames=max(36, min(60, mid_idx // 2)), + max_disp=24.0, + ) + c2w_dz, intr_dz = render_dolly_zoom_path( + tgt_c2w[0, mid_idx], + tgt_intr[0, mid_idx], + h=in_h, + w=in_w, + num_frames=max(36, min(60, mid_idx // 2)), + ) + tgt_c2w = torch.cat( + [ + tgt_c2w[:, :mid_idx], + c2w_wd.unsqueeze(0), + c2w_dz.unsqueeze(0), + tgt_c2w[:, mid_idx:], + ], + dim=1, + ) + tgt_intr = torch.cat( + [ + tgt_intr[:, :mid_idx], + intr_wd.unsqueeze(0), + intr_dz.unsqueeze(0), + tgt_intr[:, mid_idx:], + ], + dim=1, + ) + elif trj_mode in ["wander", "dolly_zoom"]: + if trj_mode == "wander": + render_fn = render_wander_path + extra_kwargs = {"max_disp": 24.0} + else: + render_fn = render_dolly_zoom_path + extra_kwargs = {"D_focus": 30.0, "max_disp": 2.0} + tgt_c2w = [] + tgt_intr = [] + for b_idx in range(cam2world.shape[0]): + c2w_i, intr_i = render_fn( + cam2world[b_idx, 0], intr_normed[b_idx, 0], h=in_h, w=in_w, **extra_kwargs + ) + tgt_c2w.append(c2w_i) + tgt_intr.append(intr_i) + tgt_c2w = torch.stack(tgt_c2w) + tgt_intr = torch.stack(tgt_intr) + elif trj_mode == "wobble_inter": + tgt_c2w, tgt_intr = render_wobble_inter_path( + cam2world=cam2world, + intr_normed=intr_normed, + inter_len=10, + n_skip=3, + ) + else: + raise Exception(f"trj mode [{trj_mode}] is not implemented.") + + _, v = tgt_c2w.shape[:2] + tgt_extr = affine_inverse(tgt_c2w) + if chunk_size is None: + chunk_size = v + chunk_size = min(v, chunk_size) + all_colors = [] + all_depths = [] + for chunk_idx in tqdm( + range(math.ceil(v / chunk_size)), + desc="Rendering novel views", + disable=(not enable_tqdm), + leave=False, + ): + s = int(chunk_idx * chunk_size) + e = int((chunk_idx + 1) * chunk_size) + cur_n_view = tgt_extr[:, s:e].shape[1] + color, depth = render_3dgs( + extrinsics=rearrange(tgt_extr[:, s:e], "b v ... -> (b v) ..."), # w2c + intrinsics=rearrange(tgt_intr[:, s:e], "b v ... -> (b v) ..."), # normed + image_shape=image_shape, + gaussian=gaussians, + num_view=cur_n_view, + **kwargs, + ) + all_colors.append(rearrange(color, "(b v) ... -> b v ...", v=cur_n_view)) + all_depths.append(rearrange(depth, "(b v) ... -> b v ...", v=cur_n_view)) + all_colors = torch.cat(all_colors, dim=1) + all_depths = torch.cat(all_depths, dim=1) + + return all_colors, all_depths diff --git a/core/models/depth_anything_3/model/utils/head_utils.py b/core/models/depth_anything_3/model/utils/head_utils.py new file mode 100644 index 0000000..c120958 --- /dev/null +++ b/core/models/depth_anything_3/model/utils/head_utils.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union +import torch +import torch.nn as nn +import torch.nn.functional as F + +# ----------------------------------------------------------------------------- +# Activation functions +# ----------------------------------------------------------------------------- + + +def activate_head_gs(out, activation="norm_exp", conf_activation="expp1", conf_dim=None): + """ + Process network output to extract GS params and density values. + Density could be view-dependent as SH coefficient + + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + # Move channels from last dim to the 4th dimension => (B, H, W, C) + fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected + + # Split into xyz (first C-1 channels) and confidence (last channel) + conf_dim = 1 if conf_dim is None else conf_dim + xyz = fmap[:, :, :, :-conf_dim] + conf = fmap[:, :, :, -1] if conf_dim == 1 else fmap[:, :, :, -conf_dim:] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + xyz_normed = xyz / d + pts3d = xyz_normed * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + elif conf_activation == "linear": + conf_out = conf + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +# ----------------------------------------------------------------------------- +# Other utilities +# ----------------------------------------------------------------------------- + + +class Permute(nn.Module): + """nn.Module wrapper around Tensor.permute for cleaner nn.Sequential usage.""" + + dims: Tuple[int, ...] + + def __init__(self, dims: Tuple[int, ...]) -> None: + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override] + return x.permute(*self.dims) + + +def position_grid_to_embed( + pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100 +) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) + + # Process x and y coordinates separately + emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] + emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] + + # Combine and reshape + emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] + + return emb.view(H, W, embed_dim) # [H, W, D] + + +def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. # noqa + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.float() + + +# Inspired by https://github.com/microsoft/moge + + +def create_uv_grid( + width: int, + height: int, + aspect_ratio: float = None, + dtype: torch.dtype = None, + device: torch.device = None, +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + # Derive aspect ratio if not explicitly provided + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + # Compute normalized spans for X and Y + diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + # Establish the linspace boundaries + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + # Generate 1D coordinates + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + # Create 2D meshgrid (width x height) and stack into UV + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid + + +# ----------------------------------------------------------------------------- +# Interpolation (safe interpolation, avoid INT_MAX overflow) +# ----------------------------------------------------------------------------- +def custom_interpolate( + x: torch.Tensor, + size: Union[Tuple[int, int], None] = None, + scale_factor: Union[float, None] = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Safe interpolation implementation to avoid INT_MAX overflow in torch.nn.functional.interpolate. + """ + if size is None: + assert scale_factor is not None, "Either size or scale_factor must be provided." + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + total = size[0] * size[1] * x.shape[0] * x.shape[1] + + if total > INT_MAX: + chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0) + outs = [ + nn.functional.interpolate(c, size=size, mode=mode, align_corners=align_corners) + for c in chunks + ] + return torch.cat(outs, dim=0).contiguous() + + return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) diff --git a/core/models/depth_anything_3/model/utils/transform.py b/core/models/depth_anything_3/model/utils/transform.py new file mode 100644 index 0000000..8d732b0 --- /dev/null +++ b/core/models/depth_anything_3/model/utils/transform.py @@ -0,0 +1,208 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + + +def extri_intri_to_pose_encoding( + extrinsics, + intrinsics, + image_size_hw=None, +): + """Convert camera extrinsics and intrinsics to a compact pose encoding.""" + + # extrinsics: BxSx3x4 + # intrinsics: BxSx3x3 + R = extrinsics[:, :, :3, :3] # BxSx3x3 + T = extrinsics[:, :, :3, 3] # BxSx3 + + quat = mat_to_quat(R) + # Note the order of h and w here + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() + + return pose_encoding + + +def pose_encoding_to_extri_intri( + pose_encoding, + image_size_hw=None, +): + """Convert a pose encoding back to camera extrinsics and intrinsics.""" + + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + + H, W = image_size_hw + fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6) + fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6) + intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 + + return extrinsics, intrinsics + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape( + batch_dim + (4,) + ) + + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) + + +def cam_quat_xyzw_to_world_quat_wxyz(cam_quat_xyzw, c2w): + # cam_quat_xyzw: (b, n, 4) in xyzw + # c2w: (b, n, 4, 4) + b, n = cam_quat_xyzw.shape[:2] + # 1. xyzw -> wxyz + cam_quat_wxyz = torch.cat( + [ + cam_quat_xyzw[..., 3:4], # w + cam_quat_xyzw[..., 0:1], # x + cam_quat_xyzw[..., 1:2], # y + cam_quat_xyzw[..., 2:3], # z + ], + dim=-1, + ) + # 2. Quaternion to matrix + cam_quat_wxyz_flat = cam_quat_wxyz.reshape(-1, 4) + rotmat_cam = quat_to_mat(cam_quat_wxyz_flat).reshape(b, n, 3, 3) + # 3. Transform to world space + rotmat_c2w = c2w[..., :3, :3] + rotmat_world = torch.matmul(rotmat_c2w, rotmat_cam) + # 4. Matrix to quaternion (wxyz) + rotmat_world_flat = rotmat_world.reshape(-1, 3, 3) + world_quat_wxyz_flat = mat_to_quat(rotmat_world_flat) + world_quat_wxyz = world_quat_wxyz_flat.reshape(b, n, 4) + return world_quat_wxyz diff --git a/core/models/depth_anything_3/registry.py b/core/models/depth_anything_3/registry.py new file mode 100644 index 0000000..bbc26f0 --- /dev/null +++ b/core/models/depth_anything_3/registry.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from pathlib import Path + + +def get_all_models() -> OrderedDict: + """ + Scans all YAML files in the configs directory and returns a sorted dictionary where: + - Keys are model names (YAML filenames without the .yaml extension) + - Values are absolute paths to the corresponding YAML files + """ + # Get path to the configs directory within the da3 package + # Works both in development and after pip installation + # configs_dir = files("depth_anything_3").joinpath("configs") + configs_dir = Path(__file__).resolve().parent / "configs" + + # Ensure path is a Path object for consistent cross-platform handling + configs_dir = Path(configs_dir) + + model_entries = [] + # Iterate through all items in the configs directory + for item in configs_dir.iterdir(): + # Filter for YAML files (excluding directories) + if item.is_file() and item.suffix == ".yaml": + # Extract model name (filename without .yaml extension) + model_name = item.stem + # Get absolute path (resolve() handles symlinks) + file_abs_path = str(item.resolve()) + model_entries.append((model_name, file_abs_path)) + + # Sort entries by model name and convert to OrderedDict + sorted_entries = sorted(model_entries, key=lambda x: x[0]) + return OrderedDict(sorted_entries) + + +# Global registry for external imports +MODEL_REGISTRY = get_all_models() \ No newline at end of file diff --git a/core/models/depth_anything_3/specs.py b/core/models/depth_anything_3/specs.py new file mode 100644 index 0000000..d5ef333 --- /dev/null +++ b/core/models/depth_anything_3/specs.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional +import numpy as np +import torch + + +@dataclass +class Gaussians: + """3DGS parameters, all in world space""" + + means: torch.Tensor # world points, "batch gaussian dim" + scales: torch.Tensor # scales_std, "batch gaussian 3" + rotations: torch.Tensor # world_quat_wxyz, "batch gaussian 4" + harmonics: torch.Tensor # world SH, "batch gaussian 3 d_sh" + opacities: torch.Tensor # opacity | opacity SH, "batch gaussian" | "batch gaussian 1 d_sh" + + +@dataclass +class Prediction: + depth: np.ndarray # N, H, W + is_metric: int + sky: np.ndarray | None = None # N, H, W + conf: np.ndarray | None = None # N, H, W + extrinsics: np.ndarray | None = None # N, 4, 4 + intrinsics: np.ndarray | None = None # N, 3, 3 + processed_images: np.ndarray | None = None # N, H, W, 3 - processed images for visualization + gaussians: Gaussians | None = None # 3D gaussians + aux: dict[str, Any] = None # + scale_factor: Optional[float] = None # metric scale \ No newline at end of file diff --git a/core/models/depth_anything_3/utils/__pycache__/alignment.cpython-313.pyc b/core/models/depth_anything_3/utils/__pycache__/alignment.cpython-313.pyc new file mode 100644 index 0000000..6aa496d Binary files /dev/null and b/core/models/depth_anything_3/utils/__pycache__/alignment.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/__pycache__/camera_trj_helpers.cpython-313.pyc b/core/models/depth_anything_3/utils/__pycache__/camera_trj_helpers.cpython-313.pyc new file mode 100644 index 0000000..18fa455 Binary files /dev/null and b/core/models/depth_anything_3/utils/__pycache__/camera_trj_helpers.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/__pycache__/constants.cpython-313.pyc b/core/models/depth_anything_3/utils/__pycache__/constants.cpython-313.pyc new file mode 100644 index 0000000..ff48bef Binary files /dev/null and b/core/models/depth_anything_3/utils/__pycache__/constants.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/__pycache__/geometry.cpython-313.pyc b/core/models/depth_anything_3/utils/__pycache__/geometry.cpython-313.pyc new file mode 100644 index 0000000..135431d Binary files /dev/null and b/core/models/depth_anything_3/utils/__pycache__/geometry.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/__pycache__/gsply_helpers.cpython-313.pyc b/core/models/depth_anything_3/utils/__pycache__/gsply_helpers.cpython-313.pyc new file mode 100644 index 0000000..4f1aa85 Binary files /dev/null and b/core/models/depth_anything_3/utils/__pycache__/gsply_helpers.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/__pycache__/layout_helpers.cpython-313.pyc b/core/models/depth_anything_3/utils/__pycache__/layout_helpers.cpython-313.pyc new file mode 100644 index 0000000..154aa12 Binary files /dev/null and b/core/models/depth_anything_3/utils/__pycache__/layout_helpers.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/__pycache__/logger.cpython-313.pyc b/core/models/depth_anything_3/utils/__pycache__/logger.cpython-313.pyc new file mode 100644 index 0000000..ffdce54 Binary files /dev/null and b/core/models/depth_anything_3/utils/__pycache__/logger.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/__pycache__/parallel_utils.cpython-313.pyc b/core/models/depth_anything_3/utils/__pycache__/parallel_utils.cpython-313.pyc new file mode 100644 index 0000000..f53c8a9 Binary files /dev/null and b/core/models/depth_anything_3/utils/__pycache__/parallel_utils.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/__pycache__/pca_utils.cpython-313.pyc b/core/models/depth_anything_3/utils/__pycache__/pca_utils.cpython-313.pyc new file mode 100644 index 0000000..2453a4c Binary files /dev/null and b/core/models/depth_anything_3/utils/__pycache__/pca_utils.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/__pycache__/pose_align.cpython-313.pyc b/core/models/depth_anything_3/utils/__pycache__/pose_align.cpython-313.pyc new file mode 100644 index 0000000..fa5b120 Binary files /dev/null and b/core/models/depth_anything_3/utils/__pycache__/pose_align.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/__pycache__/ray_utils.cpython-313.pyc b/core/models/depth_anything_3/utils/__pycache__/ray_utils.cpython-313.pyc new file mode 100644 index 0000000..6e463f2 Binary files /dev/null and b/core/models/depth_anything_3/utils/__pycache__/ray_utils.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/__pycache__/sh_helpers.cpython-313.pyc b/core/models/depth_anything_3/utils/__pycache__/sh_helpers.cpython-313.pyc new file mode 100644 index 0000000..fbf36ef Binary files /dev/null and b/core/models/depth_anything_3/utils/__pycache__/sh_helpers.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/__pycache__/visualize.cpython-313.pyc b/core/models/depth_anything_3/utils/__pycache__/visualize.cpython-313.pyc new file mode 100644 index 0000000..af89897 Binary files /dev/null and b/core/models/depth_anything_3/utils/__pycache__/visualize.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/alignment.py b/core/models/depth_anything_3/utils/alignment.py new file mode 100644 index 0000000..ceb8983 --- /dev/null +++ b/core/models/depth_anything_3/utils/alignment.py @@ -0,0 +1,163 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Alignment utilities for depth estimation and metric scaling. +""" + +from typing import Tuple +import torch + + +def least_squares_scale_scalar( + a: torch.Tensor, b: torch.Tensor, eps: float = 1e-12 +) -> torch.Tensor: + """ + Compute least squares scale factor s such that a ≈ s * b. + + Args: + a: First tensor + b: Second tensor + eps: Small epsilon for numerical stability + + Returns: + Scalar tensor containing the scale factor + + Raises: + ValueError: If tensors have mismatched shapes or devices + TypeError: If tensors are not floating point + """ + if a.shape != b.shape: + raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}") + if a.device != b.device: + raise ValueError(f"Device mismatch: {a.device} vs {b.device}") + if not a.is_floating_point() or not b.is_floating_point(): + raise TypeError("Tensors must be floating point type") + + # Compute dot products for least squares solution + num = torch.dot(a.reshape(-1), b.reshape(-1)) + den = torch.dot(b.reshape(-1), b.reshape(-1)).clamp_min(eps) + return num / den + + +def compute_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor: + """ + Compute non-sky mask from sky prediction. + + Args: + sky_prediction: Sky prediction tensor + threshold: Threshold for sky classification + + Returns: + Boolean mask where True indicates non-sky regions + """ + return sky_prediction < threshold + + +def compute_alignment_mask( + depth_conf: torch.Tensor, + non_sky_mask: torch.Tensor, + depth: torch.Tensor, + metric_depth: torch.Tensor, + median_conf: torch.Tensor, + min_depth_threshold: float = 1e-3, + min_metric_depth_threshold: float = 1e-2, +) -> torch.Tensor: + """ + Compute mask for depth alignment based on confidence and depth thresholds. + + Args: + depth_conf: Depth confidence tensor + non_sky_mask: Non-sky region mask + depth: Predicted depth tensor + metric_depth: Metric depth tensor + median_conf: Median confidence threshold + min_depth_threshold: Minimum depth threshold + min_metric_depth_threshold: Minimum metric depth threshold + + Returns: + Boolean mask for valid alignment regions + """ + return ( + (depth_conf >= median_conf) + & non_sky_mask + & (metric_depth > min_metric_depth_threshold) + & (depth > min_depth_threshold) + ) + + +def sample_tensor_for_quantile(tensor: torch.Tensor, max_samples: int = 100000) -> torch.Tensor: + """ + Sample tensor elements for quantile computation to reduce memory usage. + + Args: + tensor: Input tensor to sample + max_samples: Maximum number of samples to take + + Returns: + Sampled tensor + """ + if tensor.numel() <= max_samples: + return tensor + + idx = torch.randperm(tensor.numel(), device=tensor.device)[:max_samples] + return tensor.flatten()[idx] + + +def apply_metric_scaling( + depth: torch.Tensor, intrinsics: torch.Tensor, scale_factor: float = 300.0 +) -> torch.Tensor: + """ + Apply metric scaling to depth based on camera intrinsics. + + Args: + depth: Input depth tensor + intrinsics: Camera intrinsics tensor + scale_factor: Scaling factor for metric conversion + + Returns: + Scaled depth tensor + """ + focal_length = (intrinsics[:, :, 0, 0] + intrinsics[:, :, 1, 1]) / 2 + return depth * (focal_length[:, :, None, None] / scale_factor) + + +def set_sky_regions_to_max_depth( + depth: torch.Tensor, + depth_conf: torch.Tensor, + non_sky_mask: torch.Tensor, + max_depth: float = 200.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Set sky regions to maximum depth and high confidence. + + Args: + depth: Depth tensor + depth_conf: Depth confidence tensor + non_sky_mask: Non-sky region mask + max_depth: Maximum depth value for sky regions + + Returns: + Tuple of (updated_depth, updated_depth_conf) + """ + depth = depth.clone() + + # Set sky regions to max depth and high confidence + depth[~non_sky_mask] = max_depth + if depth_conf is not None: + depth_conf = depth_conf.clone() + depth_conf[~non_sky_mask] = 1.0 + return depth, depth_conf + else: + return depth, None diff --git a/core/models/depth_anything_3/utils/api_helpers.py b/core/models/depth_anything_3/utils/api_helpers.py new file mode 100644 index 0000000..b327331 --- /dev/null +++ b/core/models/depth_anything_3/utils/api_helpers.py @@ -0,0 +1,58 @@ +import argparse + + +def parse_scalar(s): + if not isinstance(s, str): + return s + t = s.strip() + l = t.lower() + if l == "true": + return True + if l == "false": + return False + if l in ("none", "null"): + return None + try: + return int(t, 10) + except Exception: + pass + try: + return float(t) + except Exception: + return s + + +def fn_kv_csv(s: str) -> dict[str, dict[str, object]]: + """ + Parse a string of comma-separated triplets: fn:key:value + + Returns: + dict[fn_name] -> dict[key] = parsed_value + + Example: + "fn1:width:1920,fn1:height:1080,fn2:quality:0.8" + -> {"fn1": {"width": 1920, "height": 1080}, "fn2": {"quality": 0.8}} + """ + result: dict[str, dict[str, object]] = {} + if not s: + return result + + for item in s.split(","): + if not item: + continue + parts = item.split(":", 2) # allow value to contain ":" beyond first two separators + if len(parts) < 3: + raise argparse.ArgumentTypeError(f"Bad item '{item}', expected FN:KEY:VALUE") + fn, key, raw_val = parts[0], parts[1], parts[2] + # If you need to allow colons in values, join leftover parts: + # fn, key, raw_val = parts[0], parts[1], ":".join(parts[2:]) + + if not fn: + raise argparse.ArgumentTypeError(f"Bad item '{item}': empty function name") + if not key: + raise argparse.ArgumentTypeError(f"Bad item '{item}': empty key") + + val = parse_scalar(raw_val) + bucket = result.setdefault(fn, {}) + bucket[key] = val + return result diff --git a/core/models/depth_anything_3/utils/camera_trj_helpers.py b/core/models/depth_anything_3/utils/camera_trj_helpers.py new file mode 100644 index 0000000..8b30d5a --- /dev/null +++ b/core/models/depth_anything_3/utils/camera_trj_helpers.py @@ -0,0 +1,479 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from einops import einsum, rearrange, reduce + +try: + from scipy.spatial.transform import Rotation as R +except ImportError: + from core.models.depth_anything_3.utils.logger import logger + + logger.warn("Dependency 'scipy' not found. Required for interpolating camera trajectory.") + +from core.models.depth_anything_3.utils.geometry import as_homogeneous + + +@torch.no_grad() +def render_stabilization_path(poses, k_size=45): + """Rendering stabilized camera path. + poses: [batch, 4, 4] or [batch, 3, 4], + return: + smooth path: [batch 4 4]""" + num_frames = poses.shape[0] + device = poses.device + dtype = poses.dtype + + # Early exit for trivial cases + if num_frames <= 1: + return as_homogeneous(poses) + + # Make k_size safe: positive odd and not larger than num_frames + # 1) Ensure odd + if k_size < 1: + k_size = 1 + if k_size % 2 == 0: + k_size += 1 + # 2) Cap to num_frames (keep odd) + max_odd = num_frames if (num_frames % 2 == 1) else (num_frames - 1) + if max_odd < 1: + max_odd = 1 # covers num_frames == 0 theoretically + k_size = min(k_size, max_odd) + # 3) enforce a minimum of 3 when possible (for better smoothing) + if num_frames >= 3 and k_size < 3: + k_size = 3 + + input_poses = [] + for i in range(num_frames): + input_poses.append( + torch.cat([poses[i, :3, 0:1], poses[i, :3, 1:2], poses[i, :3, 3:4]], dim=-1) + ) + input_poses = torch.stack(input_poses) # (num_frames, 3, 3) + + # Prepare Gaussian kernel + gaussian_kernel = cv2.getGaussianKernel(ksize=k_size, sigma=-1).astype(np.float32).squeeze() + gaussian_kernel = torch.tensor(gaussian_kernel, dtype=dtype, device=device).view(1, 1, -1) + pad = k_size // 2 + + output_vectors = [] + for idx in range(3): # For r1, r2, t + vec = ( + input_poses[:, :, idx].T.unsqueeze(0).unsqueeze(0) + ) # (1, 1, 3, num_frames) -> (1, 1, 3, num_frames) + # But actually, we want (batch=3, channel=1, width=num_frames) + # So: + vec = input_poses[:, :, idx].T.unsqueeze(1) # (3, 1, num_frames) + vec_padded = F.pad(vec, (pad, pad), mode="reflect") + filtered = F.conv1d(vec_padded, gaussian_kernel) + output_vectors.append(filtered.squeeze(1).T) # (num_frames, 3) + + output_r1, output_r2, output_t = output_vectors # Each is (num_frames, 3) + + # Normalize r1 and r2 + output_r1 = output_r1 / output_r1.norm(dim=-1, keepdim=True) + output_r2 = output_r2 / output_r2.norm(dim=-1, keepdim=True) + + output_poses = [] + for i in range(num_frames): + output_r3 = torch.linalg.cross(output_r1[i], output_r2[i]) + render_pose = torch.cat( + [ + output_r1[i].unsqueeze(-1), + output_r2[i].unsqueeze(-1), + output_r3.unsqueeze(-1), + output_t[i].unsqueeze(-1), + ], + dim=-1, + ) + output_poses.append(render_pose[:3, :]) + output_poses = as_homogeneous(torch.stack(output_poses, dim=0)) + + return output_poses + + +@torch.no_grad() +def render_wander_path( + cam2world: torch.Tensor, + intrinsic: torch.Tensor, + h: int, + w: int, + num_frames: int = 120, + max_disp: float = 48.0, +): + device, dtype = cam2world.device, cam2world.dtype + fx = intrinsic[0, 0] * w + r = max_disp / fx + th = torch.linspace(0, 2.0 * torch.pi, steps=num_frames, device=device, dtype=dtype) + x = r * torch.sin(th) + yz = r * torch.cos(th) / 3.0 + T = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).repeat(num_frames, 1, 1) + T[:, :3, 3] = torch.stack([x, yz, yz], dim=-1) * -1.0 + c2ws = cam2world.unsqueeze(0) @ T + # Start at reference pose and end back at reference pose + c2ws = torch.cat([cam2world.unsqueeze(0), c2ws, cam2world.unsqueeze(0)], dim=0) + Ks = intrinsic.unsqueeze(0).repeat(c2ws.shape[0], 1, 1) + return c2ws, Ks + + +@torch.no_grad() +def render_dolly_zoom_path( + cam2world: torch.Tensor, + intrinsic: torch.Tensor, + h: int, + w: int, + num_frames: int = 120, + max_disp: float = 0.1, + D_focus: float = 10.0, +): + device, dtype = cam2world.device, cam2world.dtype + fx0, fy0 = intrinsic[0, 0] * w, intrinsic[1, 1] * h + t = torch.linspace(0.0, 2.0, steps=num_frames, device=device, dtype=dtype) + z = 0.5 * (1.0 - torch.cos(torch.pi * t)) * max_disp + T = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).repeat(num_frames, 1, 1) + T[:, 2, 3] = -z + c2ws = cam2world.unsqueeze(0) @ T + Df = torch.as_tensor(D_focus, device=device, dtype=dtype) + scale = (Df / (Df + z)).clamp(min=1e-6) + Ks = intrinsic.unsqueeze(0).repeat(num_frames, 1, 1) + Ks[:, 0, 0] = (fx0 * scale) / w + Ks[:, 1, 1] = (fy0 * scale) / h + return c2ws, Ks + + +@torch.no_grad() +def interpolate_intrinsics( + initial: torch.Tensor, # "*#batch 3 3" + final: torch.Tensor, # "*#batch 3 3" + t: torch.Tensor, # " time_step" +) -> torch.Tensor: # "*batch time_step 3 3" + initial = rearrange(initial, "... i j -> ... () i j") + final = rearrange(final, "... i j -> ... () i j") + t = rearrange(t, "t -> t () ()") + return initial + (final - initial) * t + + +def intersect_rays( + a_origins: torch.Tensor, # "*#batch dim" + a_directions: torch.Tensor, # "*#batch dim" + b_origins: torch.Tensor, # "*#batch dim" + b_directions: torch.Tensor, # "*#batch dim" +) -> torch.Tensor: # "*batch dim" + """Compute the least-squares intersection of rays. Uses the math from here: + https://math.stackexchange.com/a/1762491/286022 + """ + + # Broadcast and stack the tensors. + a_origins, a_directions, b_origins, b_directions = torch.broadcast_tensors( + a_origins, a_directions, b_origins, b_directions + ) + origins = torch.stack((a_origins, b_origins), dim=-2) + directions = torch.stack((a_directions, b_directions), dim=-2) + + # Compute n_i * n_i^T - eye(3) from the equation. + n = einsum(directions, directions, "... n i, ... n j -> ... n i j") + n = n - torch.eye(3, dtype=origins.dtype, device=origins.device) + + # Compute the left-hand side of the equation. + lhs = reduce(n, "... n i j -> ... i j", "sum") + + # Compute the right-hand side of the equation. + rhs = einsum(n, origins, "... n i j, ... n j -> ... n i") + rhs = reduce(rhs, "... n i -> ... i", "sum") + + # Left-matrix-multiply both sides by the inverse of lhs to find p. + return torch.linalg.lstsq(lhs, rhs).solution + + +def normalize(a: torch.Tensor) -> torch.Tensor: # "*#batch dim" -> "*#batch dim" + return a / a.norm(dim=-1, keepdim=True) + + +def generate_coordinate_frame( + y: torch.Tensor, # "*#batch 3" + z: torch.Tensor, # "*#batch 3" +) -> torch.Tensor: # "*batch 3 3" + """Generate a coordinate frame given perpendicular, unit-length Y and Z vectors.""" + y, z = torch.broadcast_tensors(y, z) + return torch.stack([y.cross(z, dim=-1), y, z], dim=-1) + + +def generate_rotation_coordinate_frame( + a: torch.Tensor, # "*#batch 3" + b: torch.Tensor, # "*#batch 3" + eps: float = 1e-4, +) -> torch.Tensor: # "*batch 3 3" + """Generate a coordinate frame where the Y direction is normal to the plane defined + by unit vectors a and b. The other axes are arbitrary.""" + device = a.device + + # Replace every entry in b that's parallel to the corresponding entry in a with an + # arbitrary vector. + b = b.detach().clone() + parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps + b[parallel] = torch.tensor([0, 0, 1], dtype=b.dtype, device=device) + parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps + b[parallel] = torch.tensor([0, 1, 0], dtype=b.dtype, device=device) + + # Generate the coordinate frame. The initial cross product defines the plane. + return generate_coordinate_frame(normalize(torch.linalg.cross(a, b)), a) + + +def matrix_to_euler( + rotations: torch.Tensor, # "*batch 3 3" + pattern: str, +) -> torch.Tensor: # "*batch 3" + *batch, _, _ = rotations.shape + rotations = rotations.reshape(-1, 3, 3) + angles_np = R.from_matrix(rotations.detach().cpu().numpy()).as_euler(pattern) + rotations = torch.tensor(angles_np, dtype=rotations.dtype, device=rotations.device) + return rotations.reshape(*batch, 3) + + +def euler_to_matrix( + rotations: torch.Tensor, # "*batch 3" + pattern: str, +) -> torch.Tensor: # "*batch 3 3" + *batch, _ = rotations.shape + rotations = rotations.reshape(-1, 3) + matrix_np = R.from_euler(pattern, rotations.detach().cpu().numpy()).as_matrix() + rotations = torch.tensor(matrix_np, dtype=rotations.dtype, device=rotations.device) + return rotations.reshape(*batch, 3, 3) + + +def extrinsics_to_pivot_parameters( + extrinsics: torch.Tensor, # "*#batch 4 4" + pivot_coordinate_frame: torch.Tensor, # "*#batch 3 3" + pivot_point: torch.Tensor, # "*#batch 3" +) -> torch.Tensor: # "*batch 5" + """Convert the extrinsics to a representation with 5 degrees of freedom: + 1. Distance from pivot point in the "X" (look cross pivot axis) direction. + 2. Distance from pivot point in the "Y" (pivot axis) direction. + 3. Distance from pivot point in the Z (look) direction + 4. Angle in plane + 5. Twist (rotation not in plane) + """ + + # The pivot coordinate frame's Z axis is normal to the plane. + pivot_axis = pivot_coordinate_frame[..., :, 1] + + # Compute the translation elements of the pivot parametrization. + translation_frame = generate_coordinate_frame(pivot_axis, extrinsics[..., :3, 2]) + origin = extrinsics[..., :3, 3] + delta = pivot_point - origin + translation = einsum(translation_frame, delta, "... i j, ... i -> ... j") + + # Add the rotation elements of the pivot parametrization. + inverted = pivot_coordinate_frame.inverse() @ extrinsics[..., :3, :3] + y, _, z = matrix_to_euler(inverted, "YXZ").unbind(dim=-1) + + return torch.cat([translation, y[..., None], z[..., None]], dim=-1) + + +def pivot_parameters_to_extrinsics( + parameters: torch.Tensor, # "*#batch 5" + pivot_coordinate_frame: torch.Tensor, # "*#batch 3 3" + pivot_point: torch.Tensor, # "*#batch 3" +) -> torch.Tensor: # "*batch 4 4" + translation, y, z = parameters.split((3, 1, 1), dim=-1) + + euler = torch.cat((y, torch.zeros_like(y), z), dim=-1) + rotation = pivot_coordinate_frame @ euler_to_matrix(euler, "YXZ") + + # The pivot coordinate frame's Z axis is normal to the plane. + pivot_axis = pivot_coordinate_frame[..., :, 1] + + translation_frame = generate_coordinate_frame(pivot_axis, rotation[..., :3, 2]) + delta = einsum(translation_frame, translation, "... i j, ... j -> ... i") + origin = pivot_point - delta + + *batch, _ = origin.shape + extrinsics = torch.eye(4, dtype=parameters.dtype, device=parameters.device) + extrinsics = extrinsics.broadcast_to((*batch, 4, 4)).clone() + extrinsics[..., 3, 3] = 1 + extrinsics[..., :3, :3] = rotation + extrinsics[..., :3, 3] = origin + return extrinsics + + +def interpolate_circular( + a: torch.Tensor, # "*#batch" + b: torch.Tensor, # "*#batch" + t: torch.Tensor, # "*#batch" +) -> torch.Tensor: # " *batch" + a, b, t = torch.broadcast_tensors(a, b, t) + + tau = 2 * torch.pi + a = a % tau + b = b % tau + + # Consider piecewise edge cases. + d = (b - a).abs() + a_left = a - tau + d_left = (b - a_left).abs() + a_right = a + tau + d_right = (b - a_right).abs() + use_d = (d < d_left) & (d < d_right) + use_d_left = (d_left < d_right) & (~use_d) + use_d_right = (~use_d) & (~use_d_left) + + result = a + (b - a) * t + result[use_d_left] = (a_left + (b - a_left) * t)[use_d_left] + result[use_d_right] = (a_right + (b - a_right) * t)[use_d_right] + + return result + + +def interpolate_pivot_parameters( + initial: torch.Tensor, # "*#batch 5" + final: torch.Tensor, # "*#batch 5" + t: torch.Tensor, # " time_step" +) -> torch.Tensor: # "*batch time_step 5" + initial = rearrange(initial, "... d -> ... () d") + final = rearrange(final, "... d -> ... () d") + t = rearrange(t, "t -> t ()") + ti, ri = initial.split((3, 2), dim=-1) + tf, rf = final.split((3, 2), dim=-1) + + t_lerp = ti + (tf - ti) * t + r_lerp = interpolate_circular(ri, rf, t) + + return torch.cat((t_lerp, r_lerp), dim=-1) + + +@torch.no_grad() +def interpolate_extrinsics( + initial: torch.Tensor, # "*#batch 4 4" + final: torch.Tensor, # "*#batch 4 4" + t: torch.Tensor, # " time_step" + eps: float = 1e-4, +) -> torch.Tensor: # "*batch time_step 4 4" + """Interpolate extrinsics by rotating around their "focus point," which is the + least-squares intersection between the look vectors of the initial and final + extrinsics. + """ + + initial = initial.type(torch.float64) + final = final.type(torch.float64) + t = t.type(torch.float64) + + # Based on the dot product between the look vectors, pick from one of two cases: + # 1. Look vectors are parallel: interpolate about their origins' midpoint. + # 3. Look vectors aren't parallel: interpolate about their focus point. + initial_look = initial[..., :3, 2] + final_look = final[..., :3, 2] + dot_products = einsum(initial_look, final_look, "... i, ... i -> ...") + parallel_mask = (dot_products.abs() - 1).abs() < eps + + # Pick focus points. + initial_origin = initial[..., :3, 3] + final_origin = final[..., :3, 3] + pivot_point = 0.5 * (initial_origin + final_origin) + pivot_point[~parallel_mask] = intersect_rays( + initial_origin[~parallel_mask], + initial_look[~parallel_mask], + final_origin[~parallel_mask], + final_look[~parallel_mask], + ) + + # Convert to pivot parameters. + pivot_frame = generate_rotation_coordinate_frame(initial_look, final_look, eps=eps) + initial_params = extrinsics_to_pivot_parameters(initial, pivot_frame, pivot_point) + final_params = extrinsics_to_pivot_parameters(final, pivot_frame, pivot_point) + + # Interpolate the pivot parameters. + interpolated_params = interpolate_pivot_parameters(initial_params, final_params, t) + + # Convert back. + return pivot_parameters_to_extrinsics( + interpolated_params.type(torch.float32), + rearrange(pivot_frame, "... i j -> ... () i j").type(torch.float32), + rearrange(pivot_point, "... xyz -> ... () xyz").type(torch.float32), + ) + + +@torch.no_grad() +def generate_wobble_transformation( + radius: torch.Tensor, # "*#batch" + t: torch.Tensor, # " time_step" + num_rotations: int = 1, + scale_radius_with_t: bool = True, +) -> torch.Tensor: # "*batch time_step 4 4"]: + # Generate a translation in the image plane. + tf = torch.eye(4, dtype=torch.float32, device=t.device) + tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() + radius = radius[..., None] + if scale_radius_with_t: + radius = radius * t + tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius + tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius + return tf + + +@torch.no_grad() +def render_wobble_inter_path( + cam2world: torch.Tensor, intr_normed: torch.Tensor, inter_len: int, n_skip: int = 3 +): + """ + cam2world: [batch, 4, 4], + intr_normed: [batch, 3, 3] + """ + frame_per_round = n_skip * inter_len + num_rotations = 1 + + t = torch.linspace(0, 1, frame_per_round, dtype=torch.float32, device=cam2world.device) + # t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 + tgt_c2w_b = [] + tgt_intr_b = [] + for b_idx in range(cam2world.shape[0]): + tgt_c2w = [] + tgt_intr = [] + for cur_idx in range(0, cam2world.shape[1] - n_skip, n_skip): + origin_a = cam2world[b_idx, cur_idx, :3, 3] + origin_b = cam2world[b_idx, cur_idx + n_skip, :3, 3] + delta = (origin_a - origin_b).norm(dim=-1) + if cur_idx == 0: + delta_prev = delta + else: + delta = (delta_prev + delta) / 2 + delta_prev = delta + tf = generate_wobble_transformation( + radius=delta * 0.5, + t=t, + num_rotations=num_rotations, + scale_radius_with_t=False, + ) + cur_extrs = ( + interpolate_extrinsics( + cam2world[b_idx, cur_idx], + cam2world[b_idx, cur_idx + n_skip], + t, + ) + @ tf + ) + tgt_c2w.append(cur_extrs[(0 if cur_idx == 0 else 1) :]) + tgt_intr.append( + interpolate_intrinsics( + intr_normed[b_idx, cur_idx], + intr_normed[b_idx, cur_idx + n_skip], + t, + )[(0 if cur_idx == 0 else 1) :] + ) + tgt_c2w_b.append(torch.cat(tgt_c2w)) + tgt_intr_b.append(torch.cat(tgt_intr)) + tgt_c2w = torch.stack(tgt_c2w_b) # b v 4 4 + tgt_intr = torch.stack(tgt_intr_b) # b v 3 3 + return tgt_c2w, tgt_intr diff --git a/core/models/depth_anything_3/utils/constants.py b/core/models/depth_anything_3/utils/constants.py new file mode 100644 index 0000000..25c330e --- /dev/null +++ b/core/models/depth_anything_3/utils/constants.py @@ -0,0 +1,270 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +DEFAULT_MODEL = "depth-anything/DA3NESTED-GIANT-LARGE-1.1" +DEFAULT_EXPORT_DIR = "workspace/gallery/scene" +DEFAULT_GALLERY_DIR = "workspace/gallery" +DEFAULT_GRADIO_DIR = "workspace/gradio" +THRESH_FOR_REF_SELECTION = 3 + +# ============================================================================= +# Benchmark Evaluation Constants +# ============================================================================= + +# Default evaluation workspace directory +DEFAULT_EVAL_WORKSPACE = "workspace/evaluation" + +# Default reference view selection strategy for evaluation +# Use "first" for consistent and reproducible evaluation results +# Other options: "saddle_balanced", "auto", "mid" +EVAL_REF_VIEW_STRATEGY = "first" + +# ----------------------------------------------------------------------------- +# DTU Dataset Configuration +# Reference: https://roboimagedata.compute.dtu.dk/ +# Note: DepthAnything3 was never trained on any images from DTU. +# ----------------------------------------------------------------------------- + +# Root directory for DTU evaluation data (MVSNet format) +# Download from: https://drive.google.com/file/d/1rX0EXlUL4prRxrRu2DgLJv2j7-tpUD4D/view +DTU_EVAL_DATA_ROOT = "workspace/benchmark_dataset/dtu" + +# List of DTU evaluation scenes +DTU_SCENES = [ + "scan1", + "scan4", + "scan9", + "scan10", + "scan11", + "scan12", + "scan13", + "scan15", + "scan23", + "scan24", + "scan29", + "scan32", + "scan33", + "scan34", + "scan48", + "scan49", + "scan62", + "scan75", + "scan77", + "scan110", + "scan114", + "scan118", +] + +# Point cloud fusion hyperparameters +DTU_DIST_THRESH = 0.2 # Distance threshold for geometric consistency (mm) +DTU_NUM_CONSIST = 4 # Minimum number of consistent views for a point +DTU_MAX_POINTS = 4_000_000 # Maximum points in fused point cloud + +# 3D reconstruction evaluation hyperparameters +DTU_DOWN_DENSE = 0.2 # Downsample density for evaluation (mm) +DTU_PATCH_SIZE = 60 # Patch size for boundary handling +DTU_MAX_DIST = 20 # Outlier threshold for accuracy/completeness (mm) + +# ----------------------------------------------------------------------------- +# DTU-64 Dataset Configuration (Pose Evaluation Only) +# This is a subset of DTU with 64 images per scene for pose evaluation. +# Note: This dataset is ONLY for pose evaluation, not 3D reconstruction. +# ----------------------------------------------------------------------------- + +# Root directory for DTU-64 evaluation data +DTU64_EVAL_DATA_ROOT = "workspace/benchmark_dataset/dtu64" +DTU64_CAMERA_ROOT = "workspace/benchmark_dataset/dtu64/Cameras" + +# List of DTU-64 evaluation scenes (13 scenes) +DTU64_SCENES = [ + "scan105", + "scan114", + "scan118", + "scan122", + "scan24", + "scan37", + "scan40", + "scan55", + "scan63", + "scan65", + "scan69", + "scan83", + "scan97", +] + +# ----------------------------------------------------------------------------- +# ETH3D Dataset Configuration +# Reference: https://www.eth3d.net/ +# High-resolution multi-view stereo benchmark with laser-scanned ground truth. +# Note: DepthAnything3 was never trained on any images from ETH3D. +# ----------------------------------------------------------------------------- + +# Root directory for ETH3D evaluation data +ETH3D_EVAL_DATA_ROOT = "workspace/benchmark_dataset/eth3d" + +# List of ETH3D evaluation scenes (indoor and outdoor) +ETH3D_SCENES = [ + "courtyard", + "electro", + "kicker", + "pipes", + "relief", + # "terrace", # Excluded: known issues + "delivery_area", + "facade", + # "meadow", # Excluded: known issues + "office", + "playground", + "relief_2", + "terrains", +] + +# Images to filter out (known problematic views per scene) +ETH3D_FILTER_KEYS = { + "delivery_area": ["711.JPG", "712.JPG", "713.JPG", "714.JPG"], + "electro": ["9289.JPG", "9290.JPG", "9291.JPG", "9292.JPG", "9293.JPG", "9298.JPG"], + "playground": ["587.JPG", "588.JPG", "589.JPG", "590.JPG", "591.JPG", "592.JPG"], + "relief": [ + "427.JPG", "428.JPG", "429.JPG", "430.JPG", "431.JPG", "432.JPG", + "433.JPG", "434.JPG", "435.JPG", "436.JPG", "437.JPG", "438.JPG", + ], + "relief_2": [ + "458.JPG", "459.JPG", "460.JPG", "461.JPG", "462.JPG", "463.JPG", + "464.JPG", "465.JPG", "466.JPG", "467.JPG", "468.JPG", + ], +} + +# TSDF fusion hyperparameters (scaled for outdoor scenes) +ETH3D_VOXEL_LENGTH = 4.0 / 512.0 * 5 # Voxel size for TSDF (meters) +ETH3D_SDF_TRUNC = 0.04 * 5 # SDF truncation distance (meters) +ETH3D_MAX_DEPTH = 100000.0 # Maximum depth for integration (effectively no truncation) + +# Point cloud sampling +ETH3D_SAMPLING_NUMBER = 1_000_000 # Number of points to sample from mesh + +# 3D reconstruction evaluation hyperparameters +ETH3D_EVAL_THRESHOLD = 0.05 * 5 # Distance threshold for precision/recall (meters) +ETH3D_DOWN_SAMPLE = 4.0 / 512.0 * 5 # Voxel size for evaluation downsampling (meters) + + +# ============================================================================== +# 7Scenes Dataset Configuration +# ============================================================================== +# Reference: https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/ +# Note: Indoor RGB-D dataset with ground truth poses and meshes. + +# Root directory for 7Scenes evaluation data +SEVENSCENES_EVAL_DATA_ROOT = "workspace/benchmark_dataset/7scenes" + +# List of 7Scenes evaluation scenes +SEVENSCENES_SCENES = [ + "chess", + "fire", + "heads", + "office", + "pumpkin", + "redkitchen", + "stairs", +] + +# Fixed camera intrinsics for 7Scenes (all images share same intrinsics) +SEVENSCENES_FX = 585.0 +SEVENSCENES_FY = 585.0 +SEVENSCENES_CX = 320.0 +SEVENSCENES_CY = 240.0 + +# TSDF fusion hyperparameters (indoor scenes, smaller voxels) +SEVENSCENES_VOXEL_LENGTH = 4.0 / 512.0 # Voxel size for TSDF (meters) +SEVENSCENES_SDF_TRUNC = 0.04 # SDF truncation distance (meters) +SEVENSCENES_MAX_DEPTH = 1000000.0 # Maximum depth for integration (no truncation) + +# Point cloud sampling +SEVENSCENES_SAMPLING_NUMBER = 1_000_000 # Number of points to sample from mesh + +# 3D reconstruction evaluation hyperparameters +SEVENSCENES_EVAL_THRESHOLD = 0.05 # Distance threshold for precision/recall (meters) +SEVENSCENES_DOWN_SAMPLE = 4.0 / 512.0 # Voxel size for evaluation downsampling (meters) + + +# ============================================================================== +# ScanNet++ Dataset Configuration +# ============================================================================== +# Reference: https://kaldir.vc.in.tum.de/scannetpp/ +# Note: High-quality indoor RGB-D dataset with iPhone and DSLR images. + +# Root directory for ScanNet++ evaluation data +SCANNETPP_EVAL_DATA_ROOT = "workspace/benchmark_dataset/scannetpp" + +# List of ScanNet++ evaluation scenes +SCANNETPP_SCENES = [ + "09c1414f1b", + "1ada7a0617", + "40aec5fffa", + "3e8bba0176", + "acd95847c5", + "578511c8a9", + "5f99900f09", + "c4c04e6d6c", + "f3d64c30f8", + "7bc286c1b6", + "c5439f4607", + "286b55a2bf", + "fb5a96b1a2", + "7831862f02", + "38d58a7a31", + "bde1e479ad", + "9071e139d9", + "21d970d8de", + "bcd2436daf", + "cc5237fd77", +] + +# Input resolution for ScanNet++ (after undistortion and resize) +SCANNETPP_INPUT_H = 768 +SCANNETPP_INPUT_W = 1024 + +# TSDF fusion hyperparameters (indoor scenes) +SCANNETPP_VOXEL_LENGTH = 0.02 # Voxel size for TSDF (meters) +SCANNETPP_SDF_TRUNC = 0.15 # SDF truncation distance (meters) +SCANNETPP_MAX_DEPTH = 5.0 # Maximum depth for integration (meters) + +# Point cloud sampling +SCANNETPP_SAMPLING_NUMBER = 1_000_000 # Number of points to sample from mesh + +# 3D reconstruction evaluation hyperparameters +SCANNETPP_EVAL_THRESHOLD = 0.05 # Distance threshold for precision/recall (meters) +SCANNETPP_DOWN_SAMPLE = 0.02 # Voxel size for evaluation downsampling (meters) + + +# ============================================================================== +# HiRoom Dataset Configuration +# ============================================================================== +# Note: Indoor RGB-D dataset. + +# Root directory for HiRoom evaluation data +HIROOM_EVAL_DATA_ROOT = "workspace/benchmark_dataset/hiroom/data" +HIROOM_GT_ROOT_PATH = "workspace/benchmark_dataset/hiroom/fused_pcd" +HIROOM_SCENE_LIST_PATH = "workspace/benchmark_dataset/hiroom/selected_scene_list_val.txt" + +# TSDF fusion hyperparameters (indoor scenes) +HIROOM_VOXEL_LENGTH = 4.0 / 512.0 # Voxel size for TSDF (meters) +HIROOM_SDF_TRUNC = 0.04 # SDF truncation distance (meters) +HIROOM_MAX_DEPTH = 10000.0 # Maximum depth for integration (no truncation) + +# Point cloud sampling +HIROOM_SAMPLING_NUMBER = 1_000_000 # Number of points to sample from mesh + +# 3D reconstruction evaluation hyperparameters +HIROOM_EVAL_THRESHOLD = 0.05 # Distance threshold for precision/recall (meters) +HIROOM_DOWN_SAMPLE = 4.0 / 512.0 # Voxel size for evaluation downsampling (meters) diff --git a/core/models/depth_anything_3/utils/export/__init__.py b/core/models/depth_anything_3/utils/export/__init__.py new file mode 100644 index 0000000..7db434c --- /dev/null +++ b/core/models/depth_anything_3/utils/export/__init__.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from core.models.depth_anything_3.specs import Prediction +from core.models.depth_anything_3.utils.export.gs import export_to_gs_ply, export_to_gs_video + +from .colmap import export_to_colmap +from .depth_vis import export_to_depth_vis +from .feat_vis import export_to_feat_vis +from .glb import export_to_glb +from .npz import export_to_mini_npz, export_to_npz + + +def export( + prediction: Prediction, + export_format: str, + export_dir: str, + **kwargs, +): + if "-" in export_format: + export_formats = export_format.split("-") + for export_format in export_formats: + export(prediction, export_format, export_dir, **kwargs) + return # Prevent falling through to single-format handling + + if export_format == "glb": + export_to_glb(prediction, export_dir, **kwargs.get(export_format, {})) + elif export_format == "mini_npz": + export_to_mini_npz(prediction, export_dir) + elif export_format == "npz": + export_to_npz(prediction, export_dir) + elif export_format == "feat_vis": + export_to_feat_vis(prediction, export_dir, **kwargs.get(export_format, {})) + elif export_format == "depth_vis": + export_to_depth_vis(prediction, export_dir) + elif export_format == "gs_ply": + export_to_gs_ply(prediction, export_dir, **kwargs.get(export_format, {})) + elif export_format == "gs_video": + export_to_gs_video(prediction, export_dir, **kwargs.get(export_format, {})) + elif export_format == "colmap": + export_to_colmap(prediction, export_dir, **kwargs.get(export_format, {})) + else: + raise ValueError(f"Unsupported export format: {export_format}") + + +__all__ = [ + export, +] diff --git a/core/models/depth_anything_3/utils/export/__pycache__/__init__.cpython-313.pyc b/core/models/depth_anything_3/utils/export/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..4ee1539 Binary files /dev/null and b/core/models/depth_anything_3/utils/export/__pycache__/__init__.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/export/__pycache__/colmap.cpython-313.pyc b/core/models/depth_anything_3/utils/export/__pycache__/colmap.cpython-313.pyc new file mode 100644 index 0000000..b70fa29 Binary files /dev/null and b/core/models/depth_anything_3/utils/export/__pycache__/colmap.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/export/__pycache__/depth_vis.cpython-313.pyc b/core/models/depth_anything_3/utils/export/__pycache__/depth_vis.cpython-313.pyc new file mode 100644 index 0000000..c16a923 Binary files /dev/null and b/core/models/depth_anything_3/utils/export/__pycache__/depth_vis.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/export/__pycache__/feat_vis.cpython-313.pyc b/core/models/depth_anything_3/utils/export/__pycache__/feat_vis.cpython-313.pyc new file mode 100644 index 0000000..4f6035d Binary files /dev/null and b/core/models/depth_anything_3/utils/export/__pycache__/feat_vis.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/export/__pycache__/glb.cpython-313.pyc b/core/models/depth_anything_3/utils/export/__pycache__/glb.cpython-313.pyc new file mode 100644 index 0000000..6254f7c Binary files /dev/null and b/core/models/depth_anything_3/utils/export/__pycache__/glb.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/export/__pycache__/gs.cpython-313.pyc b/core/models/depth_anything_3/utils/export/__pycache__/gs.cpython-313.pyc new file mode 100644 index 0000000..584f225 Binary files /dev/null and b/core/models/depth_anything_3/utils/export/__pycache__/gs.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/export/__pycache__/npz.cpython-313.pyc b/core/models/depth_anything_3/utils/export/__pycache__/npz.cpython-313.pyc new file mode 100644 index 0000000..e150686 Binary files /dev/null and b/core/models/depth_anything_3/utils/export/__pycache__/npz.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/export/colmap.py b/core/models/depth_anything_3/utils/export/colmap.py new file mode 100644 index 0000000..bcc00df --- /dev/null +++ b/core/models/depth_anything_3/utils/export/colmap.py @@ -0,0 +1,150 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pycolmap +import cv2 as cv +import numpy as np + +from PIL import Image + +from core.models.depth_anything_3.specs import Prediction +from core.models.depth_anything_3.utils.logger import logger + +from .glb import _depths_to_world_points_with_colors + + +def export_to_colmap( + prediction: Prediction, + export_dir: str, + image_paths: list[str], + conf_thresh_percentile: float = 40.0, + process_res_method: str = "upper_bound_resize", +) -> None: + # 1. Data preparation + conf_thresh = np.percentile(prediction.conf, conf_thresh_percentile) + points, colors = _depths_to_world_points_with_colors( + prediction.depth, + prediction.intrinsics, + prediction.extrinsics, # w2c + prediction.processed_images, + prediction.conf, + conf_thresh, + ) + num_points = len(points) + logger.info(f"Exporting to COLMAP with {num_points} points") + num_frames = len(prediction.processed_images) + h, w = prediction.processed_images.shape[1:3] + points_xyf = _create_xyf(num_frames, h, w) + points_xyf = points_xyf[prediction.conf >= conf_thresh] + + # 2. Set Reconstruction + reconstruction = pycolmap.Reconstruction() + + point3d_ids = [] + for vidx in range(num_points): + point3d_id = reconstruction.add_point3D(points[vidx], pycolmap.Track(), colors[vidx]) + point3d_ids.append(point3d_id) + + for fidx in range(num_frames): + orig_w, orig_h = Image.open(image_paths[fidx]).size + + intrinsic = prediction.intrinsics[fidx] + if process_res_method.endswith("resize"): + intrinsic[:1] *= orig_w / w + intrinsic[1:2] *= orig_h / h + elif process_res_method == "crop": + raise NotImplementedError("COLMAP export for crop method is not implemented") + else: + raise ValueError(f"Unknown process_res_method: {process_res_method}") + + pycolmap_intri = np.array( + [intrinsic[0, 0], intrinsic[1, 1], intrinsic[0, 2], intrinsic[1, 2]] + ) + + extrinsic = prediction.extrinsics[fidx] + cam_from_world = pycolmap.Rigid3d(pycolmap.Rotation3d(extrinsic[:3, :3]), extrinsic[:3, 3]) + + # set and add camera + camera = pycolmap.Camera() + camera.camera_id = fidx + 1 + camera.model = pycolmap.CameraModelId.PINHOLE + camera.width = orig_w + camera.height = orig_h + camera.params = pycolmap_intri + reconstruction.add_camera(camera) + + # set and add rig (from camera) + rig = pycolmap.Rig() + rig.rig_id = camera.camera_id + rig.add_ref_sensor(camera.sensor_id) + reconstruction.add_rig(rig) + + # set image + image = pycolmap.Image() + image.image_id = fidx + 1 + image.camera_id = camera.camera_id + + # set and add frame (from image) + frame = pycolmap.Frame() + frame.frame_id = image.image_id + frame.rig_id = camera.camera_id + frame.add_data_id(image.data_id) + frame.rig_from_world = cam_from_world + reconstruction.add_frame(frame) + + # set point2d and update track + point2d_list = [] + points_in_frame = points_xyf[:, 2].astype(np.int32) == fidx + for vidx in np.where(points_in_frame)[0]: + point2d = points_xyf[vidx][:2] + point2d[0] *= orig_w / w + point2d[1] *= orig_h / h + point3d_id = point3d_ids[vidx] + point2d_list.append(pycolmap.Point2D(point2d, point3d_id)) + reconstruction.point3D(point3d_id).track.add_element( + image.image_id, len(point2d_list) - 1 + ) + + # set and add image + image.frame_id = image.image_id + image.name = os.path.basename(image_paths[fidx]) + image.points2D = pycolmap.Point2DList(point2d_list) + reconstruction.add_image(image) + + # 3. Export + reconstruction.write(export_dir) + + +def _create_xyf(num_frames, height, width): + """ + Creates a grid of pixel coordinates and frame indices (fidx) for all frames. + """ + # Create coordinate grids for a single frame + y_grid, x_grid = np.indices((height, width), dtype=np.int32) + x_grid = x_grid[np.newaxis, :, :] + y_grid = y_grid[np.newaxis, :, :] + + # Broadcast to all frames + x_coords = np.broadcast_to(x_grid, (num_frames, height, width)) + y_coords = np.broadcast_to(y_grid, (num_frames, height, width)) + + # Create frame indices and broadcast + f_idx = np.arange(num_frames, dtype=np.int32)[:, np.newaxis, np.newaxis] + f_coords = np.broadcast_to(f_idx, (num_frames, height, width)) + + # Stack coordinates and frame indices + points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1) + + return points_xyf diff --git a/core/models/depth_anything_3/utils/export/depth_vis.py b/core/models/depth_anything_3/utils/export/depth_vis.py new file mode 100644 index 0000000..58c7dd5 --- /dev/null +++ b/core/models/depth_anything_3/utils/export/depth_vis.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import imageio +import numpy as np + +from core.models.depth_anything_3.specs import Prediction +from core.models.depth_anything_3.utils.visualize import visualize_depth + + +def export_to_depth_vis( + prediction: Prediction, + export_dir: str, +): + # Use prediction.processed_images, which is already processed image data + if prediction.processed_images is None: + raise ValueError("prediction.processed_images is required but not available") + + images_u8 = prediction.processed_images # (N,H,W,3) uint8 + + os.makedirs(os.path.join(export_dir, "depth_vis"), exist_ok=True) + for idx in range(prediction.depth.shape[0]): + depth_vis = visualize_depth(prediction.depth[idx]) + image_vis = images_u8[idx] + depth_vis = depth_vis.astype(np.uint8) + image_vis = image_vis.astype(np.uint8) + vis_image = np.concatenate([image_vis, depth_vis], axis=1) + save_path = os.path.join(export_dir, f"depth_vis/{idx:04d}.jpg") + imageio.imwrite(save_path, vis_image, quality=95) diff --git a/core/models/depth_anything_3/utils/export/feat_vis.py b/core/models/depth_anything_3/utils/export/feat_vis.py new file mode 100644 index 0000000..15dd7e2 --- /dev/null +++ b/core/models/depth_anything_3/utils/export/feat_vis.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import imageio +import numpy as np +from tqdm.auto import tqdm + +from core.models.depth_anything_3.utils.parallel_utils import async_call +from core.models.depth_anything_3.utils.pca_utils import PCARGBVisualizer + + +@async_call +def export_to_feat_vis( + prediction, + export_dir, + fps=15, +): + """Export feature visualization with PCA. + + Args: + prediction: Model prediction containing feature maps + export_dir: Directory to export results + fps: Frame rate for output video (default: 15) + """ + out_dir = os.path.join(export_dir, "feat_vis") + os.makedirs(out_dir, exist_ok=True) + + images = prediction.processed_images + for k, v in prediction.aux.items(): + if not k.startswith("feat_layer_"): + continue + os.makedirs(os.path.join(out_dir, k), exist_ok=True) + viz = PCARGBVisualizer(basis_mode="fixed", percentile_mode="global", clip_percent=10.0) + viz.fit_reference(v) + feats_vis = viz.transform_video(v) + for idx in tqdm(range(len(feats_vis))): + img = images[idx] + feat_vis = (feats_vis[idx] * 255).astype(np.uint8) + feat_vis = cv2.resize( + feat_vis, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST + ) + save_path = os.path.join(out_dir, f"{k}/{idx:06d}.jpg") + save = np.concatenate([img, feat_vis], axis=1) + imageio.imwrite(save_path, save, quality=95) + cmd = ( + "ffmpeg -loglevel error -hide_banner -y " + f"-framerate {fps} -start_number 0 " + f"-i {out_dir}/{k}/%06d.jpg " + f"-c:v libx264 -pix_fmt yuv420p " + f"{out_dir}/{k}.mp4" + ) + os.system(cmd) diff --git a/core/models/depth_anything_3/utils/export/glb.py b/core/models/depth_anything_3/utils/export/glb.py new file mode 100644 index 0000000..ac30561 --- /dev/null +++ b/core/models/depth_anything_3/utils/export/glb.py @@ -0,0 +1,432 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import numpy as np +import trimesh + +from core.models.depth_anything_3.specs import Prediction +from core.models.depth_anything_3.utils.logger import logger + +from .depth_vis import export_to_depth_vis + + +def set_sky_depth(prediction: Prediction, sky_mask: np.ndarray, sky_depth_def: float = 98.0): + non_sky_mask = ~sky_mask + valid_depth = prediction.depth[non_sky_mask] + if valid_depth.size > 0: + max_depth = np.percentile(valid_depth, sky_depth_def) + prediction.depth[sky_mask] = max_depth + + +def get_conf_thresh( + prediction: Prediction, + sky_mask: np.ndarray, + conf_thresh: float, + conf_thresh_percentile: float = 10.0, + ensure_thresh_percentile: float = 90.0, +): + if sky_mask is not None and (~sky_mask).sum() > 10: + conf_pixels = prediction.conf[~sky_mask] + else: + conf_pixels = prediction.conf + lower = np.percentile(conf_pixels, conf_thresh_percentile) + upper = np.percentile(conf_pixels, ensure_thresh_percentile) + conf_thresh = min(max(conf_thresh, lower), upper) + return conf_thresh + + +def export_to_glb( + prediction: Prediction, + export_dir: str, + num_max_points: int = 1_000_000, + conf_thresh: float = 1.05, + filter_black_bg: bool = False, + filter_white_bg: bool = False, + conf_thresh_percentile: float = 40.0, + ensure_thresh_percentile: float = 90.0, + sky_depth_def: float = 98.0, + show_cameras: bool = True, + camera_size: float = 0.03, + export_depth_vis: bool = True, +) -> str: + """Generate a 3D point cloud and camera wireframes and export them as a ``.glb`` file. + + The function builds a point cloud from the predicted depth maps, aligns it to the + first camera in glTF coordinates (X-right, Y-up, Z-backward), optionally draws + camera wireframes, and writes the result to ``scene.glb``. Auxiliary assets such as + depth visualizations can also be generated alongside the main export. + + Args: + prediction: Model prediction containing depth, confidence, intrinsics, extrinsics, + and pre-processed images. + export_dir: Output directory where the glTF assets will be written. + num_max_points: Maximum number of points retained after downsampling. + conf_thresh: Base confidence threshold used before percentile adjustments. + filter_black_bg: Mark near-black background pixels for removal during confidence filtering. + filter_white_bg: Mark near-white background pixels for removal during confidence filtering. + conf_thresh_percentile: Lower percentile used when adapting the confidence threshold. + ensure_thresh_percentile: Upper percentile clamp for the adaptive threshold. + sky_depth_def: Percentile used to fill sky pixels with plausible depth values. + show_cameras: Whether to render camera wireframes in the exported scene. + camera_size: Relative camera wireframe scale as a fraction of the scene diagonal. + export_depth_vis: Whether to export raster depth visualisations alongside the glTF. + + Returns: + Path to the exported ``scene.glb`` file. + """ + # 1) Use prediction.processed_images, which is already processed image data + assert ( + prediction.processed_images is not None + ), "Export to GLB: prediction.processed_images is required but not available" + assert ( + prediction.depth is not None + ), "Export to GLB: prediction.depth is required but not available" + assert ( + prediction.intrinsics is not None + ), "Export to GLB: prediction.intrinsics is required but not available" + assert ( + prediction.extrinsics is not None + ), "Export to GLB: prediction.extrinsics is required but not available" + assert ( + prediction.conf is not None + ), "Export to GLB: prediction.conf is required but not available" + logger.info(f"conf_thresh_percentile: {conf_thresh_percentile}") + logger.info(f"num max points: {num_max_points}") + logger.info(f"Exporting to GLB with num_max_points: {num_max_points}") + if prediction.processed_images is None: + raise ValueError("prediction.processed_images is required but not available") + + images_u8 = prediction.processed_images # (N,H,W,3) uint8 + + # 2) Sky processing (if sky_mask is provided) + if getattr(prediction, "sky_mask", None) is not None: + set_sky_depth(prediction, prediction.sky_mask, sky_depth_def) + + # 3) Confidence threshold (if no conf, then no filtering) + if filter_black_bg: + prediction.conf[(prediction.processed_images < 16).all(axis=-1)] = 1.0 + if filter_white_bg: + prediction.conf[(prediction.processed_images >= 240).all(axis=-1)] = 1.0 + conf_thr = get_conf_thresh( + prediction, + getattr(prediction, "sky_mask", None), + conf_thresh, + conf_thresh_percentile, + ensure_thresh_percentile, + ) + + # 4) Back-project to world coordinates and get colors (world frame) + points, colors = _depths_to_world_points_with_colors( + prediction.depth, + prediction.intrinsics, + prediction.extrinsics, # w2c + images_u8, + prediction.conf, + conf_thr, + ) + + # 5) Based on first camera orientation + glTF axis system, center by point cloud, + # construct alignment transform, and apply to point cloud + A = _compute_alignment_transform_first_cam_glTF_center_by_points( + prediction.extrinsics[0], points + ) # (4,4) + + if points.shape[0] > 0: + points = trimesh.transform_points(points, A) + + # 6) Clean + downsample + points, colors = _filter_and_downsample(points, colors, num_max_points) + + # 7) Assemble scene (add point cloud first) + scene = trimesh.Scene() + if scene.metadata is None: + scene.metadata = {} + scene.metadata["hf_alignment"] = A # For camera wireframes and external reuse + + if points.shape[0] > 0: + pc = trimesh.points.PointCloud(vertices=points, colors=colors) + scene.add_geometry(pc) + + # 8) Draw cameras (wireframe pyramids), using the same transform A + if show_cameras and prediction.intrinsics is not None and prediction.extrinsics is not None: + scene_scale = _estimate_scene_scale(points, fallback=1.0) + H, W = prediction.depth.shape[1:] + _add_cameras_to_scene( + scene=scene, + K=prediction.intrinsics, + ext_w2c=prediction.extrinsics, + image_sizes=[(H, W)] * prediction.depth.shape[0], + scale=scene_scale * camera_size, + ) + + # 9) Export + os.makedirs(export_dir, exist_ok=True) + out_path = os.path.join(export_dir, "scene.glb") + scene.export(out_path) + + if export_depth_vis: + export_to_depth_vis(prediction, export_dir) + os.system(f"cp -r {export_dir}/depth_vis/0000.jpg {export_dir}/scene.jpg") + return out_path + + +# ========================= +# utilities +# ========================= + + +def _as_homogeneous44(ext: np.ndarray) -> np.ndarray: + """ + Accept (4,4) or (3,4) extrinsic parameters, return (4,4) homogeneous matrix. + """ + if ext.shape == (4, 4): + return ext + if ext.shape == (3, 4): + H = np.eye(4, dtype=ext.dtype) + H[:3, :4] = ext + return H + raise ValueError(f"extrinsic must be (4,4) or (3,4), got {ext.shape}") + + +def _depths_to_world_points_with_colors( + depth: np.ndarray, + K: np.ndarray, + ext_w2c: np.ndarray, + images_u8: np.ndarray, + conf: np.ndarray | None, + conf_thr: float, +) -> tuple[np.ndarray, np.ndarray]: + """ + For each frame, transform (u,v,1) through K^{-1} to get rays, + multiply by depth to camera frame, then use (w2c)^{-1} to transform to world frame. + Simultaneously extract colors. + """ + N, H, W = depth.shape + us, vs = np.meshgrid(np.arange(W), np.arange(H)) + ones = np.ones_like(us) + pix = np.stack([us, vs, ones], axis=-1).reshape(-1, 3) # (H*W,3) + + pts_all, col_all = [], [] + + for i in range(N): + d = depth[i] # (H,W) + valid = np.isfinite(d) & (d > 0) + if conf is not None: + valid &= conf[i] >= conf_thr + if not np.any(valid): + continue + + d_flat = d.reshape(-1) + vidx = np.flatnonzero(valid.reshape(-1)) + + K_inv = np.linalg.inv(K[i]) # (3,3) + c2w = np.linalg.inv(_as_homogeneous44(ext_w2c[i])) # (4,4) + + rays = K_inv @ pix[vidx].T # (3,M) + Xc = rays * d_flat[vidx][None, :] # (3,M) + Xc_h = np.vstack([Xc, np.ones((1, Xc.shape[1]))]) + Xw = (c2w @ Xc_h)[:3].T.astype(np.float32) # (M,3) + + cols = images_u8[i].reshape(-1, 3)[vidx].astype(np.uint8) # (M,3) + + pts_all.append(Xw) + col_all.append(cols) + + if len(pts_all) == 0: + return np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.uint8) + + return np.concatenate(pts_all, 0), np.concatenate(col_all, 0) + + +def _filter_and_downsample(points: np.ndarray, colors: np.ndarray, num_max: int): + if points.shape[0] == 0: + return points, colors + finite = np.isfinite(points).all(axis=1) + points, colors = points[finite], colors[finite] + if points.shape[0] > num_max: + idx = np.random.choice(points.shape[0], num_max, replace=False) + points, colors = points[idx], colors[idx] + return points, colors + + +def _estimate_scene_scale(points: np.ndarray, fallback: float = 1.0) -> float: + if points.shape[0] < 2: + return fallback + lo = np.percentile(points, 5, axis=0) + hi = np.percentile(points, 95, axis=0) + diag = np.linalg.norm(hi - lo) + return float(diag if np.isfinite(diag) and diag > 0 else fallback) + + +def _compute_alignment_transform_first_cam_glTF_center_by_points( + ext_w2c0: np.ndarray, + points_world: np.ndarray, +) -> np.ndarray: + """Computes the transformation matrix to align the scene with glTF standards. + + This function calculates a 4x4 homogeneous matrix that centers the scene's + point cloud and transforms its coordinate system from the computer vision (CV) + standard to the glTF standard. + + The transformation process involves three main steps: + 1. **Initial Alignment**: Orients the world coordinate system to match the + first camera's view (x-right, y-down, z-forward). + 2. **Coordinate System Conversion**: Converts the CV camera frame to the + glTF frame (x-right, y-up, z-backward) by flipping the Y and Z axes. + 3. **Centering**: Translates the entire scene so that the median of the + point cloud becomes the new origin (0,0,0). + + Returns: + A 4x4 homogeneous transformation matrix (torch.Tensor or np.ndarray) + that applies these transformations. A: X' = A @ [X;1] + """ + + w2c0 = _as_homogeneous44(ext_w2c0).astype(np.float64) + + # CV -> glTF axis transformation + M = np.eye(4, dtype=np.float64) + M[1, 1] = -1.0 # flip Y + M[2, 2] = -1.0 # flip Z + + # Don't center first + A_no_center = M @ w2c0 + + # Calculate point cloud center in new coordinate system (use median to resist outliers) + if points_world.shape[0] > 0: + pts_tmp = trimesh.transform_points(points_world, A_no_center) + center = np.median(pts_tmp, axis=0) + else: + center = np.zeros(3, dtype=np.float64) + + T_center = np.eye(4, dtype=np.float64) + T_center[:3, 3] = -center + + A = T_center @ A_no_center + return A + + +def _add_cameras_to_scene( + scene: trimesh.Scene, + K: np.ndarray, + ext_w2c: np.ndarray, + image_sizes: list[tuple[int, int]], + scale: float, +) -> None: + """Draws camera frustums to visualize their position and orientation. + + This function renders each camera as a wireframe pyramid, originating from + the camera's center and extending to the corners of its imaging plane. + + It reads the 'hf_alignment' metadata from the scene to ensure the + wireframes are correctly aligned with the 3D point cloud. + """ + N = K.shape[0] + if N == 0: + return + + # Alignment matrix consistent with point cloud (use identity matrix if missing) + A = None + try: + A = scene.metadata.get("hf_alignment", None) if scene.metadata else None + except Exception: + A = None + if A is None: + A = np.eye(4, dtype=np.float64) + + for i in range(N): + H, W = image_sizes[i] + segs = _camera_frustum_lines(K[i], ext_w2c[i], W, H, scale) # (8,2,3) world frame + # Apply unified transformation + segs = trimesh.transform_points(segs.reshape(-1, 3), A).reshape(-1, 2, 3) + path = trimesh.load_path(segs) + color = _index_color_rgb(i, N) + if hasattr(path, "colors"): + path.colors = np.tile(color, (len(path.entities), 1)) + scene.add_geometry(path) + + +def _camera_frustum_lines( + K: np.ndarray, ext_w2c: np.ndarray, W: int, H: int, scale: float +) -> np.ndarray: + corners = np.array( + [ + [0, 0, 1.0], + [W - 1, 0, 1.0], + [W - 1, H - 1, 1.0], + [0, H - 1, 1.0], + ], + dtype=float, + ) # (4,3) + + K_inv = np.linalg.inv(K) + c2w = np.linalg.inv(_as_homogeneous44(ext_w2c)) + + # camera center in world + Cw = (c2w @ np.array([0, 0, 0, 1.0]))[:3] + + # rays -> z=1 plane points (camera frame) + rays = (K_inv @ corners.T).T + z = rays[:, 2:3] + z[z == 0] = 1.0 + plane_cam = (rays / z) * scale # (4,3) + + # to world + plane_w = [] + for p in plane_cam: + pw = (c2w @ np.array([p[0], p[1], p[2], 1.0]))[:3] + plane_w.append(pw) + plane_w = np.stack(plane_w, 0) # (4,3) + + segs = [] + # center to corners + for k in range(4): + segs.append(np.stack([Cw, plane_w[k]], 0)) + # rectangle edges + order = [0, 1, 2, 3, 0] + for a, b in zip(order[:-1], order[1:]): + segs.append(np.stack([plane_w[a], plane_w[b]], 0)) + + return np.stack(segs, 0) # (8,2,3) + + +def _index_color_rgb(i: int, n: int) -> np.ndarray: + h = (i + 0.5) / max(n, 1) + s, v = 0.85, 0.95 + r, g, b = _hsv_to_rgb(h, s, v) + return (np.array([r, g, b]) * 255).astype(np.uint8) + + +def _hsv_to_rgb(h: float, s: float, v: float) -> tuple[float, float, float]: + i = int(h * 6.0) + f = h * 6.0 - i + p = v * (1.0 - s) + q = v * (1.0 - f * s) + t = v * (1.0 - (1.0 - f) * s) + i = i % 6 + if i == 0: + r, g, b = v, t, p + elif i == 1: + r, g, b = q, v, p + elif i == 2: + r, g, b = p, v, t + elif i == 3: + r, g, b = p, q, v + elif i == 4: + r, g, b = t, p, v + else: + r, g, b = v, p, q + return r, g, b diff --git a/core/models/depth_anything_3/utils/export/gs.py b/core/models/depth_anything_3/utils/export/gs.py new file mode 100644 index 0000000..82aa3d7 --- /dev/null +++ b/core/models/depth_anything_3/utils/export/gs.py @@ -0,0 +1,158 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Literal, Optional +try: + import moviepy.editor as mpy +except Exception: + mpy = None + +import torch + +from core.models.depth_anything_3.model.utils.gs_renderer import run_renderer_in_chunk_w_trj_mode +from core.models.depth_anything_3.specs import Prediction +from core.models.depth_anything_3.utils.gsply_helpers import save_gaussian_ply +from core.models.depth_anything_3.utils.layout_helpers import hcat, vcat +from core.models.depth_anything_3.utils.visualize import vis_depth_map_tensor + +VIDEO_QUALITY_MAP = { + "low": {"crf": "28", "preset": "veryfast"}, + "medium": {"crf": "23", "preset": "medium"}, + "high": {"crf": "18", "preset": "slow"}, +} + + +def export_to_gs_ply( + prediction: Prediction, + export_dir: str, + gs_views_interval: Optional[ + int + ] = 1, # export GS every N views, useful for extremely dense inputs +): + gs_world = prediction.gaussians + pred_depth = torch.from_numpy(prediction.depth).unsqueeze(-1).to(gs_world.means) # v h w 1 + idx = 0 + os.makedirs(os.path.join(export_dir, "gs_ply"), exist_ok=True) + save_path = os.path.join(export_dir, f"gs_ply/{idx:04d}.ply") + if gs_views_interval is None: # select around 12 views in total + gs_views_interval = max(pred_depth.shape[0] // 12, 1) + save_gaussian_ply( + gaussians=gs_world, + save_path=save_path, + ctx_depth=pred_depth, + shift_and_scale=False, + save_sh_dc_only=True, + gs_views_interval=gs_views_interval, + inv_opacity=True, + prune_by_depth_percent=0.9, + prune_border_gs=True, + match_3dgs_mcmc_dev=False, + ) + + +def export_to_gs_video( + prediction: Prediction, + export_dir: str, + extrinsics: Optional[torch.Tensor] = None, # render views' world2cam, "b v 4 4" + intrinsics: Optional[torch.Tensor] = None, # render views' unnormed intrinsics, "b v 3 3" + out_image_hw: Optional[tuple[int, int]] = None, # render views' resolution, (h, w) + chunk_size: Optional[int] = 4, + trj_mode: Literal[ + "original", + "smooth", + "interpolate", + "interpolate_smooth", + "wander", + "dolly_zoom", + "extend", + "wobble_inter", + ] = "extend", + color_mode: Literal["RGB+D", "RGB+ED"] = "RGB+ED", + vis_depth: Optional[Literal["hcat", "vcat"]] = "hcat", + enable_tqdm: Optional[bool] = True, + output_name: Optional[str] = None, + video_quality: Literal["low", "medium", "high"] = "high", +) -> None: + gs_world = prediction.gaussians + # if target poses are not provided, render the (smooth/interpolate) input poses + if extrinsics is not None: + tgt_extrs = extrinsics + else: + tgt_extrs = torch.from_numpy(prediction.extrinsics).unsqueeze(0).to(gs_world.means) + if prediction.is_metric: + scale_factor = prediction.scale_factor + if scale_factor is not None: + tgt_extrs[:, :, :3, 3] /= scale_factor + tgt_intrs = ( + intrinsics + if intrinsics is not None + else torch.from_numpy(prediction.intrinsics).unsqueeze(0).to(gs_world.means) + ) + # if render resolution is not provided, render the input ones + if out_image_hw is not None: + H, W = out_image_hw + else: + H, W = prediction.depth.shape[-2:] + # if single views, render wander trj + if tgt_extrs.shape[1] <= 1: + trj_mode = "wander" + # trj_mode = "dolly_zoom" + + color, depth = run_renderer_in_chunk_w_trj_mode( + gaussians=gs_world, + extrinsics=tgt_extrs, + intrinsics=tgt_intrs, + image_shape=(H, W), + chunk_size=chunk_size, + trj_mode=trj_mode, + use_sh=True, + color_mode=color_mode, + enable_tqdm=enable_tqdm, + ) + + # save as video + ffmpeg_params = [ + "-crf", + VIDEO_QUALITY_MAP[video_quality]["crf"], + "-preset", + VIDEO_QUALITY_MAP[video_quality]["preset"], + "-pix_fmt", + "yuv420p", + ] # best compatibility + + os.makedirs(os.path.join(export_dir, "gs_video"), exist_ok=True) + for idx in range(color.shape[0]): + video_i = color[idx] + if vis_depth is not None: + depth_i = vis_depth_map_tensor(depth[0]) + cat_fn = hcat if vis_depth == "hcat" else vcat + video_i = torch.stack([cat_fn(c, d) for c, d in zip(video_i, depth_i)]) + frames = list( + (video_i.clamp(0, 1) * 255).byte().permute(0, 2, 3, 1).cpu().numpy() + ) # T x H x W x C, uint8, numpy() + + fps = 24 + clip = mpy.ImageSequenceClip(frames, fps=fps) + output_name = f"{idx:04d}_{trj_mode}" if output_name is None else output_name + save_path = os.path.join(export_dir, f"gs_video/{output_name}.mp4") + # clip.write_videofile(save_path, codec="libx264", audio=False, bitrate="4000k") + clip.write_videofile( + save_path, + codec="libx264", + audio=False, + fps=fps, + ffmpeg_params=ffmpeg_params, + ) + return diff --git a/core/models/depth_anything_3/utils/export/npz.py b/core/models/depth_anything_3/utils/export/npz.py new file mode 100644 index 0000000..6b9fb43 --- /dev/null +++ b/core/models/depth_anything_3/utils/export/npz.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np + +from core.models.depth_anything_3.specs import Prediction +from core.models.depth_anything_3.utils.parallel_utils import async_call + + +@async_call +def export_to_npz( + prediction: Prediction, + export_dir: str, +): + output_file = os.path.join(export_dir, "exports", "npz", "results.npz") + os.makedirs(os.path.dirname(output_file), exist_ok=True) + + # Use prediction.processed_images, which is already processed image data + if prediction.processed_images is None: + raise ValueError("prediction.processed_images is required but not available") + + image = prediction.processed_images # (N,H,W,3) uint8 + + # Build save dict with only non-None values + save_dict = { + "image": image, + "depth": np.round(prediction.depth, 6), + } + + if prediction.conf is not None: + save_dict["conf"] = np.round(prediction.conf, 2) + if prediction.extrinsics is not None: + save_dict["extrinsics"] = prediction.extrinsics + if prediction.intrinsics is not None: + save_dict["intrinsics"] = prediction.intrinsics + + # aux = {k: np.round(v, 4) for k, v in prediction.aux.items()} + np.savez_compressed(output_file, **save_dict) + + +@async_call +def export_to_mini_npz( + prediction: Prediction, + export_dir: str, +): + output_file = os.path.join(export_dir, "exports", "mini_npz", "results.npz") + os.makedirs(os.path.dirname(output_file), exist_ok=True) + + # Build save dict with only non-None values + save_dict = { + "depth": np.round(prediction.depth, 8), + } + + if prediction.conf is not None: + save_dict["conf"] = np.round(prediction.conf, 2) + if prediction.extrinsics is not None: + save_dict["extrinsics"] = prediction.extrinsics + if prediction.intrinsics is not None: + save_dict["intrinsics"] = prediction.intrinsics + + np.savez_compressed(output_file, **save_dict) diff --git a/core/models/depth_anything_3/utils/export/utils.py b/core/models/depth_anything_3/utils/export/utils.py new file mode 100644 index 0000000..81f45fb --- /dev/null +++ b/core/models/depth_anything_3/utils/export/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + + +def _denorm_and_to_uint8(image_tensor: torch.Tensor) -> np.ndarray: + """Denormalize to [0,255] and output (N, H, W, 3) uint8.""" + resnet_mean = torch.tensor( + [0.485, 0.456, 0.406], dtype=image_tensor.dtype, device=image_tensor.device + ) + resnet_std = torch.tensor( + [0.229, 0.224, 0.225], dtype=image_tensor.dtype, device=image_tensor.device + ) + img = image_tensor * resnet_std[None, :, None, None] + resnet_mean[None, :, None, None] + img = torch.clamp(img, 0.0, 1.0) + img = (img.permute(0, 2, 3, 1).cpu().numpy() * 255.0).round().astype(np.uint8) # (N,H,W,3) + return img diff --git a/core/models/depth_anything_3/utils/geometry.py b/core/models/depth_anything_3/utils/geometry.py new file mode 100644 index 0000000..41a4219 --- /dev/null +++ b/core/models/depth_anything_3/utils/geometry.py @@ -0,0 +1,498 @@ +# flake8: noqa: F722 +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from types import SimpleNamespace +from typing import Optional +import numpy as np +import torch +import torch.nn.functional as F +from einops import einsum + + +def as_homogeneous(ext): + """ + Accept (..., 3,4) or (..., 4,4) extrinsics, return (...,4,4) homogeneous matrix. + Supports torch.Tensor or np.ndarray. + """ + if isinstance(ext, torch.Tensor): + # If already in homogeneous form + if ext.shape[-2:] == (4, 4): + return ext + elif ext.shape[-2:] == (3, 4): + # Create a new homogeneous matrix + ones = torch.zeros_like(ext[..., :1, :4]) + ones[..., 0, 3] = 1.0 + return torch.cat([ext, ones], dim=-2) + else: + raise ValueError(f"Invalid shape for torch.Tensor: {ext.shape}") + + elif isinstance(ext, np.ndarray): + if ext.shape[-2:] == (4, 4): + return ext + elif ext.shape[-2:] == (3, 4): + ones = np.zeros_like(ext[..., :1, :4]) + ones[..., 0, 3] = 1.0 + return np.concatenate([ext, ones], axis=-2) + else: + raise ValueError(f"Invalid shape for np.ndarray: {ext.shape}") + + else: + raise TypeError("Input must be a torch.Tensor or np.ndarray.") + + +@torch.jit.script +def affine_inverse(A: torch.Tensor): + R = A[..., :3, :3] # ..., 3, 3 + T = A[..., :3, 3:] # ..., 3, 1 + P = A[..., 3:, :] # ..., 1, 4 + return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2) + + +def transpose_last_two_axes(arr): + """ + for np < 2 + """ + if arr.ndim < 2: + return arr + axes = list(range(arr.ndim)) + # swap the last two + axes[-2], axes[-1] = axes[-1], axes[-2] + return arr.transpose(axes) + + +def affine_inverse_np(A: np.ndarray): + R = A[..., :3, :3] + T = A[..., :3, 3:] + P = A[..., 3:, :] + return np.concatenate( + [ + np.concatenate([transpose_last_two_axes(R), -transpose_last_two_axes(R) @ T], axis=-1), + P, + ], + axis=-2, + ) + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape( + batch_dim + (4,) + ) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) + + +def sample_image_grid( + shape: tuple[int, ...], + device: torch.device = torch.device("cpu"), +) -> tuple[ + torch.Tensor, # float coordinates (xy indexing), "*shape dim" + torch.Tensor, # integer indices (ij indexing), "*shape dim" +]: + """Get normalized (range 0 to 1) coordinates and integer indices for an image.""" + + # Each entry is a pixel-wise integer coordinate. In the 2D case, each entry is a + # (row, col) coordinate. + indices = [torch.arange(length, device=device) for length in shape] + stacked_indices = torch.stack(torch.meshgrid(*indices, indexing="ij"), dim=-1) + + # Each entry is a floating-point coordinate in the range (0, 1). In the 2D case, + # each entry is an (x, y) coordinate. + coordinates = [(idx + 0.5) / length for idx, length in zip(indices, shape)] + coordinates = reversed(coordinates) + coordinates = torch.stack(torch.meshgrid(*coordinates, indexing="xy"), dim=-1) + + return coordinates, stacked_indices + + +def homogenize_points(points: torch.Tensor) -> torch.Tensor: # "*batch dim" # "*batch dim+1" + """Convert batched points (xyz) to (xyz1).""" + return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + + +def homogenize_vectors(vectors: torch.Tensor) -> torch.Tensor: # "*batch dim" # "*batch dim+1" + """Convert batched vectors (xyz) to (xyz0).""" + return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1) + + +def transform_rigid( + homogeneous_coordinates: torch.Tensor, # "*#batch dim" + transformation: torch.Tensor, # "*#batch dim dim" +) -> torch.Tensor: # "*batch dim" + """Apply a rigid-body transformation to points or vectors.""" + return einsum( + transformation, + homogeneous_coordinates.to(transformation.dtype), + "... i j, ... j -> ... i", + ) + + +def transform_cam2world( + homogeneous_coordinates: torch.Tensor, # "*#batch dim" + extrinsics: torch.Tensor, # "*#batch dim dim" +) -> torch.Tensor: # "*batch dim" + """Transform points from 3D camera coordinates to 3D world coordinates.""" + return transform_rigid(homogeneous_coordinates, extrinsics) + + +def unproject( + coordinates: torch.Tensor, # "*#batch dim" + z: torch.Tensor, # "*#batch" + intrinsics: torch.Tensor, # "*#batch dim+1 dim+1" +) -> torch.Tensor: # "*batch dim+1" + """Unproject 2D camera coordinates with the given Z values.""" + + # Apply the inverse intrinsics to the coordinates. + coordinates = homogenize_points(coordinates) + ray_directions = einsum( + intrinsics.float().inverse().to(intrinsics), + coordinates.to(intrinsics.dtype), + "... i j, ... j -> ... i", + ) + + # Apply the supplied depth values. + return ray_directions * z[..., None] + + +def get_world_rays( + coordinates: torch.Tensor, # "*#batch dim" + extrinsics: torch.Tensor, # "*#batch dim+2 dim+2" + intrinsics: torch.Tensor, # "*#batch dim+1 dim+1" +) -> tuple[ + torch.Tensor, # origins, "*batch dim+1" + torch.Tensor, # directions, "*batch dim+1" +]: + # Get camera-space ray directions. + directions = unproject( + coordinates, + torch.ones_like(coordinates[..., 0]), + intrinsics, + ) + directions = directions / directions.norm(dim=-1, keepdim=True) + + # Transform ray directions to world coordinates. + directions = homogenize_vectors(directions) + directions = transform_cam2world(directions, extrinsics)[..., :-1] + + # Tile the ray origins to have the same shape as the ray directions. + origins = extrinsics[..., :-1, -1].broadcast_to(directions.shape) + + return origins, directions + + +def get_fov(intrinsics: torch.Tensor) -> torch.Tensor: # "batch 3 3" -> "batch 2" + intrinsics_inv = intrinsics.float().inverse().to(intrinsics) + + def process_vector(vector): + vector = torch.tensor(vector, dtype=intrinsics.dtype, device=intrinsics.device) + vector = einsum(intrinsics_inv, vector, "b i j, j -> b i") + return vector / vector.norm(dim=-1, keepdim=True) + + left = process_vector([0, 0.5, 1]) + right = process_vector([1, 0.5, 1]) + top = process_vector([0.5, 0, 1]) + bottom = process_vector([0.5, 1, 1]) + fov_x = (left * right).sum(dim=-1).acos() + fov_y = (top * bottom).sum(dim=-1).acos() + return torch.stack((fov_x, fov_y), dim=-1) + + +def map_pdf_to_opacity( + pdf: torch.Tensor, # " *batch" + global_step: int = 0, + opacity_mapping: Optional[dict] = None, +) -> torch.Tensor: # " *batch" + # https://www.desmos.com/calculator/opvwti3ba9 + + # Figure out the exponent. + if opacity_mapping is not None: + cfg = SimpleNamespace(**opacity_mapping) + x = cfg.initial + min(global_step / cfg.warm_up, 1) * (cfg.final - cfg.initial) + else: + x = 0.0 + exponent = 2**x + + # Map the probability density to an opacity. + return 0.5 * (1 - (1 - pdf) ** exponent + pdf ** (1 / exponent)) + +def normalize_homogenous_points(points): + """Normalize the point vectors""" + return points / points[..., -1:] + +def inverse_intrinsic_matrix(ixts): + """ """ + return torch.inverse(ixts) + +def pixel_space_to_camera_space(pixel_space_points, depth, intrinsics): + """ + Convert pixel space points to camera space points. + + Args: + pixel_space_points (torch.Tensor): Pixel space points with shape (h, w, 2) + depth (torch.Tensor): Depth map with shape (b, v, h, w, 1) + intrinsics (torch.Tensor): Camera intrinsics with shape (b, v, 3, 3) + + Returns: + torch.Tensor: Camera space points with shape (b, v, h, w, 3). + """ + pixel_space_points = homogenize_points(pixel_space_points) + # camera_space_points = torch.einsum( + # "b v i j , h w j -> b v h w i", intrinsics.inverse(), pixel_space_points + # ) + camera_space_points = torch.einsum( + "b v i j , h w j -> b v h w i", inverse_intrinsic_matrix(intrinsics), pixel_space_points + ) + camera_space_points = camera_space_points * depth + return camera_space_points + + +def camera_space_to_world_space(camera_space_points, c2w): + """ + Convert camera space points to world space points. + + Args: + camera_space_points (torch.Tensor): Camera space points with shape (b, v, h, w, 3) + c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v, 4, 4) + + Returns: + torch.Tensor: World space points with shape (b, v, h, w, 3). + """ + camera_space_points = homogenize_points(camera_space_points) + world_space_points = torch.einsum("b v i j , b v h w j -> b v h w i", c2w, camera_space_points) + return world_space_points[..., :3] + + +def camera_space_to_pixel_space(camera_space_points, intrinsics): + """ + Convert camera space points to pixel space points. + + Args: + camera_space_points (torch.Tensor): Camera space points with shape (b, v1, v2, h, w, 3) + c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 3, 3) + + Returns: + torch.Tensor: World space points with shape (b, v1, v2, h, w, 2). + """ + camera_space_points = normalize_homogenous_points(camera_space_points) + pixel_space_points = torch.einsum( + "b u i j , b v u h w j -> b v u h w i", intrinsics, camera_space_points + ) + return pixel_space_points[..., :2] + + +def world_space_to_camera_space(world_space_points, c2w): + """ + Convert world space points to pixel space points. + + Args: + world_space_points (torch.Tensor): World space points with shape (b, v1, h, w, 3) + c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 4, 4) + + Returns: + torch.Tensor: Camera space points with shape (b, v1, v2, h, w, 3). + """ + world_space_points = homogenize_points(world_space_points) + camera_space_points = torch.einsum( + "b u i j , b v h w j -> b v u h w i", c2w.inverse(), world_space_points + ) + return camera_space_points[..., :3] + + +def unproject_depth( + depth, intrinsics, c2w=None, ixt_normalized=False, num_patches_x=None, num_patches_y=None +): + """ + Turn the depth map into a 3D point cloud in world space + + Args: + depth: (b, v, h, w, 1) + intrinsics: (b, v, 3, 3) + c2w: (b, v, 4, 4) + + Returns: + torch.Tensor: World space points with shape (b, v, h, w, 3). + """ + if c2w is None: + c2w = torch.eye(4, device=depth.device, dtype=depth.dtype) + c2w = c2w[None, None].repeat(depth.shape[0], depth.shape[1], 1, 1) + + if not ixt_normalized: + # Compute indices of pixels + h, w = depth.shape[-3], depth.shape[-2] + x_grid, y_grid = torch.meshgrid( + torch.arange(w, device=depth.device, dtype=depth.dtype), + torch.arange(h, device=depth.device, dtype=depth.dtype), + indexing="xy", + ) # (h, w), (h, w) + else: + # ixt_normalized: h=w=2.0. cx, cy, fx, fy are normalized according to h=w=2.0 + assert num_patches_x is not None and num_patches_y is not None + dx = 1 / num_patches_x + dy = 1 / num_patches_y + max_y = 1 - dy + min_y = -max_y + max_x = 1 - dx + min_x = -max_x + + grid_shift = 1.0 + y_grid, x_grid = torch.meshgrid( + torch.linspace( + min_y + grid_shift, + max_y + grid_shift, + num_patches_y, + dtype=torch.float32, + device=depth.device, + ), + torch.linspace( + min_x + grid_shift, + max_x + grid_shift, + num_patches_x, + dtype=torch.float32, + device=depth.device, + ), + indexing="ij", + ) + + # Compute coordinates of pixels in camera space + pixel_space_points = torch.stack((x_grid, y_grid), dim=-1) # (..., h, w, 2) + camera_points = pixel_space_to_camera_space( + pixel_space_points, depth, intrinsics + ) # (..., h, w, 3) + + # Convert points to world space + world_points = camera_space_to_world_space(camera_points, c2w) # (..., h, w, 3) + + return world_points \ No newline at end of file diff --git a/core/models/depth_anything_3/utils/gsply_helpers.py b/core/models/depth_anything_3/utils/gsply_helpers.py new file mode 100644 index 0000000..0560a35 --- /dev/null +++ b/core/models/depth_anything_3/utils/gsply_helpers.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from typing import Optional +import numpy as np +import torch +from einops import rearrange, repeat +try: + from plyfile import PlyData, PlyElement +except Exception: + PlyData = PlyElement = None + +from torch import Tensor + +from core.models.depth_anything_3.specs import Gaussians + + +def construct_list_of_attributes(num_rest: int) -> list[str]: + attributes = ["x", "y", "z", "nx", "ny", "nz"] + for i in range(3): + attributes.append(f"f_dc_{i}") + for i in range(num_rest): + attributes.append(f"f_rest_{i}") + attributes.append("opacity") + for i in range(3): + attributes.append(f"scale_{i}") + for i in range(4): + attributes.append(f"rot_{i}") + return attributes + + +def export_ply( + means: Tensor, # "gaussian 3" + scales: Tensor, # "gaussian 3" + rotations: Tensor, # "gaussian 4" + harmonics: Tensor, # "gaussian 3 d_sh" + opacities: Tensor, # "gaussian" + path: Path, + shift_and_scale: bool = False, + save_sh_dc_only: bool = True, + match_3dgs_mcmc_dev: Optional[bool] = False, +): + if shift_and_scale: + # Shift the scene so that the median Gaussian is at the origin. + means = means - means.median(dim=0).values + + # Rescale the scene so that most Gaussians are within range [-1, 1]. + scale_factor = means.abs().quantile(0.95, dim=0).max() + means = means / scale_factor + scales = scales / scale_factor + + rotations = rotations.detach().cpu().numpy() + + # Since current model use SH_degree = 4, + # which require large memory to store, we can only save the DC band to save memory. + f_dc = harmonics[..., 0] + f_rest = harmonics[..., 1:].flatten(start_dim=1) + + if match_3dgs_mcmc_dev: + sh_degree = 3 + n_rest = 3 * (sh_degree + 1) ** 2 - 3 + f_rest = repeat( + torch.zeros_like(harmonics[..., :1]), "... i -> ... (n i)", n=(n_rest // 3) + ).flatten(start_dim=1) + dtype_full = [ + (attribute, "f4") + for attribute in construct_list_of_attributes(num_rest=n_rest) + if attribute not in ("nx", "ny", "nz") + ] + else: + dtype_full = [ + (attribute, "f4") + for attribute in construct_list_of_attributes( + 0 if save_sh_dc_only else f_rest.shape[1] + ) + ] + elements = np.empty(means.shape[0], dtype=dtype_full) + attributes = [ + means.detach().cpu().numpy(), + torch.zeros_like(means).detach().cpu().numpy(), + f_dc.detach().cpu().contiguous().numpy(), + f_rest.detach().cpu().contiguous().numpy(), + opacities[..., None].detach().cpu().numpy(), + scales.log().detach().cpu().numpy(), + rotations, + ] + if match_3dgs_mcmc_dev: + attributes.pop(1) # dummy normal is not needed + elif save_sh_dc_only: + attributes.pop(3) # remove f_rest from attributes + + attributes = np.concatenate(attributes, axis=1) + elements[:] = list(map(tuple, attributes)) + path.parent.mkdir(exist_ok=True, parents=True) + PlyData([PlyElement.describe(elements, "vertex")]).write(path) + + +def inverse_sigmoid(x): + return torch.log(x / (1 - x)) + + +def save_gaussian_ply( + gaussians: Gaussians, + save_path: str, + ctx_depth: torch.Tensor, # depth of input views; for getting shape and filtering, "v h w 1" + shift_and_scale: bool = False, + save_sh_dc_only: bool = True, + gs_views_interval: int = 1, + inv_opacity: Optional[bool] = True, + prune_by_depth_percent: Optional[float] = 1.0, + prune_border_gs: Optional[bool] = True, + match_3dgs_mcmc_dev: Optional[bool] = False, +): + b = gaussians.means.shape[0] + assert b == 1, "must set batch_size=1 when exporting 3D gaussians" + src_v, out_h, out_w, _ = ctx_depth.shape + + # extract gs params + world_means = gaussians.means + world_shs = gaussians.harmonics + world_rotations = gaussians.rotations + gs_scales = gaussians.scales + gs_opacities = inverse_sigmoid(gaussians.opacities) if inv_opacity else gaussians.opacities + + # Create a mask to filter the Gaussians. + + # TODO: prune the sky region here + + # throw away Gaussians at the borders, since they're generally of lower quality. + if prune_border_gs: + mask = torch.zeros_like(ctx_depth, dtype=torch.bool) + gstrim_h = int(8 / 256 * out_h) + gstrim_w = int(8 / 256 * out_w) + mask[:, gstrim_h:-gstrim_h, gstrim_w:-gstrim_w, :] = 1 + else: + mask = torch.ones_like(ctx_depth, dtype=torch.bool) + + # trim the far away point based on depth; + if prune_by_depth_percent is not None and prune_by_depth_percent < 1: + in_depths = ctx_depth + d_percentile = torch.quantile( + in_depths.view(in_depths.shape[0], -1), q=prune_by_depth_percent, dim=1 + ).view(-1, 1, 1) + d_mask = (in_depths[..., 0] <= d_percentile).unsqueeze(-1) + mask = mask & d_mask + mask = mask.squeeze(-1) # v h w + + # helper fn, must place after mask + def trim_select_reshape(element): + selected_element = rearrange( + element[0], "(v h w) ... -> v h w ...", v=src_v, h=out_h, w=out_w + ) + selected_element = selected_element[::gs_views_interval][mask[::gs_views_interval]] + return selected_element + + export_ply( + means=trim_select_reshape(world_means), + scales=trim_select_reshape(gs_scales), + rotations=trim_select_reshape(world_rotations), + harmonics=trim_select_reshape(world_shs), + opacities=trim_select_reshape(gs_opacities), + path=Path(save_path), + shift_and_scale=shift_and_scale, + save_sh_dc_only=save_sh_dc_only, + match_3dgs_mcmc_dev=match_3dgs_mcmc_dev, + ) diff --git a/core/models/depth_anything_3/utils/io/__pycache__/input_processor.cpython-313.pyc b/core/models/depth_anything_3/utils/io/__pycache__/input_processor.cpython-313.pyc new file mode 100644 index 0000000..5401722 Binary files /dev/null and b/core/models/depth_anything_3/utils/io/__pycache__/input_processor.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/io/__pycache__/output_processor.cpython-313.pyc b/core/models/depth_anything_3/utils/io/__pycache__/output_processor.cpython-313.pyc new file mode 100644 index 0000000..3013669 Binary files /dev/null and b/core/models/depth_anything_3/utils/io/__pycache__/output_processor.cpython-313.pyc differ diff --git a/core/models/depth_anything_3/utils/io/input_processor.py b/core/models/depth_anything_3/utils/io/input_processor.py new file mode 100644 index 0000000..725329c --- /dev/null +++ b/core/models/depth_anything_3/utils/io/input_processor.py @@ -0,0 +1,501 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Input processor for Depth Anything 3 (parallelized). + +This version removes the square center-crop step for "*crop" methods (same as your note). +In addition, it parallelizes per-image preprocessing using the provided `parallel_execution`. +""" + +from __future__ import annotations + +from typing import Sequence +import cv2 +import numpy as np +import torch +import torchvision.transforms as T +from PIL import Image + +from core.models.depth_anything_3.utils.logger import logger +from core.models.depth_anything_3.utils.parallel_utils import parallel_execution + + +class InputProcessor: + """Prepares a batch of images for model inference. + This processor converts a list of image file paths into a single, model-ready + tensor. The processing pipeline is executed in parallel across multiple workers + for efficiency. + + Pipeline: + 1) Load image and convert to RGB + 2) Boundary resize (upper/lower bound, preserving aspect ratio) + 3) Enforce divisibility by PATCH_SIZE: + - "*resize" methods: each dimension is rounded to nearest multiple + (may up/downscale a few px) + - "*crop" methods: each dimension is floored to nearest multiple via center crop + 4) Convert to tensor and apply ImageNet normalization + 5) Stack into (1, N, 3, H, W) + + Parallelization: + - Each image is processed independently in a worker. + - Order of outputs matches the input order. + """ + + NORMALIZE = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + PATCH_SIZE = 14 + + def __init__(self): + pass + + # ----------------------------- + # Public API + # ----------------------------- + def __call__( + self, + image: list[np.ndarray | Image.Image | str], + extrinsics: np.ndarray | None = None, + intrinsics: np.ndarray | None = None, + process_res: int = 504, + process_res_method: str = "upper_bound_resize", + *, + num_workers: int = 8, + print_progress: bool = False, + sequential: bool | None = None, + desc: str | None = "Preprocess", + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """ + Returns: + (tensor, extrinsics_list, intrinsics_list) + tensor shape: (1, N, 3, H, W) + """ + sequential = self._resolve_sequential(sequential, num_workers) + exts_list, ixts_list = self._validate_and_pack_meta(image, extrinsics, intrinsics) + + results = self._run_parallel( + image=image, + exts_list=exts_list, + ixts_list=ixts_list, + process_res=process_res, + process_res_method=process_res_method, + num_workers=num_workers, + print_progress=print_progress, + sequential=sequential, + desc=desc, + ) + + proc_imgs, out_sizes, out_ixts, out_exts = self._unpack_results(results) + proc_imgs, out_sizes, out_ixts = self._unify_batch_shapes(proc_imgs, out_sizes, out_ixts) + + batch_tensor = self._stack_batch(proc_imgs) + out_exts = ( + torch.from_numpy(np.asarray(out_exts)).float() + if out_exts is not None and out_exts[0] is not None + else None + ) + out_ixts = ( + torch.from_numpy(np.asarray(out_ixts)).float() + if out_ixts is not None and out_ixts[0] is not None + else None + ) + return (batch_tensor, out_exts, out_ixts) + + # ----------------------------- + # __call__ helpers + # ----------------------------- + def _resolve_sequential(self, sequential: bool | None, num_workers: int) -> bool: + return (num_workers <= 1) if sequential is None else sequential + + def _validate_and_pack_meta( + self, + images: list[np.ndarray | Image.Image | str], + extrinsics: np.ndarray | None, + intrinsics: np.ndarray | None, + ) -> tuple[list[np.ndarray | None] | None, list[np.ndarray | None] | None]: + if extrinsics is not None and len(extrinsics) != len(images): + raise ValueError("Length of extrinsics must match images when provided.") + if intrinsics is not None and len(intrinsics) != len(images): + raise ValueError("Length of intrinsics must match images when provided.") + exts_list = [e for e in extrinsics] if extrinsics is not None else None + ixts_list = [k for k in intrinsics] if intrinsics is not None else None + return exts_list, ixts_list + + def _run_parallel( + self, + *, + image: list[np.ndarray | Image.Image | str], + exts_list: list[np.ndarray | None] | None, + ixts_list: list[np.ndarray | None] | None, + process_res: int, + process_res_method: str, + num_workers: int, + print_progress: bool, + sequential: bool, + desc: str | None, + ): + results = parallel_execution( + image, + exts_list, + ixts_list, + action=self._process_one, # (img, extrinsic, intrinsic, ...) + num_processes=num_workers, + print_progress=print_progress, + sequential=sequential, + desc=desc, + process_res=process_res, + process_res_method=process_res_method, + ) + if not results: + raise RuntimeError( + "No preprocessing results returned. Check inputs and parallel_execution." + ) + return results + + def _unpack_results(self, results): + """ + results: List[Tuple[torch.Tensor, Tuple[H, W], Optional[np.ndarray], Optional[np.ndarray]]] + -> processed_images, out_sizes, out_intrinsics, out_extrinsics + """ + try: + processed_images, out_sizes, out_intrinsics, out_extrinsics = zip(*results) + except Exception as e: + raise RuntimeError( + "Unexpected results structure from parallel_execution: " + f"{type(results)} / sample: {results[0]}" + ) from e + + return list(processed_images), list(out_sizes), list(out_intrinsics), list(out_extrinsics) + + def _unify_batch_shapes( + self, + processed_images: list[torch.Tensor], + out_sizes: list[tuple[int, int]], + out_intrinsics: list[np.ndarray | None], + ) -> tuple[list[torch.Tensor], list[tuple[int, int]], list[np.ndarray | None]]: + """Center-crop all tensors to the smallest H, W; adjust intrinsics' cx, cy accordingly.""" + if len(set(out_sizes)) <= 1: + return processed_images, out_sizes, out_intrinsics + + min_h = min(h for h, _ in out_sizes) + min_w = min(w for _, w in out_sizes) + logger.warn( + f"Images in batch have different sizes {out_sizes}; " + f"center-cropping all to smallest ({min_h},{min_w})" + ) + + center_crop = T.CenterCrop((min_h, min_w)) + new_imgs, new_sizes, new_ixts = [], [], [] + for img_t, (H, W), K in zip(processed_images, out_sizes, out_intrinsics): + crop_top = max(0, (H - min_h) // 2) + crop_left = max(0, (W - min_w) // 2) + new_imgs.append(center_crop(img_t)) + new_sizes.append((min_h, min_w)) + if K is None: + new_ixts.append(None) + else: + K_adj = K.copy() + K_adj[0, 2] -= crop_left + K_adj[1, 2] -= crop_top + new_ixts.append(K_adj) + return new_imgs, new_sizes, new_ixts + + def _stack_batch(self, processed_images: list[torch.Tensor]) -> torch.Tensor: + return torch.stack(processed_images) + + # ----------------------------- + # Per-item worker + # ----------------------------- + def _process_one( + self, + img: np.ndarray | Image.Image | str, + extrinsic: np.ndarray | None = None, + intrinsic: np.ndarray | None = None, + *, + process_res: int, + process_res_method: str, + ) -> tuple[torch.Tensor, tuple[int, int], np.ndarray | None, np.ndarray | None]: + # Load & remember original size + pil_img = self._load_image(img) + orig_w, orig_h = pil_img.size + + # Boundary resize + pil_img = self._resize_image(pil_img, process_res, process_res_method) + w, h = pil_img.size + intrinsic = self._resize_ixt(intrinsic, orig_w, orig_h, w, h) + + # Enforce divisibility by PATCH_SIZE + if process_res_method.endswith("resize"): + pil_img = self._make_divisible_by_resize(pil_img, self.PATCH_SIZE) + new_w, new_h = pil_img.size + intrinsic = self._resize_ixt(intrinsic, w, h, new_w, new_h) + w, h = new_w, new_h + elif process_res_method.endswith("crop"): + pil_img = self._make_divisible_by_crop(pil_img, self.PATCH_SIZE) + new_w, new_h = pil_img.size + intrinsic = self._crop_ixt(intrinsic, w, h, new_w, new_h) + w, h = new_w, new_h + else: + raise ValueError(f"Unsupported process_res_method: {process_res_method}") + + # Convert to tensor & normalize + img_tensor = self._normalize_image(pil_img) + _, H, W = img_tensor.shape + assert (W, H) == (w, h), "Tensor size mismatch with PIL image size after processing." + + # Return: (img_tensor, (H, W), intrinsic, extrinsic) + return img_tensor, (H, W), intrinsic, extrinsic + + # ----------------------------- + # Intrinsics transforms + # ----------------------------- + def _resize_ixt( + self, + intrinsic: np.ndarray | None, + orig_w: int, + orig_h: int, + w: int, + h: int, + ) -> np.ndarray | None: + if intrinsic is None: + return None + K = intrinsic.copy() + # scale fx, cx by w ratio; fy, cy by h ratio + K[:1] *= w / float(orig_w) + K[1:2] *= h / float(orig_h) + return K + + def _crop_ixt( + self, + intrinsic: np.ndarray | None, + orig_w: int, + orig_h: int, + w: int, + h: int, + ) -> np.ndarray | None: + if intrinsic is None: + return None + K = intrinsic.copy() + crop_h = (orig_h - h) // 2 + crop_w = (orig_w - w) // 2 + K[0, 2] -= crop_w + K[1, 2] -= crop_h + return K + + # ----------------------------- + # I/O & normalization + # ----------------------------- + def _load_image(self, img: np.ndarray | Image.Image | str) -> Image.Image: + if isinstance(img, str): + return Image.open(img).convert("RGB") + elif isinstance(img, np.ndarray): + # Assume HxWxC uint8/RGB + return Image.fromarray(img).convert("RGB") + elif isinstance(img, Image.Image): + return img.convert("RGB") + else: + raise ValueError(f"Unsupported image type: {type(img)}") + + def _normalize_image(self, img: Image.Image) -> torch.Tensor: + img_tensor = T.ToTensor()(img) + return self.NORMALIZE(img_tensor) + + # ----------------------------- + # Boundary resizing + # ----------------------------- + def _resize_image(self, img: Image.Image, target_size: int, method: str) -> Image.Image: + if method in ("upper_bound_resize", "upper_bound_crop"): + return self._resize_longest_side(img, target_size) + elif method in ("lower_bound_resize", "lower_bound_crop"): + return self._resize_shortest_side(img, target_size) + else: + raise ValueError(f"Unsupported resize method: {method}") + + def _resize_longest_side(self, img: Image.Image, target_size: int) -> Image.Image: + w, h = img.size + longest = max(w, h) + if longest == target_size: + return img + scale = target_size / float(longest) + new_w = max(1, int(round(w * scale))) + new_h = max(1, int(round(h * scale))) + interpolation = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA + arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation) + return Image.fromarray(arr) + + def _resize_shortest_side(self, img: Image.Image, target_size: int) -> Image.Image: + w, h = img.size + shortest = min(w, h) + if shortest == target_size: + return img + scale = target_size / float(shortest) + new_w = max(1, int(round(w * scale))) + new_h = max(1, int(round(h * scale))) + interpolation = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA + arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation) + return Image.fromarray(arr) + + # ----------------------------- + # Make divisible by PATCH_SIZE + # ----------------------------- + def _make_divisible_by_crop(self, img: Image.Image, patch: int) -> Image.Image: + """ + Floor each dimension to the nearest multiple of PATCH_SIZE via center crop. + Example: 504x377 -> 504x364 + """ + w, h = img.size + new_w = (w // patch) * patch + new_h = (h // patch) * patch + if new_w == w and new_h == h: + return img + left = (w - new_w) // 2 + top = (h - new_h) // 2 + return img.crop((left, top, left + new_w, top + new_h)) + + def _make_divisible_by_resize(self, img: Image.Image, patch: int) -> Image.Image: + """ + Round each dimension to nearest multiple of PATCH_SIZE via small resize. + """ + w, h = img.size + + def nearest_multiple(x: int, p: int) -> int: + down = (x // p) * p + up = down + p + return up if abs(up - x) <= abs(x - down) else down + + new_w = max(1, nearest_multiple(w, patch)) + new_h = max(1, nearest_multiple(h, patch)) + if new_w == w and new_h == h: + return img + upscale = (new_w > w) or (new_h > h) + interpolation = cv2.INTER_CUBIC if upscale else cv2.INTER_AREA + arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation) + return Image.fromarray(arr) + + +# Backward compatibility alias +InputAdapter = InputProcessor + + +# =========================== +# Minimal test runner (parallel execution) +# =========================== +if __name__ == "__main__": + """ + Minimal test suite: + - Creates pairs of images so batch shapes match. + - Tests all four process_res_methods. + - Prints fx fy cx cy IN->OUT per image. + - Includes cases with K/E provided and with None. + """ + + def fmt_k_line(K: np.ndarray | None) -> str: + if K is None: + return "None" + fx, fy, cx, cy = float(K[0, 0]), float(K[1, 1]), float(K[0, 2]), float(K[1, 2]) + return f"fx={fx:.3f} fy={fy:.3f} cx={cx:.3f} cy={cy:.3f}" + + def show_result( + tag: str, + tensor: torch.Tensor, + Ks_in: Sequence[np.ndarray | None] | None = None, + Ks_out: Sequence[np.ndarray | None] | None = None, + ): + B, N, C, H, W = tensor.shape + print(f"[{tag}] shape={tuple(tensor.shape)} HxW=({H},{W}) div14=({H%14==0},{W%14==0})") + assert H % 14 == 0 and W % 14 == 0, f"{tag}: output size not divisible by 14!" + if Ks_in is not None or Ks_out is not None: + Ks_in = Ks_in or [None] * N + Ks_out = Ks_out or [None] * N + for i in range(N): + print(f" K[{i}]: {fmt_k_line(Ks_in[i])} -> {fmt_k_line(Ks_out[i])}") + + proc = InputProcessor() + process_res = 504 + methods = ["upper_bound_resize", "upper_bound_crop", "lower_bound_resize", "lower_bound_crop"] + + # Example sizes (two orientations) + small_sizes = [(680, 1208), (1208, 680)] + large_sizes = [(1208, 680), (680, 1208)] + + def make_K(w, h, fx=1200.0, fy=1100.0): + cx, cy = w / 2.0, h / 2.0 + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + return K + + def run_suite(suite_name: str, sizes: list[tuple[int, int]]): + print(f"\n===== {suite_name} =====") + for w, h in sizes: + img = Image.new("RGB", (w, h), color=(123, 222, 100)) + batch_imgs = [img, img] + + # intrinsics / extrinsics examples + Ks_in = [make_K(w, h), make_K(w, h)] + Es_in = [np.eye(4, dtype=np.float32), np.eye(4, dtype=np.float32)] + + for m in methods: + tensor, Es_out, Ks_out = proc( + image=batch_imgs, + process_res=process_res, + process_res_method=m, + num_workers=8, + print_progress=False, + intrinsics=Ks_in, # test with non-None + extrinsics=Es_in, + ) + show_result(f"{suite_name} size=({w},{h}) | {m}", tensor, Ks_in, Ks_out) + + # Also test None path + tensor2, Es_out2, Ks_out2 = proc( + image=batch_imgs, + process_res=process_res, + process_res_method="upper_bound_resize", + num_workers=8, + intrinsics=None, + extrinsics=None, + ) + show_result( + f"{suite_name} size=({w},{h}) | upper_bound_resize | no K/E", + tensor2, + None, + Ks_out2, + ) + + run_suite("SMALL", small_sizes) + run_suite("LARGE", large_sizes) + + # Extra sanity for 504x376 + print("\n===== EXTRA sanity for 504x376 =====") + img_example = Image.new("RGB", (504, 376), color=(10, 20, 30)) + Ks_in_extra = [make_K(504, 376, fx=900.0, fy=900.0), make_K(504, 376, fx=900.0, fy=900.0)] + + out_r, _, Ks_out_r = proc( + image=[img_example, img_example], + process_res=504, + process_res_method="upper_bound_resize", + num_workers=8, + intrinsics=Ks_in_extra, + ) + out_c, _, Ks_out_c = proc( + image=[img_example, img_example], + process_res=504, + process_res_method="upper_bound_crop", + num_workers=8, + intrinsics=Ks_in_extra, + ) + _, _, _, Hr, Wr = out_r.shape + _, _, _, Hc, Wc = out_c.shape + print(f"upper_bound_resize -> ({Hr},{Wr}) (rounded to nearest multiple of 14)") + show_result("Ks after upper_bound_resize", out_r, Ks_in_extra, Ks_out_r) + print(f"upper_bound_crop -> ({Hc},{Wc}) (floored to multiple of 14)") + show_result("Ks after upper_bound_crop", out_c, Ks_in_extra, Ks_out_c) diff --git a/core/models/depth_anything_3/utils/io/output_processor.py b/core/models/depth_anything_3/utils/io/output_processor.py new file mode 100644 index 0000000..4586479 --- /dev/null +++ b/core/models/depth_anything_3/utils/io/output_processor.py @@ -0,0 +1,172 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Output processor for Depth Anything 3. + +This module handles model output processing, including tensor-to-numpy conversion, +batch dimension removal, and Prediction object creation. +""" + +from __future__ import annotations + +import numpy as np +import torch +from addict import Dict as AddictDict + +from core.models.depth_anything_3.specs import Prediction + + +class OutputProcessor: + """ + Output processor for converting model outputs to Prediction objects. + + Handles tensor-to-numpy conversion, batch dimension removal, + and creates structured Prediction objects with proper data types. + """ + + def __init__(self) -> None: + """Initialize the output processor.""" + + def __call__(self, model_output: dict[str, torch.Tensor]) -> Prediction: + """ + Convert model output to Prediction object. + + Args: + model_output: Model output dictionary containing depth, conf, extrinsics, intrinsics + Expected shapes: depth (B, N, 1, H, W), conf (B, N, 1, H, W), + extrinsics (B, N, 4, 4), intrinsics (B, N, 3, 3) + + Returns: + Prediction: Object containing depth estimation results with shapes: + depth (N, H, W), conf (N, H, W), extrinsics (N, 4, 4), intrinsics (N, 3, 3) + """ + # Extract data from batch dimension (B=1, N=number of images) + depth = self._extract_depth(model_output) + conf = self._extract_conf(model_output) + extrinsics = self._extract_extrinsics(model_output) + intrinsics = self._extract_intrinsics(model_output) + sky = self._extract_sky(model_output) + aux = self._extract_aux(model_output) + gaussians = model_output.get("gaussians", None) + scale_factor = model_output.get("scale_factor", None) + + return Prediction( + depth=depth, + sky=sky, + conf=conf, + extrinsics=extrinsics, + intrinsics=intrinsics, + is_metric=getattr(model_output, "is_metric", 0), + gaussians=gaussians, + aux=aux, + scale_factor=scale_factor, + ) + + def _extract_depth(self, model_output: dict[str, torch.Tensor]) -> np.ndarray: + """ + Extract depth tensor from model output and convert to numpy. + + Args: + model_output: Model output dictionary + + Returns: + Depth array with shape (N, H, W) + """ + depth = model_output["depth"].squeeze(0).squeeze(-1).cpu().numpy() # (N, H, W) + return depth + + def _extract_conf(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None: + """ + Extract confidence tensor from model output and convert to numpy. + + Args: + model_output: Model output dictionary + + Returns: + Confidence array with shape (N, H, W) or None + """ + conf = model_output.get("depth_conf", None) + if conf is not None: + conf = conf.squeeze(0).cpu().numpy() # (N, H, W) + return conf + + def _extract_extrinsics(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None: + """ + Extract extrinsics tensor from model output and convert to numpy. + + Args: + model_output: Model output dictionary + + Returns: + Extrinsics array with shape (N, 4, 4) or None + """ + extrinsics = model_output.get("extrinsics", None) + if extrinsics is not None: + extrinsics = extrinsics.squeeze(0).cpu().numpy() # (N, 4, 4) + return extrinsics + + def _extract_intrinsics(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None: + """ + Extract intrinsics tensor from model output and convert to numpy. + + Args: + model_output: Model output dictionary + + Returns: + Intrinsics array with shape (N, 3, 3) or None + """ + intrinsics = model_output.get("intrinsics", None) + if intrinsics is not None: + intrinsics = intrinsics.squeeze(0).cpu().numpy() # (N, 3, 3) + return intrinsics + + def _extract_sky(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None: + """ + Extract sky tensor from model output and convert to numpy. + + Args: + model_output: Model output dictionary + + Returns: + Sky mask array with shape (N, H, W) or None + """ + sky = model_output.get("sky", None) + if sky is not None: + sky = sky.squeeze(0).cpu().numpy() >= 0.5 # (N, H, W) + return sky + + def _extract_aux(self, model_output: dict[str, torch.Tensor]) -> AddictDict: + """ + Extract auxiliary data from model output and convert to numpy. + + Args: + model_output: Model output dictionary + + Returns: + Dictionary containing auxiliary data + """ + aux = model_output.get("aux", None) + ret = AddictDict() + if aux is not None: + for k in aux.keys(): + if isinstance(aux[k], torch.Tensor): + ret[k] = aux[k].squeeze(0).cpu().numpy() + else: + ret[k] = aux[k] + return ret + + +# Backward compatibility alias +OutputAdapter = OutputProcessor diff --git a/core/models/depth_anything_3/utils/layout_helpers.py b/core/models/depth_anything_3/utils/layout_helpers.py new file mode 100644 index 0000000..189c170 --- /dev/null +++ b/core/models/depth_anything_3/utils/layout_helpers.py @@ -0,0 +1,216 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains useful layout utilities for images. They are: + +- add_border: Add a border to an image. +- cat/hcat/vcat: Join images by arranging them in a line. If the images have different + sizes, they are aligned as specified (start, end, center). Allows you to specify a gap + between images. + +Images are assumed to be float32 tensors with shape (channel, height, width). +""" + +from typing import Any, Generator, Iterable, Literal, Union +import torch +from torch import Tensor + +Alignment = Literal["start", "center", "end"] +Axis = Literal["horizontal", "vertical"] +Color = Union[ + int, + float, + Iterable[int], + Iterable[float], + Tensor, + Tensor, +] + + +def _sanitize_color(color: Color) -> Tensor: # "#channel" + # Convert tensor to list (or individual item). + if isinstance(color, torch.Tensor): + color = color.tolist() + + # Turn iterators and individual items into lists. + if isinstance(color, Iterable): + color = list(color) + else: + color = [color] + + return torch.tensor(color, dtype=torch.float32) + + +def _intersperse(iterable: Iterable, delimiter: Any) -> Generator[Any, None, None]: + it = iter(iterable) + yield next(it) + for item in it: + yield delimiter + yield item + + +def _get_main_dim(main_axis: Axis) -> int: + return { + "horizontal": 2, + "vertical": 1, + }[main_axis] + + +def _get_cross_dim(main_axis: Axis) -> int: + return { + "horizontal": 1, + "vertical": 2, + }[main_axis] + + +def _compute_offset(base: int, overlay: int, align: Alignment) -> slice: + assert base >= overlay + offset = { + "start": 0, + "center": (base - overlay) // 2, + "end": base - overlay, + }[align] + return slice(offset, offset + overlay) + + +def overlay( + base: Tensor, # "channel base_height base_width" + overlay: Tensor, # "channel overlay_height overlay_width" + main_axis: Axis, + main_axis_alignment: Alignment, + cross_axis_alignment: Alignment, +) -> Tensor: # "channel base_height base_width" + # The overlay must be smaller than the base. + _, base_height, base_width = base.shape + _, overlay_height, overlay_width = overlay.shape + assert base_height >= overlay_height and base_width >= overlay_width + + # Compute spacing on the main dimension. + main_dim = _get_main_dim(main_axis) + main_slice = _compute_offset( + base.shape[main_dim], overlay.shape[main_dim], main_axis_alignment + ) + + # Compute spacing on the cross dimension. + cross_dim = _get_cross_dim(main_axis) + cross_slice = _compute_offset( + base.shape[cross_dim], overlay.shape[cross_dim], cross_axis_alignment + ) + + # Combine the slices and paste the overlay onto the base accordingly. + selector = [..., None, None] + selector[main_dim] = main_slice + selector[cross_dim] = cross_slice + result = base.clone() + result[selector] = overlay + return result + + +def cat( + main_axis: Axis, + *images: Iterable[Tensor], # "channel _ _" + align: Alignment = "center", + gap: int = 8, + gap_color: Color = 1, +) -> Tensor: # "channel height width" + """Arrange images in a line. The interface resembles a CSS div with flexbox.""" + device = images[0].device + gap_color = _sanitize_color(gap_color).to(device) + + # Find the maximum image side length in the cross axis dimension. + cross_dim = _get_cross_dim(main_axis) + cross_axis_length = max(image.shape[cross_dim] for image in images) + + # Pad the images. + padded_images = [] + for image in images: + # Create an empty image with the correct size. + padded_shape = list(image.shape) + padded_shape[cross_dim] = cross_axis_length + base = torch.ones(padded_shape, dtype=torch.float32, device=device) + base = base * gap_color[:, None, None] + padded_images.append(overlay(base, image, main_axis, "start", align)) + + # Intersperse separators if necessary. + if gap > 0: + # Generate a separator. + c, _, _ = images[0].shape + separator_size = [gap, gap] + separator_size[cross_dim - 1] = cross_axis_length + separator = torch.ones((c, *separator_size), dtype=torch.float32, device=device) + separator = separator * gap_color[:, None, None] + + # Intersperse the separator between the images. + padded_images = list(_intersperse(padded_images, separator)) + + return torch.cat(padded_images, dim=_get_main_dim(main_axis)) + + +def hcat( + *images: Iterable[Tensor], # "channel _ _" + align: Literal["start", "center", "end", "top", "bottom"] = "start", + gap: int = 8, + gap_color: Color = 1, +): + """Shorthand for a horizontal linear concatenation.""" + return cat( + "horizontal", + *images, + align={ + "start": "start", + "center": "center", + "end": "end", + "top": "start", + "bottom": "end", + }[align], + gap=gap, + gap_color=gap_color, + ) + + +def vcat( + *images: Iterable[Tensor], # "channel _ _" + align: Literal["start", "center", "end", "left", "right"] = "start", + gap: int = 8, + gap_color: Color = 1, +): + """Shorthand for a horizontal linear concatenation.""" + return cat( + "vertical", + *images, + align={ + "start": "start", + "center": "center", + "end": "end", + "left": "start", + "right": "end", + }[align], + gap=gap, + gap_color=gap_color, + ) + + +def add_border( + image: Tensor, # "channel height width" + border: int = 8, + color: Color = 1, +) -> Tensor: # "channel new_height new_width" + color = _sanitize_color(color).to(image) + c, h, w = image.shape + result = torch.empty( + (c, h + 2 * border, w + 2 * border), dtype=torch.float32, device=image.device + ) + result[:] = color[:, None, None] + result[:, border : h + border, border : w + border] = image + return result diff --git a/core/models/depth_anything_3/utils/logger.py b/core/models/depth_anything_3/utils/logger.py new file mode 100644 index 0000000..897855d --- /dev/null +++ b/core/models/depth_anything_3/utils/logger.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + + +class Color: + RED = "\033[91m" + YELLOW = "\033[93m" + WHITE = "\033[97m" + GREEN = "\033[92m" + RESET = "\033[0m" + + +LOG_LEVELS = {"ERROR": 0, "WARN": 1, "INFO": 2, "DEBUG": 3} + +COLOR_MAP = {"ERROR": Color.RED, "WARN": Color.YELLOW, "INFO": Color.WHITE, "DEBUG": Color.GREEN} + + +def get_env_log_level(): + level = os.environ.get("DA3_LOG_LEVEL", "INFO").upper() + return LOG_LEVELS.get(level, LOG_LEVELS["INFO"]) + + +class Logger: + def __init__(self): + self.level = get_env_log_level() + + def log(self, level_str, *args, **kwargs): + level_key = level_str.split(":")[0].strip() + level_val = LOG_LEVELS.get(level_key) + if level_val is None: + raise ValueError(f"Unknown log level: {level_str}") + if self.level >= level_val: + color = COLOR_MAP[level_key] + msg = " ".join(str(arg) for arg in args) + + # Align log level output in square brackets + # ERROR and DEBUG are 5 characters, INFO and WARN have an extra space for alignment + tag = level_key + if tag in ("INFO", "WARN"): + tag += " " + print( + f"{color}[{tag}] {msg}{Color.RESET}", + file=sys.stderr if level_key == "ERROR" else sys.stdout, + **kwargs, + ) + + def error(self, *args, **kwargs): + self.log("ERROR:", *args, **kwargs) + + def warn(self, *args, **kwargs): + self.log("WARN:", *args, **kwargs) + + def info(self, *args, **kwargs): + self.log("INFO:", *args, **kwargs) + + def debug(self, *args, **kwargs): + self.log("DEBUG:", *args, **kwargs) + + +logger = Logger() + +__all__ = ["logger"] + +if __name__ == "__main__": + logger.info("This is an info message") + logger.warn("This is a warning message") + logger.error("This is an error message") + logger.debug("This is a debug message") \ No newline at end of file diff --git a/core/models/depth_anything_3/utils/memory.py b/core/models/depth_anything_3/utils/memory.py new file mode 100644 index 0000000..682dad7 --- /dev/null +++ b/core/models/depth_anything_3/utils/memory.py @@ -0,0 +1,127 @@ +""" +GPU memory utility helpers. + +Shared cleanup and memory checking logic used by both the backend API and +the Gradio UI to keep memory-management behavior consistent. +""" +from __future__ import annotations + +import gc + +from typing import Any, Dict, Optional + +import torch + + +def get_gpu_memory_info() -> Optional[Dict[str, Any]]: + """Return a snapshot of current GPU memory usage or None if CUDA not available. + + Keys in returned dict: total_gb, allocated_gb, reserved_gb, free_gb, utilization + """ + if not torch.cuda.is_available(): + return None + + try: + device = torch.cuda.current_device() + total_memory = torch.cuda.get_device_properties(device).total_memory + allocated_memory = torch.cuda.memory_allocated(device) + reserved_memory = torch.cuda.memory_reserved(device) + free_memory = total_memory - reserved_memory + + return { + "total_gb": total_memory / 1024 ** 3, + "allocated_gb": allocated_memory / 1024 ** 3, + "reserved_gb": reserved_memory / 1024 ** 3, + "free_gb": free_memory / 1024 ** 3, + "utilization": (reserved_memory / total_memory) * 100, + } + except Exception: + return None + + +def cleanup_cuda_memory() -> None: + """Perform a robust GPU cleanup sequence. + + This includes synchronizing, emptying caches, collecting IPC handles and + running the Python garbage collector. Use this instead of a raw + ``torch.cuda.empty_cache()`` where you need reliable freeing of GPU memory + between model loads or in error handling paths. + """ + try: + if torch.cuda.is_available(): + mem_before = get_gpu_memory_info() + + torch.cuda.synchronize() + torch.cuda.empty_cache() + # Collect cross-process cuda resources + try: + torch.cuda.ipc_collect() + except Exception: + # Older PyTorch versions or non-cuda devices may not support + # ipc_collect (no-op if not available) + pass + gc.collect() + + mem_after = get_gpu_memory_info() + if mem_before and mem_after: + freed = mem_before["reserved_gb"] - mem_after["reserved_gb"] + print( + f"CUDA cleanup: freed {freed:.2f}GB, " + f"available: {mem_after['free_gb']:.2f}GB/{mem_after['total_gb']:.2f}GB" + ) + else: + print("CUDA memory cleanup completed") + except Exception as e: + print(f"Warning: CUDA cleanup failed: {e}") + + +def check_memory_availability(required_gb: float = 2.0) -> tuple[bool, str]: + """Return whether at least ``required_gb`` seems available on the current GPU. + + The returned tuple is (is_available, message) with a human-friendly message. + """ + try: + if not torch.cuda.is_available(): + return False, "CUDA is not available" + + mem_info = get_gpu_memory_info() + if mem_info is None: + return True, "Cannot check memory, proceeding anyway" + + if mem_info["free_gb"] < required_gb: + return ( + False, + ( + f"Insufficient GPU memory: {mem_info['free_gb']:.2f}GB available, " + f"{required_gb:.2f}GB required. Total: {mem_info['total_gb']:.2f}GB, " + f"Used: {mem_info['reserved_gb']:.2f}GB ({mem_info['utilization']:.1f}%)" + ), + ) + + return ( + True, + ( + f"Memory check passed: {mem_info['free_gb']:.2f}GB available, " + f"{required_gb:.2f}GB required" + ), + ) + except Exception as e: + return True, f"Memory check failed: {e}, proceeding anyway" +def estimate_memory_requirement(num_images: int, process_res: int) -> float: + """Heuristic estimate for memory usage (GB) based on image count and resolution. + + This mirrors the simple policy used by the backend service so other code + (e.g., Gradio UI) can make consistent decisions when checking available + memory before loading a model or running inference. + + Args: + num_images: Number of images to process. + process_res: Processing resolution. + + Returns: + Estimated memory requirement in GB. + """ + base_memory = 2.0 + per_image_memory = (process_res / 504) ** 2 * 0.5 + total_memory = base_memory + (num_images * per_image_memory * 0.1) + return total_memory diff --git a/core/models/depth_anything_3/utils/model_loading.py b/core/models/depth_anything_3/utils/model_loading.py new file mode 100644 index 0000000..eda85ee --- /dev/null +++ b/core/models/depth_anything_3/utils/model_loading.py @@ -0,0 +1,149 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model loading and state dict conversion utilities. +""" + +from typing import Dict, Tuple +import torch + +from core.models.depth_anything_3.utils.logger import logger + + +def convert_general_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert general model state dict to match current model architecture. + + Args: + state_dict: Original state dictionary + + Returns: + Converted state dictionary + """ + # Replace module prefixes + state_dict = {k.replace("module.", "model."): v for k, v in state_dict.items()} + state_dict = {k.replace(".net.", ".backbone."): v for k, v in state_dict.items()} + + # Remove camera token if present + if "model.backbone.pretrained.camera_token" in state_dict: + del state_dict["model.backbone.pretrained.camera_token"] + + # Replace camera token naming + state_dict = { + k.replace(".camera_token_extra", ".camera_token"): v for k, v in state_dict.items() + } + + # Replace head naming + state_dict = { + k.replace("model.all_heads.camera_cond_head", "model.cam_enc"): v + for k, v in state_dict.items() + } + state_dict = { + k.replace("model.all_heads.camera_head", "model.cam_dec"): v for k, v in state_dict.items() + } + state_dict = {k.replace(".more_mlps.", ".backbone."): v for k, v in state_dict.items()} + state_dict = {k.replace(".fc_rot.", ".fc_qvec."): v for k, v in state_dict.items()} + state_dict = { + k.replace("model.all_heads.head", "model.head"): v for k, v in state_dict.items() + } + + # Replace output naming + state_dict = { + k.replace("output_conv2_additional.sky_mask", "sky_output_conv2"): v + for k, v in state_dict.items() + } + state_dict = {k.replace("_ray.", "_aux."): v for k, v in state_dict.items()} + + # Update GS-DPT head naming and value + state_dict = {k.replace("gaussian_param_head.", "gs_head."): v for k, v in state_dict.items()} + + return state_dict + + +def convert_metric_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert metric model state dict to match current model architecture. + + Args: + state_dict: Original metric state dictionary + + Returns: + Converted state dictionary + """ + # Add module prefix for metric models + state_dict = {"module." + k: v for k, v in state_dict.items()} + return convert_general_state_dict(state_dict) + + +def load_pretrained_weights(model, model_path: str, is_metric: bool = False) -> Tuple[list, list]: + """ + Load pretrained weights for a single model. + + Args: + model: Model instance to load weights into + model_path: Path to the pretrained weights + is_metric: Whether this is a metric model + + Returns: + Tuple of (missed_keys, unexpected_keys) + """ + state_dict = torch.load(model_path, map_location="cpu") + + if is_metric: + state_dict = convert_metric_state_dict(state_dict) + else: + state_dict = convert_general_state_dict(state_dict) + + missed, unexpected = model.load_state_dict(state_dict, strict=False) + logger.info("Missed keys:", missed) + logger.info("Unexpected keys:", unexpected) + + return missed, unexpected + + +def load_pretrained_nested_weights( + model, main_model_path: str, metric_model_path: str +) -> Tuple[list, list]: + """ + Load pretrained weights for a nested model with both main and metric branches. + + Args: + model: Nested model instance + main_model_path: Path to main model weights + metric_model_path: Path to metric model weights + + Returns: + Tuple of (missed_keys, unexpected_keys) + """ + # Load main model weights + state_dict0 = torch.load(main_model_path, map_location="cpu") + state_dict0 = convert_general_state_dict(state_dict0) + state_dict0 = {k.replace("model.", "model.da3."): v for k, v in state_dict0.items()} + + # Load metric model weights + state_dict1 = torch.load(metric_model_path, map_location="cpu") + state_dict1 = convert_metric_state_dict(state_dict1) + state_dict1 = {k.replace("model.", "model.da3_metric."): v for k, v in state_dict1.items()} + + # Combine state dictionaries + combined_state_dict = state_dict0.copy() + combined_state_dict.update(state_dict1) + + missed, unexpected = model.load_state_dict(combined_state_dict, strict=False) + + print("Missed keys:", missed) + print("Unexpected keys:", unexpected) + + return missed, unexpected diff --git a/core/models/depth_anything_3/utils/parallel_utils.py b/core/models/depth_anything_3/utils/parallel_utils.py new file mode 100644 index 0000000..9ff108e --- /dev/null +++ b/core/models/depth_anything_3/utils/parallel_utils.py @@ -0,0 +1,133 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +from functools import wraps +from multiprocessing.pool import ThreadPool +from threading import Thread +from typing import Callable, Dict, List +import imageio +from tqdm import tqdm + + +def async_call_func(func): + @wraps(func) + async def wrapper(*args, **kwargs): + loop = asyncio.get_event_loop() + # Use run_in_executor to run the blocking function in a separate thread + return await loop.run_in_executor(None, func, *args, **kwargs) + + return wrapper + + +slice_func = lambda chunk_index, chunk_dim, chunk_size: [slice(None)] * chunk_dim + [ + slice(chunk_index, chunk_index + chunk_size) +] + + +def async_call(fn): + def wrapper(*args, **kwargs): + Thread(target=fn, args=args, kwargs=kwargs).start() + + return wrapper + + +def _save_image_impl(save_img, save_path): + """Common implementation for saving images synchronously or asynchronously""" + os.makedirs(os.path.dirname(save_path), exist_ok=True) + imageio.imwrite(save_path, save_img) + + +@async_call +def save_image_async(save_img, save_path): + """Save image asynchronously""" + _save_image_impl(save_img, save_path) + + +def save_image(save_img, save_path): + """Save image synchronously""" + _save_image_impl(save_img, save_path) + + +def parallel_execution( + *args, + action: Callable, + num_processes=32, + print_progress=False, + sequential=False, + async_return=False, + desc=None, + **kwargs, +): + # Partially copy from EasyVolumetricVideo (parallel_execution) + # NOTE: we expect first arg / or kwargs to be distributed + # NOTE: print_progress arg is reserved. + # `*args` packs all positional arguments passed to the function into a tuple + args = list(args) + + def get_length(args: List, kwargs: Dict): + for a in args: + if isinstance(a, list): + return len(a) + for v in kwargs.values(): + if isinstance(v, list): + return len(v) + raise NotImplementedError + + def get_action_args(length: int, args: List, kwargs: Dict, i: int): + action_args = [ + (arg[i] if isinstance(arg, list) and len(arg) == length else arg) for arg in args + ] + # TODO: Support all types of iterable + action_kwargs = { + key: ( + kwargs[key][i] + if isinstance(kwargs[key], list) and len(kwargs[key]) == length + else kwargs[key] + ) + for key in kwargs + } + return action_args, action_kwargs + + if not sequential: + # Create ThreadPool + pool = ThreadPool(processes=num_processes) + + # Spawn threads + results = [] + asyncs = [] + length = get_length(args, kwargs) + for i in range(length): + action_args, action_kwargs = get_action_args(length, args, kwargs, i) + async_result = pool.apply_async(action, action_args, action_kwargs) + asyncs.append(async_result) + + # Join threads and get return values + if not async_return: + for async_result in tqdm(asyncs, desc=desc, disable=not print_progress): + results.append(async_result.get()) # will sync the corresponding thread + pool.close() + pool.join() + return results + else: + return pool + else: + results = [] + length = get_length(args, kwargs) + for i in tqdm(range(length), desc=desc, disable=not print_progress): + action_args, action_kwargs = get_action_args(length, args, kwargs, i) + async_result = action(*action_args, **action_kwargs) + results.append(async_result) + return results diff --git a/core/models/depth_anything_3/utils/pca_utils.py b/core/models/depth_anything_3/utils/pca_utils.py new file mode 100644 index 0000000..2b9eee2 --- /dev/null +++ b/core/models/depth_anything_3/utils/pca_utils.py @@ -0,0 +1,284 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +PCA utilities for feature visualization and dimensionality reduction (video-friendly). +- Support frame-by-frame: transform_frame / transform_video +- Support one-time global PCA fitting and reuse (mean, V3) for stable colors +- Support Procrustes alignment (solving principal component order/sign/rotation jumps) +- Support global fixed or temporal EMA for percentiles (time dimension only, no spatial) +""" + +import numpy as np +import torch + + +def pca_to_rgb_4d_bf16_percentile( + x_np: np.ndarray, + device=None, + q_oversample: int = 6, + clip_percent: float = 10.0, # Percentage to clip from top and bottom (0~49.9) + return_uint8: bool = False, + enable_autocast_bf16: bool = True, +): + """ + Reduce numpy array of shape (49, 27, 36, 3072) to 3D via PCA and visualize as (49, 27, 36, 3). + - PCA uses torch.pca_lowrank (randomized SVD), defaults to GPU. + - Uses CUDA bf16 autocast in computation (if available), + then per-channel percentile clipping and normalization. + - Default removes 5% outliers from top and bottom (adjustable via clip_percent) to + improve visualization contrast. + + Parameters + ---------- + x_np : np.ndarray + Shape must be (49, 27, 36, 3072). dtype recommended float32/float64. + device : str | None + Specify 'cuda' or 'cpu'. Auto-select if None (prefer cuda). + q_oversample : int + Oversampling q for pca_lowrank, must be >= 3. + Slightly larger than target dim (3) is more stable, default 6. + clip_percent : float + Percentage to clip from top and bottom (0~49.9), + e.g. 5.0 means clip lowest 5% and highest 5% per channel. + return_uint8 : bool + True returns uint8(0~255), otherwise returns float32(0~1). + enable_autocast_bf16 : bool + Enable bf16 autocast on CUDA. + + Returns + ------- + np.ndarray + Array of shape (49, 27, 36, 3), float32[0,1] or uint8[0,255]. + """ + assert ( + x_np.ndim == 4 + ) # and x_np.shape[-1] == 3072, f"expect (49,27,36,3072), got {x_np.shape}" + B1, B2, B3, D = x_np.shape + N = B1 * B2 * B3 + + # Device selection + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Convert input to torch, unified float32 + X = torch.from_numpy(x_np.reshape(N, D)).to(device=device, dtype=torch.float32) + + # Parameter and safety checks + k = 3 + q = max(int(q_oversample), k) + clip_percent = float(clip_percent) + if not (0.0 <= clip_percent < 50.0): + raise ValueError( + "clip_percent must be in [0, 50), e.g. 5.0 means clip 5% from top and bottom" + ) + low = clip_percent / 100.0 + high = 1.0 - low + + with torch.no_grad(): + # Zero mean + mean = X.mean(dim=0, keepdim=True) + Xc = X - mean + + # Main computation: PCA + projection, try to use bf16 + # (auto-fallback if operator not supported) + device.startswith("cuda") and enable_autocast_bf16 + U, S, V = torch.pca_lowrank(Xc, q=q, center=False) # V: (D, q) + V3 = V[:, :k] # (3072, 3) + PCs = Xc @ V3 # (N, 3) + + # === Per-channel percentile clipping and normalization to [0,1] === + # Vectorized one-time calculation of low/high percentiles for each channel + qs = torch.tensor([low, high], device=PCs.device, dtype=PCs.dtype) + qvals = torch.quantile(PCs, q=qs, dim=0) # Shape (2, 3) + lo = qvals[0] # (3,) + hi = qvals[1] # (3,) + + # Avoid degenerate case where hi==lo + denom = torch.clamp(hi - lo, min=1e-8) + + # Broadcast clipping + normalization + PCs = torch.clamp(PCs, lo, hi) + PCs = (PCs - lo) / denom # (N, 3) in [0,1] + + # Restore 4D + PCs = PCs.reshape(B1, B2, B3, k) + + # Output + if return_uint8: + out = (PCs * 255.0).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + else: + out = PCs.clamp(0, 1).to(torch.float32).cpu().numpy() + + return out + + +class PCARGBVisualizer: + """ + Stable PCA→RGB for video features shaped (T, H, W, D) or a single frame (H, W, D). + - Global mean/V3 reference for stable colors + - Per-frame PCA with Procrustes alignment to V3_ref (basis_mode='procrustes') + - Percentile normalization with global or EMA stats (time-only, no spatial smoothing) + """ + + def __init__( + self, + device=None, + q_oversample: int = 16, + clip_percent: float = 10.0, + return_uint8: bool = False, + enable_autocast_bf16: bool = True, + basis_mode: str = "procrustes", # 'fixed' | 'procrustes' + percentile_mode: str = "ema", # 'global' | 'ema' + ema_alpha: float = 0.1, + denom_eps: float = 1e-4, + ): + assert 0.0 <= clip_percent < 50.0 + assert basis_mode in ("fixed", "procrustes") + assert percentile_mode in ("global", "ema") + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.q = max(int(q_oversample), 6) + self.clip_percent = float(clip_percent) + self.return_uint8 = return_uint8 + self.enable_autocast_bf16 = enable_autocast_bf16 + self.basis_mode = basis_mode + self.percentile_mode = percentile_mode + self.ema_alpha = float(ema_alpha) + self.denom_eps = float(denom_eps) + + # reference state + self.mean_ref = None # (1, D) + self.V3_ref = None # (D, 3) + self.lo_ref = None # (3,) + self.hi_ref = None # (3,) + + @torch.no_grad() + def fit_reference(self, frames): + """ + Fit global mean/V3 and initialize percentiles from a reference set. + frames: ndarray (T,H,W,D) or list of (H,W,D) + """ + if isinstance(frames, np.ndarray): + if frames.ndim != 4: + raise ValueError("fit_reference expects (T,H,W,D) ndarray.") + T, H, W, D = frames.shape + X = torch.from_numpy(frames.reshape(T * H * W, D)) + else: # list of (H,W,D) + xs = [torch.from_numpy(x.reshape(-1, x.shape[-1])) for x in frames] + D = xs[0].shape[-1] + X = torch.cat(xs, dim=0) + + X = X.to(self.device, dtype=torch.float32) + X = torch.nan_to_num(X, nan=0.0, posinf=1e6, neginf=-1e6) + + mean = X.mean(0, keepdim=True) + Xc = X - mean + + U, S, V = torch.pca_lowrank(Xc, q=max(self.q, 8), center=False) + V3 = V[:, :3] # (D,3) + + PCs = Xc @ V3 + low = self.clip_percent / 100.0 + high = 1.0 - low + qs = torch.tensor([low, high], device=PCs.device, dtype=PCs.dtype) + qvals = torch.quantile(PCs, q=qs, dim=0) + lo, hi = qvals[0], qvals[1] + + self.mean_ref = mean + self.V3_ref = V3 + if self.percentile_mode == "global": + self.lo_ref, self.hi_ref = lo, hi + else: + self.lo_ref = lo.clone() + self.hi_ref = hi.clone() + + @torch.no_grad() + def _project_with_stable_colors(self, X: torch.Tensor) -> torch.Tensor: + """ + X: (N,D) where N = H*W + Returns PCs_raw: (N,3) using stable basis (fixed or Procrustes-aligned) + """ + assert self.mean_ref is not None and self.V3_ref is not None, "Call fit_reference() first." + X = torch.nan_to_num(X, nan=0.0, posinf=1e6, neginf=-1e6) + Xc = X - self.mean_ref + + if self.basis_mode == "fixed": + V3_used = self.V3_ref + else: + U, S, V = torch.pca_lowrank(Xc, q=max(self.q, 6), center=False) + V3 = V[:, :3] # (D,3) + M = V3.T @ self.V3_ref + Uo, So, Vh = torch.linalg.svd(M) + R = Uo @ Vh + V3_used = V3 @ R + # Optional polarity fix via anchor + a = self.V3_ref.mean(0, keepdim=True) + sign = torch.sign((V3_used * a).sum(0, keepdim=True)).clamp(min=-1) + V3_used = V3_used * sign + + return Xc @ V3_used + + @torch.no_grad() + def _normalize_rgb(self, PCs_raw: torch.Tensor) -> torch.Tensor: + assert self.lo_ref is not None and self.hi_ref is not None + if self.percentile_mode == "global": + lo, hi = self.lo_ref, self.hi_ref + else: + low = self.clip_percent / 100.0 + high = 1.0 - low + qs = torch.tensor([low, high], device=PCs_raw.device, dtype=PCs_raw.dtype) + qvals = torch.quantile(PCs_raw, q=qs, dim=0) + lo_now, hi_now = qvals[0], qvals[1] + a = self.ema_alpha + self.lo_ref = (1 - a) * self.lo_ref + a * lo_now + self.hi_ref = (1 - a) * self.hi_ref + a * hi_now + lo, hi = self.lo_ref, self.hi_ref + + denom = torch.clamp(hi - lo, min=self.denom_eps) + PCs = torch.clamp(PCs_raw, lo, hi) + PCs = (PCs - lo) / denom + return PCs.clamp_(0, 1) + + @torch.no_grad() + def transform_frame(self, frame: np.ndarray) -> np.ndarray: + """ + frame: (H,W,D) -> (H,W,3) + """ + if frame.ndim != 3: + raise ValueError("transform_frame expects (H,W,D).") + H, W, D = frame.shape + X = torch.from_numpy(frame.reshape(H * W, D)).to(self.device, dtype=torch.float32) + PCs_raw = self._project_with_stable_colors(X) + PCs = self._normalize_rgb(PCs_raw).reshape(H, W, 3) + if self.return_uint8: + return (PCs * 255.0).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + return PCs.to(torch.float32).cpu().numpy() + + @torch.no_grad() + def transform_video(self, frames) -> np.ndarray: + """ + frames: (T,H,W,D) or list of (H,W,D) + returns: (T,H,W,3) + """ + outs = [] + if isinstance(frames, np.ndarray): + if frames.ndim != 4: + raise ValueError("transform_video expects (T,H,W,D).") + T, H, W, D = frames.shape + for t in range(T): + outs.append(self.transform_frame(frames[t])) + else: + for f in frames: + outs.append(self.transform_frame(f)) + return np.stack(outs, axis=0) diff --git a/core/models/depth_anything_3/utils/pose_align.py b/core/models/depth_anything_3/utils/pose_align.py new file mode 100644 index 0000000..708b974 --- /dev/null +++ b/core/models/depth_anything_3/utils/pose_align.py @@ -0,0 +1,347 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +import numpy as np +import torch +from evo.core.trajectory import PosePath3D + +from core.models.depth_anything_3.utils.geometry import affine_inverse, affine_inverse_np + + +def batch_apply_alignment_to_enc( + rots: torch.Tensor, trans: torch.Tensor, scales: torch.Tensor, enc_list: List[torch.Tensor] +): + pass + + +def batch_apply_alignment_to_ext( + rots: torch.Tensor, trans: torch.Tensor, scales: torch.Tensor, ext: torch.Tensor +): + device, _ = ext.device, ext.dtype + if ext.shape[-2:] == (3, 4): + pad = torch.zeros((*ext.shape[:-2], 4, 4), dtype=ext.dtype, device=device) + pad[..., :3, :4] = ext + pad[..., 3, 3] = 1.0 + ext = pad + pose_est = affine_inverse(ext) + pose_new_align_rot = rots[:, None] @ pose_est[..., :3, :3] + pose_new_align_trans = ( + scales[:, None, None] * (rots[:, None] @ pose_est[..., :3, 3:])[..., 0] + trans[:, None] + ) + pose_new_align = torch.zeros_like(ext) + pose_new_align[..., :3, :3] = pose_new_align_rot + pose_new_align[..., :3, 3] = pose_new_align_trans + pose_new_align[..., 3, 3] = 1.0 + return affine_inverse(pose_new_align)[:, :3] + + +def batch_align_poses_umeyama(ext_ref: torch.Tensor, ext_est: torch.Tensor): + device, dtype = ext_ref.device, ext_ref.dtype + assert ext_ref.dtype in [torch.float32, torch.float64] + assert ext_est.dtype in [torch.float32, torch.float64] + assert ext_ref.requires_grad is False + assert ext_est.requires_grad is False + rots, trans, scales = [], [], [] + for b in range(ext_ref.shape[0]): + r, t, s = align_poses_umeyama(ext_ref[b].cpu().numpy(), ext_est[b].cpu().numpy()) + rots.append(torch.from_numpy(r).to(device=device, dtype=dtype)) + trans.append(torch.from_numpy(t).to(device=device, dtype=dtype)) + scales.append(torch.tensor(s, device=device, dtype=dtype)) + return torch.stack(rots), torch.stack(trans), torch.stack(scales) + + +# Dependencies: affine_inverse_np, PosePath3D (maintain consistency with your existing project) + + +def _to44(ext): + if ext.shape[1] == 3: + out = np.eye(4)[None].repeat(len(ext), 0) + out[:, :3, :4] = ext + return out + return ext + + +def _poses_from_ext(ext_ref, ext_est): + ext_ref = _to44(ext_ref) + ext_est = _to44(ext_est) + pose_ref = affine_inverse_np(ext_ref) + pose_est = affine_inverse_np(ext_est) + return pose_ref, pose_est + + +def _umeyama_sim3_from_paths(pose_ref, pose_est): + path_ref = PosePath3D(poses_se3=pose_ref.copy()) + path_est = PosePath3D(poses_se3=pose_est.copy()) + r, t, s = path_est.align(path_ref, correct_scale=True) + pose_est_aligned = np.stack(path_est.poses_se3) + return r, t, s, pose_est_aligned + + +def _apply_sim3_to_poses(poses, r, t, s): + out = poses.copy() + Ri = poses[:, :3, :3] + ti = poses[:, :3, 3] + out[:, :3, :3] = r @ Ri + out[:, :3, 3] = (r @ (s * ti.T)).T + t + return out + + +def _median_nn_thresh(pose_ref, pose_est_aligned): + P_ref = pose_ref[:, :3, 3] + P_est = pose_est_aligned[:, :3, 3] + dists = [] + for p in P_est: + dd = np.linalg.norm(P_ref - p[None, :], axis=1) + dists.append(dd.min()) + return float(np.median(dists)) if dists else 0.0 + + +def _ransac_align_sim3( + pose_ref, pose_est, sub_n=None, inlier_thresh=None, max_iters=10, random_state=None +): + rng = np.random.default_rng(random_state) + N = pose_ref.shape[0] + idx_all = np.arange(N) + if sub_n is None: + sub_n = max(3, (N + 1) // 2) + else: + sub_n = max(3, min(sub_n, N)) + + # Pre-alignment + default threshold + r0, t0, s0, pose_est0 = _umeyama_sim3_from_paths(pose_ref, pose_est) + if inlier_thresh is None: + inlier_thresh = _median_nn_thresh(pose_ref, pose_est0) + + P_ref_all = pose_ref[:, :3, 3] + + best_model = (r0, t0, s0) + best_inliers = None + best_score = (-1, np.inf) # (num_inliers, mean_err) + + for _ in range(max_iters): + sample = rng.choice(idx_all, size=sub_n, replace=False) + try: + r, t, s, _ = _umeyama_sim3_from_paths(pose_ref[sample], pose_est[sample]) + except Exception: + continue + pose_h = _apply_sim3_to_poses(pose_est, r, t, s) + P_h = pose_h[:, :3, 3] + errs = np.linalg.norm(P_h - P_ref_all, axis=1) # Match by same index + inliers = errs <= inlier_thresh + k = int(inliers.sum()) + mean_err = float(errs[inliers].mean()) if k > 0 else np.inf + if (k > best_score[0]) or (k == best_score[0] and mean_err < best_score[1]): + best_score = (k, mean_err) + best_model = (r, t, s) + best_inliers = inliers + + # Fit again with best inliers + if best_inliers is not None and best_inliers.sum() >= 3: + r, t, s, _ = _umeyama_sim3_from_paths(pose_ref[best_inliers], pose_est[best_inliers]) + else: + r, t, s = best_model + return r, t, s + + +def align_poses_umeyama( + ext_ref: np.ndarray, + ext_est: np.ndarray, + return_aligned=False, + ransac=False, + sub_n=None, + inlier_thresh=None, + ransac_max_iters=10, + random_state=None, +): + """ + Align estimated trajectory to reference using Umeyama Sim(3). + Default no RANSAC; if ransac=True, use RANSAC (max iterations default 10). + - sub_n defaults to half the number of frames (rounded up, at least 3) + - inlier_thresh defaults to median of "distance from each estimated pose to + nearest reference pose after pre-alignment" + Returns rotation (3x3), translation (3,), scale; optionally returns aligned extrinsics (4x4). + """ + pose_ref, pose_est = _poses_from_ext(ext_ref, ext_est) + + if not ransac: + r, t, s, pose_est_aligned = _umeyama_sim3_from_paths(pose_ref, pose_est) + else: + r, t, s = _ransac_align_sim3( + pose_ref, + pose_est, + sub_n=sub_n, + inlier_thresh=inlier_thresh, + max_iters=ransac_max_iters, + random_state=random_state, + ) + pose_est_aligned = _apply_sim3_to_poses(pose_est, r, t, s) + + if return_aligned: + ext_est_aligned = affine_inverse_np(pose_est_aligned) + return r, t, s, ext_est_aligned + return r, t, s + + +# def align_poses_umeyama(ext_ref: np.ndarray, ext_est: np.ndarray, return_aligned=False): +# """ +# Align estimated trajectory to reference trajectory using Umeyama Sim(3) +# alignment (via evo PosePath3D). # noqa +# Returns rotation, translation, and scale. +# """ +# # If input extrinsics are 3x4, convert to 4x4 by padding +# if ext_ref.shape[1] == 3: +# ext_ref_ = np.eye(4)[None].repeat(len(ext_ref), 0) +# ext_ref_[:, :3] = ext_ref +# ext_ref = ext_ref_ +# if ext_est.shape[1] == 3: +# ext_est_ = np.eye(4)[None].repeat(len(ext_est), 0) +# ext_est_[:, :3] = ext_est +# ext_est = ext_est_ + +# # Convert to camera poses (inverse extrinsics) +# pose_ref = affine_inverse_np(ext_ref) +# pose_est = affine_inverse_np(ext_est) + +# # Create evo PosePath3D objects +# path_ref = PosePath3D(poses_se3=pose_ref) +# path_est = PosePath3D(poses_se3=pose_est) +# r, t, s = path_est.align(path_ref, correct_scale=True) +# if return_aligned: +# return r, t, s, affine_inverse_np(np.stack(path_est.poses_se3)) +# else: +# return r, t, s + + +def apply_umeyama_alignment_to_ext( + rot: np.ndarray, # (3,3) + trans: np.ndarray, # (3,) or (1,3) + scale: float, + ext_est: np.ndarray, # (...,4,4) or (...,3,4) +) -> np.ndarray: + """ + Apply Sim(3) (R, t, s) to a batch of world-to-camera extrinsics ext_est. + Returns the aligned extrinsics, with the same shape as input. + """ + + # Allow 3x4 extrinsics: pad to 4x4 + if ext_est.shape[-2:] == (3, 4): + pad = np.zeros((*ext_est.shape[:-2], 4, 4), dtype=ext_est.dtype) + pad[..., :3, :4] = ext_est + pad[..., 3, 3] = 1.0 + ext_est = pad + + # Convert world-to-camera to camera-to-world + pose_est = affine_inverse_np(ext_est) # (...,4,4) + R_e = pose_est[..., :3, :3] # (...,3,3) + t_e = pose_est[..., :3, 3] # (...,3) + + # Apply Sim(3) transformation + R_a = np.einsum("ij,...jk->...ik", rot, R_e) # (...,3,3) + t_a = scale * np.einsum("ij,...j->...i", rot, t_e) + trans # (...,3) + + # Assemble the transformed pose + pose_a = np.zeros_like(pose_est) + pose_a[..., :3, :3] = R_a + pose_a[..., :3, 3] = t_a + pose_a[..., 3, 3] = 1.0 + + # Convert back to world-to-camera + return affine_inverse_np(pose_a) + + +def transform_points_sim3(points, rot, trans, scale, inverse=False): + """ + Sim(3) transform point cloud + points: (N, 3) + rot: (3, 3) + trans: (3,) or (1, 3) + scale: float + inverse: Whether to do inverse transform (ref->est) + Returns: (N, 3) + """ + if not inverse: + # Forward: est -> ref + return scale * (points @ rot.T) + trans + else: + # Inverse: ref -> est + return ((points - trans) @ rot) / scale + + +def _rand_rot(): + u1, u2, u3 = np.random.rand(3) + q = np.array( + [ + np.sqrt(1 - u1) * np.sin(2 * np.math.pi * u2), + np.sqrt(1 - u1) * np.cos(2 * np.math.pi * u2), + np.sqrt(u1) * np.sin(2 * np.math.pi * u3), + np.sqrt(u1) * np.cos(2 * np.math.pi * u3), + ] + ) + w, x, y, z = q + return np.array( + [ + [1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)], + [2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)], + [2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)], + ] + ) + + +def _rand_pose(): + R, t = _rand_rot(), np.random.randn(3) + P = np.eye(4) + P[:3, :3] = R + P[:3, 3] = t + return P + + +if __name__ == "__main__": + np.random.seed(42) + # 1. Randomly generate reference trajectory and Sim(3) + N = 8 + pose_ref = np.stack([_rand_pose() for _ in range(N)]) # (N,4,4) cam→world + rot_gt = _rand_rot() + scale_gt = 2.3 + trans_gt = np.random.randn(3) + # 2. Generate estimated trajectory (apply Sim(3)) + pose_est = np.zeros_like(pose_ref) + for i in range(N): + R = pose_ref[i][:3, :3] + t = pose_ref[i][:3, 3] + pose_est[i][:3, :3] = rot_gt @ R + pose_est[i][:3, 3] = scale_gt * (rot_gt @ t) + trans_gt + pose_est[i][3, 3] = 1.0 + # 3. Get extrinsics (world->cam) + ext_ref = affine_inverse_np(pose_ref) + ext_est = affine_inverse_np(pose_est) + # 4. Use umeyama alignment, estimate Sim(3) + r_est, t_est, s_est = align_poses_umeyama(ext_ref, ext_est) + print("GT scale:", scale_gt, "Estimated:", s_est) + print("GT trans:", trans_gt, "Estimated:", t_est) + print("GT rot:\n", rot_gt, "\nEstimated:\n", r_est) + # 5. Random point cloud, in ref frame + num_points = 100 + points_ref = np.random.randn(num_points, 3) + # 6. Use GT Sim(3) inverse transform to est frame + points_est = transform_points_sim3(points_ref, rot_gt, trans_gt, scale_gt, inverse=True) + # 7. Use estimated Sim(3) forward transform back to ref frame + points_ref_recovered = transform_points_sim3(points_est, r_est, t_est, s_est, inverse=False) + # 8. Check error + err = np.abs(points_ref_recovered - points_ref) + print("Point cloud sim3 transform error (mean abs):", err.mean()) + print("Point cloud sim3 transform error (max abs):", err.max()) + assert err.mean() < 1e-6, "Mean sim3 transform error too large!" + assert err.max() < 1e-5, "Max sim3 transform error too large!" + print("Sim(3) point cloud transform & alignment test passed!") diff --git a/core/models/depth_anything_3/utils/ray_utils.py b/core/models/depth_anything_3/utils/ray_utils.py new file mode 100644 index 0000000..6244dc4 --- /dev/null +++ b/core/models/depth_anything_3/utils/ray_utils.py @@ -0,0 +1,523 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from einops import repeat +from .geometry import unproject_depth + + +def compute_optimal_rotation_intrinsics_batch( + rays_origin, rays_target, z_threshold=1e-4, reproj_threshold=0.2, weights=None, + n_sample = None, + n_iter=100, + num_sample_for_ransac=8, + rand_sample_iters_idx=None, +): + """ + Args: + rays_origin (torch.Tensor): (B, N, 3) + rays_target (torch.Tensor): (B, N, 3) + z_threshold (float): Threshold for z value to be considered valid. + + Returns: + R (torch.tensor): (3, 3) + focal_length (torch.tensor): (2,) + principal_point (torch.tensor): (2,) + """ + device = rays_origin.device + B, N, _ = rays_origin.shape + z_mask = torch.logical_and( + torch.abs(rays_target[:, :, 2]) > z_threshold, torch.abs(rays_origin[:, :, 2]) > z_threshold + ) # (B, N, 1) + rays_origin = rays_origin.clone() + rays_target = rays_target.clone() + rays_origin[:, :, 0][z_mask] /= rays_origin[:, :, 2][z_mask] + rays_origin[:, :, 1][z_mask] /= rays_origin[:, :, 2][z_mask] + rays_target[:, :, 0][z_mask] /= rays_target[:, :, 2][z_mask] + rays_target[:, :, 1][z_mask] /= rays_target[:, :, 2][z_mask] + + rays_origin = rays_origin[:, :, :2] + rays_target = rays_target[:, :, :2] + assert weights is not None, "weights must be provided" + weights[~z_mask] = 0 + + A_list = [] + max_chunk_size = 2 + for i in range(0, rays_origin.shape[0], max_chunk_size): + A = ransac_find_homography_weighted_fast_batch( + rays_origin[i:i+max_chunk_size], + rays_target[i:i+max_chunk_size], + weights[i:i+max_chunk_size], + n_iter=n_iter, + n_sample = n_sample, + num_sample_for_ransac=num_sample_for_ransac, + reproj_threshold=reproj_threshold, + rand_sample_iters_idx=rand_sample_iters_idx, + max_inlier_num=8000, + ) + A = A.to(device) + A_need_inv_mask = torch.linalg.det(A) < 0 + A[A_need_inv_mask] = -A[A_need_inv_mask] + A_list.append(A) + + A = torch.cat(A_list, dim=0) + + R_list = [] + f_list = [] + pp_list = [] + for i in range(A.shape[0]): + R, L = ql_decomposition(A[i]) + L = L / L[2][2] + + f = torch.stack((L[0][0], L[1][1])) + pp = torch.stack((L[2][0], L[2][1])) + R_list.append(R) + f_list.append(f) + pp_list.append(pp) + + R = torch.stack(R_list) + f = torch.stack(f_list) + pp = torch.stack(pp_list) + + return R, f, pp + + +# https://www.reddit.com/r/learnmath/comments/v1crd7/linear_algebra_qr_to_ql_decomposition/ +def ql_decomposition(A): + P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float() + A_tilde = torch.matmul(A, P) + Q_tilde, R_tilde = torch.linalg.qr(A_tilde) + Q = torch.matmul(Q_tilde, P) + L = torch.matmul(torch.matmul(P, R_tilde), P) + d = torch.diag(L) + Q[:, 0] *= torch.sign(d[0]) + Q[:, 1] *= torch.sign(d[1]) + Q[:, 2] *= torch.sign(d[2]) + L[0] *= torch.sign(d[0]) + L[1] *= torch.sign(d[1]) + L[2] *= torch.sign(d[2]) + return Q, L + +def find_homography_least_squares_weighted_torch(src_pts, dst_pts, confident_weight): + """ + src_pts: (N,2) source points (torch.Tensor, float32/float64) + dst_pts: (N,2) target points (torch.Tensor, float32/float64) + confident_weight: (N,) weights (torch.Tensor) + Returns: (3,3) homography matrix H (torch.Tensor) + """ + assert src_pts.shape == dst_pts.shape + N = src_pts.shape[0] + if N < 4: + raise ValueError("At least 4 points are required to compute homography.") + assert confident_weight.shape == (N,) + + w = confident_weight.sqrt().unsqueeze(1) # (N,1) + + x = src_pts[:, 0:1] # (N,1) + y = src_pts[:, 1:2] # (N,1) + u = dst_pts[:, 0:1] + v = dst_pts[:, 1:2] + + zeros = torch.zeros_like(x) + + # Construct A matrix (2N, 9) + A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1) + A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1) + A = torch.cat([A1, A2], dim=0) # (2N, 9) + + # SVD + # Note: torch.linalg.svd returns U, S, Vh, where Vh is the transpose of V + _, _, Vh = torch.linalg.svd(A) + H = Vh[-1].reshape(3, 3) + H = H / H[-1, -1] + return H + + +def ransac_find_homography_weighted( + src_pts, + dst_pts, + confident_weight, + n_iter=100, + sample_ratio=0.2, + reproj_threshold=3.0, + num_sample_for_ransac=16, + random_seed=None, +): + """ + RANSAC version of weighted Homography estimation. + Sample 4 points from the top 50% weighted points each time. + reproj_threshold: points with reprojection error less than this value are inliers + Returns: best_H + """ + if random_seed is not None: + torch.manual_seed(random_seed) + N = src_pts.shape[0] + assert N >= 4 + # 1. Select top 50% weighted points + sorted_idx = torch.argsort(confident_weight, descending=True) + n_sample = max(num_sample_for_ransac, int(N * sample_ratio)) + candidate_idx = sorted_idx[:n_sample] + best_inlier_mask = None + best_score = 0 + for _ in range(n_iter): + # 2. Randomly sample 4 points + idx = candidate_idx[torch.randperm(n_sample)[:num_sample_for_ransac]] + # 3. Compute Homography + try: + H = find_homography_least_squares_weighted_torch( + src_pts[idx], dst_pts[idx], confident_weight[idx] + ) + except Exception: + H = torch.eye(3, dtype=src_pts.dtype, device=src_pts.device) + # 4. Compute reprojection error for all points + src_homo = torch.cat( + [src_pts, torch.ones(N, 1, dtype=src_pts.dtype, device=src_pts.device)], dim=1 + ) + proj = (H @ src_homo.T).T + proj = proj[:, :2] / proj[:, 2:3] + error = ((proj - dst_pts) ** 2).sum(dim=1).sqrt() # Euclidean distance + inlier_mask = error < reproj_threshold + total_score = (inlier_mask * confident_weight).sum().item() + n_inlier = inlier_mask.sum().item() + if n_inlier < 4: + continue # At least 4 inliers required for fitting + + if total_score > best_score: + best_score = total_score + best_inlier_mask = inlier_mask + + # 5. Refit Homography using inliers + H_inlier = find_homography_least_squares_weighted_torch( + src_pts[best_inlier_mask], dst_pts[best_inlier_mask], confident_weight[best_inlier_mask] + ) + + return H_inlier + + +def find_homography_least_squares_weighted_torch_batch( + src_pts_batch, dst_pts_batch, confident_weight_batch +): + """ + Batch version of weighted least squares Homography + src_pts_batch: (B, K, 2) + dst_pts_batch: (B, K, 2) + confident_weight_batch: (B, K) + Returns: (B, 3, 3) + """ + B, K, _ = src_pts_batch.shape + w = confident_weight_batch.sqrt().unsqueeze(2) # (B,K,1) + x = src_pts_batch[:, :, 0:1] + y = src_pts_batch[:, :, 1:2] + u = dst_pts_batch[:, :, 0:1] + v = dst_pts_batch[:, :, 1:2] + zeros = torch.zeros_like(x) + A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2) + A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2) + A = torch.cat([A1, A2], dim=1) # (B, 2K, 9) + # SVD: torch.linalg.svd supports batch + _, _, Vh = torch.linalg.svd(A) + H = Vh[:, -1].reshape(B, 3, 3) + H = H / H[:, 2:3, 2:3] + return H + + +def ransac_find_homography_weighted_fast( + src_pts, + dst_pts, + confident_weight, + n_sample, + n_iter=100, + reproj_threshold=3.0, + num_sample_for_ransac=8, + random_seed=None, + rand_sample_iters_idx=None, +): + """ + Batch version of RANSAC weighted Homography estimation. + Returns: H_inlier + """ + if random_seed is not None: + torch.manual_seed(random_seed) + N = src_pts.shape[0] + device = src_pts.device + assert N >= 4 + # 1. Select top weighted points by sample_ratio + sorted_idx = torch.argsort(confident_weight, descending=True) + candidate_idx = sorted_idx[:n_sample] # (n_sample,) + if rand_sample_iters_idx is None: + rand_sample_iters_idx = torch.stack( + [torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)], + dim=0, + ) # (n_iter, num_sample_for_ransac) + # 2. Generate all sampling groups at once + # shape: (n_iter, num_sample_for_ransac) + rand_idx = candidate_idx[rand_sample_iters_idx] # (n_iter, num_sample_for_ransac) + # 3. Construct batch input + src_pts_batch = src_pts[rand_idx] # (n_iter, num_sample_for_ransac, 2) + dst_pts_batch = dst_pts[rand_idx] # (n_iter, num_sample_for_ransac, 2) + confident_weight_batch = confident_weight[rand_idx] # (n_iter, num_sample_for_ransac) + # 4. Batch fit Homography + H_batch = find_homography_least_squares_weighted_torch_batch( + src_pts_batch, dst_pts_batch, confident_weight_batch + ) # (n_iter, 3, 3) + # 5. Batch evaluate inliers for all H + src_homo = torch.cat( + [src_pts, torch.ones(N, 1, dtype=src_pts.dtype, device=src_pts.device)], dim=1 + ) # (N,3) + src_homo_expand = src_homo.unsqueeze(0).expand(n_iter, N, 3) # (n_iter, N, 3) + dst_pts_expand = dst_pts.unsqueeze(0).expand(n_iter, N, 2) # (n_iter, N, 2) + confident_weight_expand = confident_weight.unsqueeze(0).expand(n_iter, N) # (n_iter, N) + # H_batch: (n_iter, 3, 3) + proj = torch.bmm(src_homo_expand, H_batch.transpose(1, 2)) # (n_iter, N, 3) + proj_xy = proj[:, :, :2] / proj[:, :, 2:3] # (n_iter, N, 2) + error = ((proj_xy - dst_pts_expand) ** 2).sum(dim=2).sqrt() # (n_iter, N) + inlier_mask = error < reproj_threshold # (n_iter, N) + total_score = (inlier_mask * confident_weight_expand).sum(dim=1) # (n_iter,) + # 6. Select the sampling group with the highest score + best_idx = torch.argmax(total_score) + best_inlier_mask = inlier_mask[best_idx] # (N,) + inlier_src_pts = src_pts[best_inlier_mask] + inlier_dst_pts = dst_pts[best_inlier_mask] + inlier_confident_weight = confident_weight[best_inlier_mask] + + max_inlier_num = 10000 + sorted_idx = torch.argsort(inlier_confident_weight, descending=True) + + # method 1: sort according to confident_weight, and only keep max_inlier_num pts + # sorted_idx = sorted_idx[:max_inlier_num] + + # method 2: random choose max_inlier_num pts + sorted_idx = sorted_idx[torch.randperm(len(sorted_idx))[:max_inlier_num]] + + inlier_src_pts = inlier_src_pts[sorted_idx] + inlier_dst_pts = inlier_dst_pts[sorted_idx] + inlier_confident_weight = inlier_confident_weight[sorted_idx] + # 7. Refit Homography using inliers + H_inlier = find_homography_least_squares_weighted_torch( + inlier_src_pts, inlier_dst_pts, inlier_confident_weight + ) + return H_inlier + + +def ransac_find_homography_weighted_fast_batch( + src_pts, # (B, N, 3) + dst_pts, # (B, N, 2) + confident_weight, # (B, N) + n_sample, + n_iter=100, + reproj_threshold=3.0, + num_sample_for_ransac=8, + max_inlier_num=10000, + random_seed=None, + rand_sample_iters_idx=None, +): + """ + Batch version of RANSAC weighted Homography estimation (supports batch). + Input: + src_pts: (B, N, 2) + dst_pts: (B, N, 2) + confident_weight: (B, N) + Returns: + H_inlier: (B, 3, 3) + """ + if random_seed is not None: + torch.manual_seed(random_seed) + B, N, _ = src_pts.shape + assert N >= 4 + + device = src_pts.device + + # 1. Select top weighted points by sample_ratio + sorted_idx = torch.argsort(confident_weight, descending=True, dim=1) # (B, N) + candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample) + + # 2. Generate all sampling groups at once + # rand_idx: (B, n_iter, num_sample_for_ransac) + if rand_sample_iters_idx is None: + rand_sample_iters_idx = torch.stack( + [torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)], + dim=0, + ) # (n_iter, num_sample_for_ransac) + + rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, num_sample_for_ransac) + + # 3. Construct batch input + # Indexing method below: (B, n_iter, num_sample_for_ransac, ...) + b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, n_iter, num_sample_for_ransac) + src_pts_batch = src_pts[b_idx, rand_idx] # (B, n_iter, num_sample_for_ransac, 2) + dst_pts_batch = dst_pts[b_idx, rand_idx] # (B, n_iter, num_sample_for_ransac, 2) + confident_weight_batch = confident_weight[b_idx, rand_idx] # (B, n_iter, num_sample_for_ransac) + + # 4. Batch fit Homography + # Need to implement batch version that supports (B, n_iter, num_sample_for_ransac, ...) input + # Output H_batch: (B, n_iter, 3, 3) + cB, cN = src_pts_batch.shape[:2] + H_batch = find_homography_least_squares_weighted_torch_batch( + src_pts_batch.flatten(0, 1), dst_pts_batch.flatten(0, 1), confident_weight_batch.flatten(0, 1) + ) # (B, n_iter, 3, 3) + H_batch = H_batch.unflatten(0, (cB, cN)) + + # 5. Batch evaluate inliers for all H + src_homo = torch.cat( + [src_pts, torch.ones(B, N, 1, dtype=src_pts.dtype, device=src_pts.device)], dim=2 + ) # (B, N, 3) + src_homo_expand = src_homo.unsqueeze(1).expand(B, n_iter, N, 3) # (B, n_iter, N, 3) + dst_pts_expand = dst_pts.unsqueeze(1).expand(B, n_iter, N, 2) # (B, n_iter, N, 2) + confident_weight_expand = confident_weight.unsqueeze(1).expand(B, n_iter, N) # (B, n_iter, N) + + # H_batch: (B, n_iter, 3, 3) + # Need to reshape H_batch to (B*n_iter, 3, 3), src_homo_expand to (B*n_iter, N, 3) + H_batch_flat = H_batch.reshape(-1, 3, 3) + src_homo_expand_flat = src_homo_expand.reshape(-1, N, 3) + proj = torch.bmm(src_homo_expand_flat, H_batch_flat.transpose(1, 2)) # (B*n_iter, N, 3) + proj_xy = proj[:, :, :2] / proj[:, :, 2:3] # (B*n_iter, N, 2) + proj_xy = proj_xy.reshape(B, n_iter, N, 2) + error = ((proj_xy - dst_pts_expand) ** 2).sum(dim=3).sqrt() # (B, n_iter, N) + inlier_mask = error < reproj_threshold # (B, n_iter, N) + total_score = (inlier_mask * confident_weight_expand).sum(dim=2) # (B, n_iter) + + # 6. Select the sampling group with the highest score + best_idx = torch.argmax(total_score, dim=1) # (B,) + best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx] # (B, N) + + # 7. Refit Homography using inliers + H_inlier_list = [] + for b in range(B): + mask = best_inlier_mask[b] + inlier_src_pts = src_pts[b][mask] # (?, 3) + inlier_dst_pts = dst_pts[b][mask] # (?, 2) + inlier_confident_weight = confident_weight[b][mask] # (?) + + sorted_idx = torch.argsort(inlier_confident_weight, descending=True) + # # method 1: sort according to confident_weight, and only keep max_inlier_num pts + # sorted_idx = sorted_idx[:max_inlier_num] + # method 2: random choose max_inlier_num pts + if len(sorted_idx) > max_inlier_num: + # random choose from first 95% confident pts + keep_len = max(int(len(sorted_idx) * 0.95), max_inlier_num) + sorted_idx = sorted_idx[:keep_len] + perm = torch.randperm(len(sorted_idx), device=device)[:max_inlier_num] + sorted_idx = sorted_idx[perm] + inlier_src_pts = inlier_src_pts[sorted_idx] + inlier_dst_pts = inlier_dst_pts[sorted_idx] + inlier_confident_weight = inlier_confident_weight[sorted_idx] + + H_inlier = find_homography_least_squares_weighted_torch( + inlier_src_pts, inlier_dst_pts, inlier_confident_weight + ) # (3, 3) + H_inlier_list.append(H_inlier) + H_inlier = torch.stack(H_inlier_list, dim=0) # (B, 3, 3) + return H_inlier + +def get_params_for_ransac(N, device): + n_iter=100 + sample_ratio=0.3 + num_sample_for_ransac=8 + n_sample = max(num_sample_for_ransac, int(N * sample_ratio)) + rand_sample_iters_idx = torch.stack( + [torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)], + dim=0, + ) # (n_iter, num_sample_for_ransac) + return n_iter, num_sample_for_ransac, n_sample, rand_sample_iters_idx + + +def camray_to_caminfo(camray, confidence=None, reproj_threshold=0.2, training=False): + """ + Args: + camray: (B, S, num_patches_y, num_patches_x, 6) + confidence: (B, S, num_patches_y, num_patches_x) + Returns: + R: (B, S, 3, 3) + T: (B, S, 3) + focal_lengths: (B, S, 2) + principal_points: (B, S, 2) + """ + if confidence is None: + confidence = torch.ones_like(camray[:, :, :, :, 0]) + B, S, num_patches_y, num_patches_x, _ = camray.shape + # identity K, assume imw=imh=2.0 + I_K = torch.eye(3, dtype=camray.dtype, device=camray.device) + I_K[0, 2] = 1.0 + I_K[1, 2] = 1.0 + # repeat I_K to match camray + I_K = I_K.unsqueeze(0).unsqueeze(0).expand(B, S, -1, -1) + + cam_plane_depth = torch.ones( + B, S, num_patches_y, num_patches_x, 1, dtype=camray.dtype, device=camray.device + ) + I_cam_plane_unproj = unproject_depth( + cam_plane_depth, + I_K, + c2w=None, + ixt_normalized=True, + num_patches_x=num_patches_x, + num_patches_y=num_patches_y, + ) # (B, S, num_patches_y, num_patches_x, 3) + + camray = camray.flatten(0, 1).flatten(1, 2) # (B*S, num_patches_y*num_patches_x, 6) + I_cam_plane_unproj = I_cam_plane_unproj.flatten(0, 1).flatten( + 1, 2 + ) # (B*S, num_patches_y*num_patches_x, 3) + confidence = confidence.flatten(0, 1).flatten(1, 2) # (B*S, num_patches_y*num_patches_x) + + # Compute optimal rotation to align rays + N = camray.shape[-2] + device = camray.device + n_iter, num_sample_for_ransac, n_sample, rand_sample_iters_idx = get_params_for_ransac(N, device) + + # Use batch processing (confidence is guaranteed to be not None at this point) + if training: + camray = camray.clone().detach() + I_cam_plane_unproj = I_cam_plane_unproj.clone().detach() + confidence = confidence.clone().detach() + R, focal_lengths, principal_points = compute_optimal_rotation_intrinsics_batch( + I_cam_plane_unproj, + camray[:, :, :3], + reproj_threshold=reproj_threshold, + weights=confidence, + n_sample = n_sample, + n_iter=n_iter, + num_sample_for_ransac=num_sample_for_ransac, + rand_sample_iters_idx=rand_sample_iters_idx, + ) + + T = torch.sum(camray[:, :, 3:] * confidence.unsqueeze(-1), dim=1) / torch.sum( + confidence, dim=-1, keepdim=True + ) + + R = R.reshape(B, S, 3, 3) + T = T.reshape(B, S, 3) + focal_lengths = focal_lengths.reshape(B, S, 2) + principal_points = principal_points.reshape(B, S, 2) + + return R, T, 1.0 / focal_lengths, principal_points + 1.0 + +def get_extrinsic_from_camray(camray, conf, patch_size_y, patch_size_x, training=False): + pred_R, pred_T, pred_focal_lengths, pred_principal_points = camray_to_caminfo( + camray, confidence=conf.squeeze(-1), training=training + ) + + pred_extrinsic = torch.cat( + [ + torch.cat([pred_R, pred_T.unsqueeze(-1)], dim=-1), + repeat( + torch.tensor([0, 0, 0, 1], dtype=pred_R.dtype, device=pred_R.device), + "c -> b s 1 c", + b=pred_R.shape[0], + s=pred_R.shape[1], + ), + ], + dim=-2, + ) # B, S, 4, 4 + return pred_extrinsic, pred_focal_lengths, pred_principal_points \ No newline at end of file diff --git a/core/models/depth_anything_3/utils/read_write_model.py b/core/models/depth_anything_3/utils/read_write_model.py new file mode 100644 index 0000000..4b4cf19 --- /dev/null +++ b/core/models/depth_anything_3/utils/read_write_model.py @@ -0,0 +1,585 @@ +# Copyright (c), ETH Zurich and UNC Chapel Hill. +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# All rights reserved. +# +# This file has been modified by ByteDance Ltd. and/or its affiliates. on 11/05/2025 +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + + +import argparse +import collections +import os +import struct +import numpy as np + +CameraModel = collections.namedtuple("CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] +) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] +) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), +} +CAMERA_MODEL_IDS = {camera_model.model_id: camera_model for camera_model in CAMERA_MODELS} +CAMERA_MODEL_NAMES = {camera_model.model_name: camera_model for camera_model in CAMERA_MODELS} + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): + """pack and write to a binary file. + :param fid: + :param data: data to send, if multiple elements are sent at the same time, + they should be encapsuled either in a list or a tuple + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + should be the same length as the data list or tuple + :param endian_character: Any of {@, =, <, >, !} + """ + if isinstance(data, (list, tuple)): + bytes = struct.pack(endian_character + format_char_sequence, *data) + else: + bytes = struct.pack(endian_character + format_char_sequence, data) + fid.write(bytes) + + +def read_cameras_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path) as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera( + id=camera_id, + model=model, + width=width, + height=height, + params=params, + ) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes(fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes( + fid, + num_bytes=8 * num_params, + format_char_sequence="d" * num_params, + ) + cameras[camera_id] = Camera( + id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params), + ) + assert len(cameras) == num_cameras + return cameras + + +def write_cameras_text(cameras, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + HEADER = ( + "# Camera list with one line of data per camera:\n" + + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + + f"# Number of cameras: {len(cameras)}\n" + ) + with open(path, "w") as fid: + fid.write(HEADER) + for _, cam in cameras.items(): + to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] + line = " ".join([str(elem) for elem in to_write]) + fid.write(line + "\n") + + +def write_cameras_binary(cameras, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(cameras), "Q") + for _, cam in cameras.items(): + model_id = CAMERA_MODEL_NAMES[cam.model].model_id + camera_properties = [cam.id, model_id, cam.width, cam.height] + write_next_bytes(fid, camera_properties, "iiQQ") + for p in cam.params: + write_next_bytes(fid, float(p), "d") + return cameras + + +def read_images_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path) as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack( + [ + tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3])), + ] + ) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi" + ) + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + binary_image_name = b"" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + binary_image_name += current_char + current_char = read_next_bytes(fid, 1, "c")[0] + image_name = binary_image_name.decode("utf-8") + num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes( + fid, + num_bytes=24 * num_points2D, + format_char_sequence="ddq" * num_points2D, + ) + xys = np.column_stack( + [ + tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3])), + ] + ) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def write_images_text(images, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + if len(images) == 0: + mean_observations = 0 + else: + mean_observations = sum((len(img.point3D_ids) for _, img in images.items())) / len(images) + HEADER = ( + "# Image list with two lines of data per image:\n" + + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + + "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + + "# Number of images: {}, mean observations per image: {}\n".format( + len(images), mean_observations + ) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, img in images.items(): + image_header = [ + img.id, + *img.qvec, + *img.tvec, + img.camera_id, + img.name, + ] + first_line = " ".join(map(str, image_header)) + fid.write(first_line + "\n") + + points_strings = [] + for xy, point3D_id in zip(img.xys, img.point3D_ids): + points_strings.append(" ".join(map(str, [*xy, point3D_id]))) + fid.write(" ".join(points_strings) + "\n") + + +def write_images_binary(images, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for xy, p3d_id in zip(img.xys, img.point3D_ids): + write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_points3D_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path) as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def read_points3D_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd" + ) + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, + num_bytes=8 * track_length, + format_char_sequence="ii" * track_length, + ) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def write_points3D_text(points3D, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + if len(points3D) == 0: + mean_track_length = 0 + else: + mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items())) / len(points3D) + HEADER = ( + "# 3D point list with one line of data per point:\n" + + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + + "# Number of points: {}, mean track length: {}\n".format( + len(points3D), mean_track_length + ) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, pt in points3D.items(): + point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] + fid.write(" ".join(map(str, point_header)) + " ") + track_strings = [] + for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): + track_strings.append(" ".join(map(str, [image_id, point2D]))) + fid.write(" ".join(track_strings) + "\n") + + +def write_points3D_binary(points3D, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + + +def detect_model_format(path, ext): + if ( + os.path.isfile(os.path.join(path, "cameras" + ext)) + and os.path.isfile(os.path.join(path, "images" + ext)) + and os.path.isfile(os.path.join(path, "points3D" + ext)) + ): + print("Detected model format: '" + ext + "'") + return True + + return False + + +def read_model(path, ext=""): + # try to detect the extension automatically + if ext == "": + if detect_model_format(path, ".bin"): + ext = ".bin" + elif detect_model_format(path, ".txt"): + ext = ".txt" + else: + print("Provide model format: '.bin' or '.txt'") + return + + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def write_model(cameras, images, points3D, path, ext=".bin"): + if ext == ".txt": + write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) + write_images_text(images, os.path.join(path, "images" + ext)) + write_points3D_text(points3D, os.path.join(path, "points3D") + ext) + else: + write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) + write_images_binary(images, os.path.join(path, "images" + ext)) + write_points3D_binary(points3D, os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array( + [ + [ + 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], + ], + [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], + ], + [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, + ], + ] + ) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = ( + np.array( + [ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ] + ) + / 3.0 + ) + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +def main(): + parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models") + parser.add_argument("--input_model", help="path to input model folder") + parser.add_argument( + "--input_format", + choices=[".bin", ".txt"], + help="input model format", + default="", + ) + parser.add_argument("--output_model", help="path to output model folder") + parser.add_argument( + "--output_format", + choices=[".bin", ".txt"], + help="output model format", + default=".txt", + ) + args = parser.parse_args() + + cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model( + cameras, + images, + points3D, + path=args.output_model, + ext=args.output_format, + ) + + +if __name__ == "__main__": + main() diff --git a/core/models/depth_anything_3/utils/registry.py b/core/models/depth_anything_3/utils/registry.py new file mode 100644 index 0000000..7db16d5 --- /dev/null +++ b/core/models/depth_anything_3/utils/registry.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from addict import Dict + + +class Registry(Dict[str, Any]): + def __init__(self): + super().__init__() + self._map = Dict({}) + + def register(self, name=None): + def decorator(cls): + key = name or cls.__name__ + self._map[key] = cls + return cls + + return decorator + + def get(self, name): + return self._map[name] + + def all(self): + return self._map diff --git a/core/models/depth_anything_3/utils/sh_helpers.py b/core/models/depth_anything_3/utils/sh_helpers.py new file mode 100644 index 0000000..75040cb --- /dev/null +++ b/core/models/depth_anything_3/utils/sh_helpers.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from math import isqrt +import torch +from einops import einsum + +try: + from e3nn.o3 import matrix_to_angles, wigner_D +except ImportError: + from core.models.depth_anything_3.utils.logger import logger + + logger.warn("Dependency 'e3nn' not found. Required for rotating the camera space SH coeff") + + +def project_to_so3_strict(M: torch.Tensor) -> torch.Tensor: + if M.shape[-2:] != (3, 3): + raise ValueError("Input must be a batch of 3x3 matrices (i.e., shape [..., 3, 3]).") + + # 1. Compute SVD + U, S, Vh = torch.linalg.svd(M) + V = Vh.mH + + # 2. Handle reflection case (det = -1) + det_U = torch.det(U) + det_V = torch.det(V) + is_reflection = (det_U * det_V) < 0 + correction_sign = torch.where( + is_reflection[..., None], + torch.tensor([1, 1, -1.0], device=M.device, dtype=M.dtype), + torch.tensor([1, 1, 1.0], device=M.device, dtype=M.dtype), + ) + correction_matrix = torch.diag_embed(correction_sign) + U_corrected = U @ correction_matrix + R_so3_initial = U_corrected @ V.transpose(-2, -1) + + # 3. Explicitly ensure determinant is 1 (or extremely close) + current_det = torch.det(R_so3_initial) + det_correction_factor = torch.pow(current_det, -1 / 3)[..., None, None] + R_so3_final = R_so3_initial * det_correction_factor + + return R_so3_final + + +def rotate_sh( + sh_coefficients: torch.Tensor, # "*#batch n" + rotations: torch.Tensor, # "*#batch 3 3" +) -> torch.Tensor: # "*batch n" + # https://github.com/graphdeco-inria/gaussian-splatting/issues/176#issuecomment-2452412653 + device = sh_coefficients.device + dtype = sh_coefficients.dtype + + *_, n = sh_coefficients.shape + + with torch.autocast(device_type=rotations.device.type, enabled=False): + rotations_float32 = rotations.to(torch.float32) + + # switch axes: yzx -> xyz + P = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]]).unsqueeze(0).to(rotations_float32) + permuted_rotations = torch.linalg.inv(P) @ rotations_float32 @ P + + # ensure rotation has det == 1 in float32 type + permuted_rotations_so3 = project_to_so3_strict(permuted_rotations) + + alpha, beta, gamma = matrix_to_angles(permuted_rotations_so3) + result = [] + for degree in range(isqrt(n)): + with torch.device(device): + sh_rotations = wigner_D(degree, alpha, -beta, gamma).type(dtype) + sh_rotated = einsum( + sh_rotations, + sh_coefficients[..., degree**2 : (degree + 1) ** 2], + "... i j, ... j -> ... i", + ) + result.append(sh_rotated) + + return torch.cat(result, dim=-1) diff --git a/core/models/depth_anything_3/utils/visualize.py b/core/models/depth_anything_3/utils/visualize.py new file mode 100644 index 0000000..38e9e0d --- /dev/null +++ b/core/models/depth_anything_3/utils/visualize.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import matplotlib +import numpy as np +import torch +from einops import rearrange + +from core.models.depth_anything_3.utils.logger import logger + + +def visualize_depth( + depth: np.ndarray, + depth_min=None, + depth_max=None, + percentile=2, + ret_minmax=False, + ret_type=np.uint8, + cmap="Spectral", +): + """ + Visualize a depth map using a colormap. + + Args: + depth: Input depth map array + depth_min: Minimum depth value for normalization. If None, uses percentile + depth_max: Maximum depth value for normalization. If None, uses percentile + percentile: Percentile for min/max computation if not provided + ret_minmax: Whether to return min/max depth values + ret_type: Return array type (uint8 or float) + cmap: Matplotlib colormap name to use + + Returns: + Colored depth visualization as numpy array + If ret_minmax=True, also returns depth_min and depth_max + """ + depth = depth.copy() + depth.copy() + valid_mask = depth > 0 + depth[valid_mask] = 1 / depth[valid_mask] + if depth_min is None: + if valid_mask.sum() <= 10: + depth_min = 0 + else: + depth_min = np.percentile(depth[valid_mask], percentile) + if depth_max is None: + if valid_mask.sum() <= 10: + depth_max = 0 + else: + depth_max = np.percentile(depth[valid_mask], 100 - percentile) + if depth_min == depth_max: + depth_min = depth_min - 1e-6 + depth_max = depth_max + 1e-6 + cm = matplotlib.colormaps[cmap] + depth = ((depth - depth_min) / (depth_max - depth_min)).clip(0, 1) + depth = 1 - depth + img_colored_np = cm(depth[None], bytes=False)[:, :, :, 0:3] # value from 0 to 1 + if ret_type == np.uint8: + img_colored_np = (img_colored_np[0] * 255.0).astype(np.uint8) + elif ret_type == np.float32 or ret_type == np.float64: + img_colored_np = img_colored_np[0] + else: + raise ValueError(f"Invalid return type: {ret_type}") + if ret_minmax: + return img_colored_np, depth_min, depth_max + else: + return img_colored_np + + +# GS video rendering visulization function, since it operates in Tensor space... + + +def vis_depth_map_tensor( + result: torch.Tensor, # "*batch height width" + color_map: str = "Spectral", +) -> torch.Tensor: # "*batch 3 height with" + """ + Color-map the depth map. + """ + far = result.reshape(-1)[:16_000_000].float().quantile(0.99).log().to(result) + try: + near = result[result > 0][:16_000_000].float().quantile(0.01).log().to(result) + except (RuntimeError, ValueError) as e: + logger.error(f"No valid depth values found. Reason: {e}") + near = torch.zeros_like(far) + result = result.log() + result = (result - near) / (far - near) + return apply_color_map_to_image(result, color_map) + + +def apply_color_map( + x: torch.Tensor, # " *batch" + color_map: str = "inferno", +) -> torch.Tensor: # "*batch 3" + cmap = matplotlib.cm.get_cmap(color_map) + + # Convert to NumPy so that Matplotlib color maps can be used. + mapped = cmap(x.float().detach().clip(min=0, max=1).cpu().numpy())[..., :3] + + # Convert back to the original format. + return torch.tensor(mapped, device=x.device, dtype=torch.float32) + + +def apply_color_map_to_image( + image: torch.Tensor, # "*batch height width" + color_map: str = "inferno", +) -> torch.Tensor: # "*batch 3 height with" + image = apply_color_map(image, color_map) + return rearrange(image, "... h w c -> ... c h w") diff --git a/core/models/video_depth_anything/__pycache__/dinov2.cpython-313.pyc b/core/models/video_depth_anything/__pycache__/dinov2.cpython-313.pyc new file mode 100644 index 0000000..e34b9ba Binary files /dev/null and b/core/models/video_depth_anything/__pycache__/dinov2.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/__pycache__/dpt.cpython-313.pyc b/core/models/video_depth_anything/__pycache__/dpt.cpython-313.pyc new file mode 100644 index 0000000..d567ad7 Binary files /dev/null and b/core/models/video_depth_anything/__pycache__/dpt.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/__pycache__/dpt_temporal.cpython-313.pyc b/core/models/video_depth_anything/__pycache__/dpt_temporal.cpython-313.pyc new file mode 100644 index 0000000..ef452ba Binary files /dev/null and b/core/models/video_depth_anything/__pycache__/dpt_temporal.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/__pycache__/video_depth.cpython-313.pyc b/core/models/video_depth_anything/__pycache__/video_depth.cpython-313.pyc new file mode 100644 index 0000000..6b773f6 Binary files /dev/null and b/core/models/video_depth_anything/__pycache__/video_depth.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/dinov2.py b/core/models/video_depth_anything/dinov2.py new file mode 100644 index 0000000..ddd60f5 --- /dev/null +++ b/core/models/video_depth_anything/dinov2.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from core.models.video_depth_anything.dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + # w0, h0 = w0 + 0.1, h0 + 0.1 + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + # (int(w0), int(h0)), # to solve the upsampling shape issue + mode="bicubic", + antialias=self.interpolate_antialias + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def DINOv2(model_name): + model_zoo = { + "vits": vit_small, + "vitb": vit_base, + "vitl": vit_large, + "vitg": vit_giant2 + } + + return model_zoo[model_name]( + img_size=518, + patch_size=14, + init_values=1.0, + ffn_layer="mlp" if model_name != "vitg" else "swiglufused", + block_chunks=0, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1 + ) diff --git a/core/models/video_depth_anything/dinov2_layers/__init__.py b/core/models/video_depth_anything/dinov2_layers/__init__.py new file mode 100644 index 0000000..8120f4b --- /dev/null +++ b/core/models/video_depth_anything/dinov2_layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/core/models/video_depth_anything/dinov2_layers/__pycache__/__init__.cpython-313.pyc b/core/models/video_depth_anything/dinov2_layers/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..f67cdb6 Binary files /dev/null and b/core/models/video_depth_anything/dinov2_layers/__pycache__/__init__.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/dinov2_layers/__pycache__/attention.cpython-313.pyc b/core/models/video_depth_anything/dinov2_layers/__pycache__/attention.cpython-313.pyc new file mode 100644 index 0000000..bbdd851 Binary files /dev/null and b/core/models/video_depth_anything/dinov2_layers/__pycache__/attention.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/dinov2_layers/__pycache__/block.cpython-313.pyc b/core/models/video_depth_anything/dinov2_layers/__pycache__/block.cpython-313.pyc new file mode 100644 index 0000000..9646a6c Binary files /dev/null and b/core/models/video_depth_anything/dinov2_layers/__pycache__/block.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/dinov2_layers/__pycache__/drop_path.cpython-313.pyc b/core/models/video_depth_anything/dinov2_layers/__pycache__/drop_path.cpython-313.pyc new file mode 100644 index 0000000..20e72ab Binary files /dev/null and b/core/models/video_depth_anything/dinov2_layers/__pycache__/drop_path.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/dinov2_layers/__pycache__/layer_scale.cpython-313.pyc b/core/models/video_depth_anything/dinov2_layers/__pycache__/layer_scale.cpython-313.pyc new file mode 100644 index 0000000..3383a08 Binary files /dev/null and b/core/models/video_depth_anything/dinov2_layers/__pycache__/layer_scale.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/dinov2_layers/__pycache__/mlp.cpython-313.pyc b/core/models/video_depth_anything/dinov2_layers/__pycache__/mlp.cpython-313.pyc new file mode 100644 index 0000000..b40f1b2 Binary files /dev/null and b/core/models/video_depth_anything/dinov2_layers/__pycache__/mlp.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/dinov2_layers/__pycache__/patch_embed.cpython-313.pyc b/core/models/video_depth_anything/dinov2_layers/__pycache__/patch_embed.cpython-313.pyc new file mode 100644 index 0000000..faf835a Binary files /dev/null and b/core/models/video_depth_anything/dinov2_layers/__pycache__/patch_embed.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/dinov2_layers/__pycache__/swiglu_ffn.cpython-313.pyc b/core/models/video_depth_anything/dinov2_layers/__pycache__/swiglu_ffn.cpython-313.pyc new file mode 100644 index 0000000..cef2986 Binary files /dev/null and b/core/models/video_depth_anything/dinov2_layers/__pycache__/swiglu_ffn.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/dinov2_layers/attention.py b/core/models/video_depth_anything/dinov2_layers/attention.py new file mode 100644 index 0000000..815a2bf --- /dev/null +++ b/core/models/video_depth_anything/dinov2_layers/attention.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + \ No newline at end of file diff --git a/core/models/video_depth_anything/dinov2_layers/block.py b/core/models/video_depth_anything/dinov2_layers/block.py new file mode 100644 index 0000000..1dc8b29 --- /dev/null +++ b/core/models/video_depth_anything/dinov2_layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from core.models.video_depth_anything.dinov2_layers.attention import Attention, MemEffAttention +from core.models.video_depth_anything.dinov2_layers.drop_path import DropPath +from core.models.video_depth_anything.dinov2_layers.layer_scale import LayerScale +from core.models.video_depth_anything.dinov2_layers.mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/core/models/video_depth_anything/dinov2_layers/drop_path.py b/core/models/video_depth_anything/dinov2_layers/drop_path.py new file mode 100644 index 0000000..af05625 --- /dev/null +++ b/core/models/video_depth_anything/dinov2_layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/core/models/video_depth_anything/dinov2_layers/layer_scale.py b/core/models/video_depth_anything/dinov2_layers/layer_scale.py new file mode 100644 index 0000000..ca5daa5 --- /dev/null +++ b/core/models/video_depth_anything/dinov2_layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/core/models/video_depth_anything/dinov2_layers/mlp.py b/core/models/video_depth_anything/dinov2_layers/mlp.py new file mode 100644 index 0000000..5e4b315 --- /dev/null +++ b/core/models/video_depth_anything/dinov2_layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/core/models/video_depth_anything/dinov2_layers/patch_embed.py b/core/models/video_depth_anything/dinov2_layers/patch_embed.py new file mode 100644 index 0000000..574abe4 --- /dev/null +++ b/core/models/video_depth_anything/dinov2_layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/core/models/video_depth_anything/dinov2_layers/swiglu_ffn.py b/core/models/video_depth_anything/dinov2_layers/swiglu_ffn.py new file mode 100644 index 0000000..b3324b2 --- /dev/null +++ b/core/models/video_depth_anything/dinov2_layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/core/models/video_depth_anything/dpt.py b/core/models/video_depth_anything/dpt.py new file mode 100644 index 0000000..0d6ccf7 --- /dev/null +++ b/core/models/video_depth_anything/dpt.py @@ -0,0 +1,160 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from core.models.video_depth_anything.util.blocks import FeatureFusionBlock, _make_scratch + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class ConvBlock(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + + self.conv_block = nn.Sequential( + nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_feature), + nn.ReLU(True) + ) + + def forward(self, x): + return self.conv_block(x) + + +class DPTHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + + return out + \ No newline at end of file diff --git a/core/models/video_depth_anything/dpt_temporal.py b/core/models/video_depth_anything/dpt_temporal.py new file mode 100644 index 0000000..f7fd892 --- /dev/null +++ b/core/models/video_depth_anything/dpt_temporal.py @@ -0,0 +1,125 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F +import torch.nn as nn +from core.models.video_depth_anything.dpt import DPTHead +from core.models.video_depth_anything.motion_module.motion_module import TemporalModule +from easydict import EasyDict + + +class DPTHeadTemporal(DPTHead): + def __init__(self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False, + num_frames=32, + pe='ape' + ): + super().__init__(in_channels, features, use_bn, out_channels, use_clstoken) + + assert num_frames > 0 + motion_module_kwargs = EasyDict(num_attention_heads = 8, + num_transformer_block = 1, + num_attention_blocks = 2, + temporal_max_len = num_frames, + zero_initialize = True, + pos_embedding_type = pe) + + self.motion_modules = nn.ModuleList([ + TemporalModule(in_channels=out_channels[2], + **motion_module_kwargs), + TemporalModule(in_channels=out_channels[3], + **motion_module_kwargs), + TemporalModule(in_channels=features, + **motion_module_kwargs), + TemporalModule(in_channels=features, + **motion_module_kwargs) + ]) + + def forward(self, out_features, patch_h, patch_w, frame_length, micro_batch_size=4, cached_hidden_state_list=None): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)).contiguous() + + B, T = x.shape[0] // frame_length, frame_length + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + B, T = layer_1.shape[0] // frame_length, frame_length + if cached_hidden_state_list is not None: + N = len(cached_hidden_state_list) // len(self.motion_modules) + else: + N = 0 + + layer_3, h0 = self.motion_modules[0](layer_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None, cached_hidden_state_list[0:N] if N else None) + layer_3 = layer_3.permute(0, 2, 1, 3, 4).flatten(0, 1) + layer_4, h1 = self.motion_modules[1](layer_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None, cached_hidden_state_list[N:2*N] if N else None) + layer_4 = layer_4.permute(0, 2, 1, 3, 4).flatten(0, 1) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_4, h2 = self.motion_modules[2](path_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None, cached_hidden_state_list[2*N:3*N] if N else None) + path_4 = path_4.permute(0, 2, 1, 3, 4).flatten(0, 1) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_3, h3 = self.motion_modules[3](path_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None, cached_hidden_state_list[3*N:] if N else None) + path_3 = path_3.permute(0, 2, 1, 3, 4).flatten(0, 1) + + batch_size = layer_1_rn.shape[0] + if batch_size <= micro_batch_size or batch_size % micro_batch_size != 0: + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate( + out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True + ) + ori_type = out.dtype + with torch.autocast(device_type="cuda", enabled=False): + out = self.scratch.output_conv2(out.float()) + + output = out.to(ori_type) + else: + ret = [] + for i in range(0, batch_size, micro_batch_size): + path_2 = self.scratch.refinenet2(path_3[i:i + micro_batch_size], layer_2_rn[i:i + micro_batch_size], size=layer_1_rn[i:i + micro_batch_size].shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn[i:i + micro_batch_size]) + out = self.scratch.output_conv1(path_1) + out = F.interpolate( + out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True + ) + ori_type = out.dtype + with torch.autocast(device_type="cuda", enabled=False): + out = self.scratch.output_conv2(out.float()) + ret.append(out.to(ori_type)) + output = torch.cat(ret, dim=0) + + return output, h0 + h1 + h2 + h3 diff --git a/core/models/video_depth_anything/motion_module/__pycache__/attention.cpython-313.pyc b/core/models/video_depth_anything/motion_module/__pycache__/attention.cpython-313.pyc new file mode 100644 index 0000000..da97d04 Binary files /dev/null and b/core/models/video_depth_anything/motion_module/__pycache__/attention.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/motion_module/__pycache__/motion_module.cpython-313.pyc b/core/models/video_depth_anything/motion_module/__pycache__/motion_module.cpython-313.pyc new file mode 100644 index 0000000..9a6fe1c Binary files /dev/null and b/core/models/video_depth_anything/motion_module/__pycache__/motion_module.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/motion_module/attention.py b/core/models/video_depth_anything/motion_module/attention.py new file mode 100644 index 0000000..41f551b --- /dev/null +++ b/core/models/video_depth_anything/motion_module/attention.py @@ -0,0 +1,429 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +try: + import xformers + import xformers.ops + + XFORMERS_AVAILABLE = True +except ImportError: + print("xFormers not available") + XFORMERS_AVAILABLE = False + + +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.upcast_efficient_attention = False + + self.scale = dim_head**-0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous() + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size).contiguous() + return tensor + + def reshape_heads_to_4d(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous() + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim).contiguous() + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size).contiguous() + return tensor + + def reshape_4d_to_heads(self, tensor): + batch_size, seq_len, head_size, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, dim * head_size).contiguous() + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + if self.upcast_efficient_attention: + org_dtype = query.dtype + query = query.float() + key = key.float() + value = value.float() + if attention_mask is not None: + attention_mask = attention_mask.float() + hidden_states = self._memory_efficient_attention_split(query, key, value, attention_mask) + + if self.upcast_efficient_attention: + hidden_states = hidden_states.to(org_dtype) + + hidden_states = self.reshape_4d_to_heads(hidden_states) + return hidden_states + + # print("Errror: no xformers") + # raise NotImplementedError + + def _memory_efficient_attention_split(self, query, key, value, attention_mask): + batch_size = query.shape[0] + max_batch_size = 65535 + num_batches = (batch_size + max_batch_size - 1) // max_batch_size + results = [] + for i in range(num_batches): + start_idx = i * max_batch_size + end_idx = min((i + 1) * max_batch_size, batch_size) + query_batch = query[start_idx:end_idx] + key_batch = key[start_idx:end_idx] + value_batch = value[start_idx:end_idx] + if attention_mask is not None: + attention_mask_batch = attention_mask[start_idx:end_idx] + else: + attention_mask_batch = None + result = xformers.ops.memory_efficient_attention(query_batch, key_batch, value_batch, attn_bias=attention_mask_batch) + results.append(result) + full_result = torch.cat(results, dim=0) + return full_result + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class GELU(nn.Module): + r""" + GELU activation function + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + """ + The approximate form of Gaussian Error Linear Unit (GELU) + + For more details, see section 2: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x): + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2).contiguous()) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2).contiguous()) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2) + return xq_out.type_as(xq), xk_out.type_as(xk) diff --git a/core/models/video_depth_anything/motion_module/motion_module.py b/core/models/video_depth_anything/motion_module/motion_module.py new file mode 100644 index 0000000..330c3ab --- /dev/null +++ b/core/models/video_depth_anything/motion_module/motion_module.py @@ -0,0 +1,321 @@ +# This file is originally from AnimateDiff/animatediff/models/motion_module.py at main · guoyww/AnimateDiff +# SPDX-License-Identifier: Apache-2.0 license +# +# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification] +# Original file was released under [ Apache-2.0 license], with the full license text available at [https://github.com/guoyww/AnimateDiff?tab=Apache-2.0-1-ov-file#readme]. +import torch +import torch.nn.functional as F +from torch import nn + +from core.models.video_depth_anything.motion_module.attention import CrossAttention, FeedForward, apply_rotary_emb, precompute_freqs_cis + +from einops import rearrange, repeat +import math + +try: + import xformers + import xformers.ops + + XFORMERS_AVAILABLE = True +except ImportError: + print("xFormers not available") + XFORMERS_AVAILABLE = False + + +def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + +class TemporalModule(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads = 8, + num_transformer_block = 2, + num_attention_blocks = 2, + norm_num_groups = 32, + temporal_max_len = 32, + zero_initialize = True, + pos_embedding_type = "ape", + ): + super().__init__() + + self.temporal_transformer = TemporalTransformer3DModel( + in_channels=in_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads, + num_layers=num_transformer_block, + num_attention_blocks=num_attention_blocks, + norm_num_groups=norm_num_groups, + temporal_max_len=temporal_max_len, + pos_embedding_type=pos_embedding_type, + ) + + if zero_initialize: + self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) + + def forward(self, input_tensor, encoder_hidden_states, attention_mask=None, cached_hidden_state_list=None): + hidden_states = input_tensor + hidden_states, output_hidden_state_list = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask, cached_hidden_state_list) + + output = hidden_states + return output, output_hidden_state_list # list of hidden states + + +class TemporalTransformer3DModel(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads, + attention_head_dim, + num_layers, + num_attention_blocks = 2, + norm_num_groups = 32, + temporal_max_len = 32, + pos_embedding_type = "ape", + ): + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_attention_blocks=num_attention_blocks, + temporal_max_len=temporal_max_len, + pos_embedding_type=pos_embedding_type, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, cached_hidden_state_list=None): + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + output_hidden_state_list = [] + + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + batch, channel, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim).contiguous() + hidden_states = self.proj_in(hidden_states) + + # Transformer Blocks + if cached_hidden_state_list is not None: + n = len(cached_hidden_state_list) // len(self.transformer_blocks) + else: + n = 0 + for i, block in enumerate(self.transformer_blocks): + hidden_states, hidden_state_list = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, attention_mask=attention_mask, + cached_hidden_state_list=cached_hidden_state_list[i*n:(i+1)*n] if n else None) + output_hidden_state_list.extend(hidden_state_list) + + # output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + + return output, output_hidden_state_list + + +class TemporalTransformerBlock(nn.Module): + def __init__( + self, + dim, + num_attention_heads, + attention_head_dim, + num_attention_blocks = 2, + temporal_max_len = 32, + pos_embedding_type = "ape", + ): + super().__init__() + + self.attention_blocks = nn.ModuleList( + [ + TemporalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + temporal_max_len=temporal_max_len, + pos_embedding_type=pos_embedding_type, + ) + for i in range(num_attention_blocks) + ] + ) + self.norms = nn.ModuleList( + [ + nn.LayerNorm(dim) + for i in range(num_attention_blocks) + ] + ) + + self.ff = FeedForward(dim, dropout=0.0, activation_fn="geglu") + self.ff_norm = nn.LayerNorm(dim) + + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, cached_hidden_state_list=None): + output_hidden_state_list = [] + for i, (attention_block, norm) in enumerate(zip(self.attention_blocks, self.norms)): + norm_hidden_states = norm(hidden_states) + residual_hidden_states, output_hidden_states = attention_block( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + video_length=video_length, + attention_mask=attention_mask, + cached_hidden_states=cached_hidden_state_list[i] if cached_hidden_state_list is not None else None, + ) + hidden_states = residual_hidden_states + hidden_states + output_hidden_state_list.append(output_hidden_states) + + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output, output_hidden_state_list + + +class PositionalEncoding(nn.Module): + def __init__( + self, + d_model, + dropout = 0., + max_len = 32 + ): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)].to(x.dtype) + return self.dropout(x) + +class TemporalAttention(CrossAttention): + def __init__( + self, + temporal_max_len = 32, + pos_embedding_type = "ape", + *args, **kwargs + ): + super().__init__(*args, **kwargs) + + self.pos_embedding_type = pos_embedding_type + self._use_memory_efficient_attention_xformers = True + + self.pos_encoder = None + self.freqs_cis = None + if self.pos_embedding_type == "ape": + self.pos_encoder = PositionalEncoding( + kwargs["query_dim"], + dropout=0., + max_len=temporal_max_len + ) + + elif self.pos_embedding_type == "rope": + self.freqs_cis = precompute_freqs_cis( + kwargs["query_dim"], + temporal_max_len + ) + + else: + raise NotImplementedError + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, cached_hidden_states=None): + # TODO: support cache for these + assert encoder_hidden_states is None + assert attention_mask is None + + d = hidden_states.shape[1] + d_in = 0 + if cached_hidden_states is None: + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + input_hidden_states = hidden_states # (bxd) f c + else: + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=1) + input_hidden_states = hidden_states + d_in = cached_hidden_states.shape[1] + hidden_states = torch.cat([cached_hidden_states, hidden_states], dim=1) + + if self.pos_encoder is not None: + hidden_states = self.pos_encoder(hidden_states) + + encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states[:, d_in:, ...]) + dim = query.shape[-1] + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if self.freqs_cis is not None: + seq_len = query.shape[1] + freqs_cis = self.freqs_cis[:seq_len].to(query.device) + query, key = apply_rotary_emb(query, key, freqs_cis) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + + use_memory_efficient = XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers + if use_memory_efficient and (dim // self.heads) % 8 != 0: + # print('Warning: the dim {} cannot be divided by 8. Fall into normal attention'.format(dim // self.heads)) + use_memory_efficient = False + + # attention, what we cannot get enough of + if use_memory_efficient: + query = self.reshape_heads_to_4d(query) + key = self.reshape_heads_to_4d(key) + value = self.reshape_heads_to_4d(value) + + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + raise NotImplementedError + # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states, input_hidden_states diff --git a/core/models/video_depth_anything/util/__pycache__/blocks.cpython-313.pyc b/core/models/video_depth_anything/util/__pycache__/blocks.cpython-313.pyc new file mode 100644 index 0000000..2602f68 Binary files /dev/null and b/core/models/video_depth_anything/util/__pycache__/blocks.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/util/__pycache__/transform.cpython-313.pyc b/core/models/video_depth_anything/util/__pycache__/transform.cpython-313.pyc new file mode 100644 index 0000000..b05bb18 Binary files /dev/null and b/core/models/video_depth_anything/util/__pycache__/transform.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/util/__pycache__/util.cpython-313.pyc b/core/models/video_depth_anything/util/__pycache__/util.cpython-313.pyc new file mode 100644 index 0000000..84fb236 Binary files /dev/null and b/core/models/video_depth_anything/util/__pycache__/util.cpython-313.pyc differ diff --git a/core/models/video_depth_anything/util/blocks.py b/core/models/video_depth_anything/util/blocks.py new file mode 100644 index 0000000..0be16c0 --- /dev/null +++ b/core/models/video_depth_anything/util/blocks.py @@ -0,0 +1,162 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn is True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn is True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn is True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + ): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand is True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1 + ) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output.contiguous(), **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output diff --git a/core/models/video_depth_anything/util/dc_utils.py b/core/models/video_depth_anything/util/dc_utils.py new file mode 100644 index 0000000..1380956 --- /dev/null +++ b/core/models/video_depth_anything/util/dc_utils.py @@ -0,0 +1,86 @@ +# This file is originally from DepthCrafter/depthcrafter/utils.py at main · Tencent/DepthCrafter +# SPDX-License-Identifier: MIT License license +# +# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification] +# Original file is released under [ MIT License license], with the full license text available at [https://github.com/Tencent/DepthCrafter?tab=License-1-ov-file]. +import numpy as np +import matplotlib.cm as cm +import imageio +try: + from decord import VideoReader, cpu + DECORD_AVAILABLE = True +except: + import cv2 + DECORD_AVAILABLE = False + +def ensure_even(value): + return value if value % 2 == 0 else value + 1 + +def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1): + if DECORD_AVAILABLE: + vid = VideoReader(video_path, ctx=cpu(0)) + original_height, original_width = vid.get_batch([0]).shape[1:3] + height = original_height + width = original_width + if max_res > 0 and max(height, width) > max_res: + scale = max_res / max(original_height, original_width) + height = ensure_even(round(original_height * scale)) + width = ensure_even(round(original_width * scale)) + + vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height) + + fps = vid.get_avg_fps() if target_fps == -1 else target_fps + stride = round(vid.get_avg_fps() / fps) + stride = max(stride, 1) + frames_idx = list(range(0, len(vid), stride)) + if process_length != -1 and process_length < len(frames_idx): + frames_idx = frames_idx[:process_length] + frames = vid.get_batch(frames_idx).asnumpy() + else: + cap = cv2.VideoCapture(video_path) + original_fps = cap.get(cv2.CAP_PROP_FPS) + original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + + if max_res > 0 and max(original_height, original_width) > max_res: + scale = max_res / max(original_height, original_width) + height = round(original_height * scale) + width = round(original_width * scale) + + fps = original_fps if target_fps < 0 else target_fps + + stride = max(round(original_fps / fps), 1) + + frames = [] + frame_count = 0 + while cap.isOpened(): + ret, frame = cap.read() + if not ret or (process_length > 0 and frame_count >= process_length): + break + if frame_count % stride == 0: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert BGR to RGB + if max_res > 0 and max(original_height, original_width) > max_res: + frame = cv2.resize(frame, (width, height)) # Resize frame + frames.append(frame) + frame_count += 1 + cap.release() + frames = np.stack(frames, axis=0) + + return frames, fps + + +def save_video(frames, output_video_path, fps=10, is_depths=False, grayscale=False): + writer = imageio.get_writer(output_video_path, fps=fps, macro_block_size=1, codec='libx264', ffmpeg_params=['-crf', '18']) + if is_depths: + colormap = np.array(cm.get_cmap("inferno").colors) + d_min, d_max = frames.min(), frames.max() + for i in range(frames.shape[0]): + depth = frames[i] + depth_norm = ((depth - d_min) / (d_max - d_min) * 255).astype(np.uint8) + depth_vis = (colormap[depth_norm] * 255).astype(np.uint8) if not grayscale else depth_norm + writer.append_data(depth_vis) + else: + for i in range(frames.shape[0]): + writer.append_data(frames[i]) + + writer.close() diff --git a/core/models/video_depth_anything/util/transform.py b/core/models/video_depth_anything/util/transform.py new file mode 100644 index 0000000..b14aacd --- /dev/null +++ b/core/models/video_depth_anything/util/transform.py @@ -0,0 +1,158 @@ +import numpy as np +import cv2 + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) + + # resize sample + sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) + + if self.__resize_target: + if "depth" in sample: + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) + + if "mask" in sample: + sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + return sample \ No newline at end of file diff --git a/core/models/video_depth_anything/util/util.py b/core/models/video_depth_anything/util/util.py new file mode 100644 index 0000000..75ff80a --- /dev/null +++ b/core/models/video_depth_anything/util/util.py @@ -0,0 +1,74 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +def compute_scale_and_shift(prediction, target, mask, scale_only=False): + if scale_only: + return compute_scale(prediction, target, mask), 0 + else: + return compute_scale_and_shift_full(prediction, target, mask) + + +def compute_scale(prediction, target, mask): + # system matrix: A = [[a_00, a_01], [a_10, a_11]] + prediction = prediction.astype(np.float32) + target = target.astype(np.float32) + mask = mask.astype(np.float32) + + a_00 = np.sum(mask * prediction * prediction) + a_01 = np.sum(mask * prediction) + a_11 = np.sum(mask) + + # right hand side: b = [b_0, b_1] + b_0 = np.sum(mask * prediction * target) + + x_0 = b_0 / (a_00 + 1e-6) + + return x_0 + +def compute_scale_and_shift_full(prediction, target, mask): + # system matrix: A = [[a_00, a_01], [a_10, a_11]] + prediction = prediction.astype(np.float32) + target = target.astype(np.float32) + mask = mask.astype(np.float32) + + a_00 = np.sum(mask * prediction * prediction) + a_01 = np.sum(mask * prediction) + a_11 = np.sum(mask) + + b_0 = np.sum(mask * prediction * target) + b_1 = np.sum(mask * target) + + x_0 = 1 + x_1 = 0 + + det = a_00 * a_11 - a_01 * a_01 + + if det != 0: + x_0 = (a_11 * b_0 - a_01 * b_1) / det + x_1 = (-a_01 * b_0 + a_00 * b_1) / det + + return x_0, x_1 + + +def get_interpolate_frames(frame_list_pre, frame_list_post): + assert len(frame_list_pre) == len(frame_list_post) + min_w = 0.0 + max_w = 1.0 + step = (max_w - min_w) / (len(frame_list_pre)-1) + post_w_list = [min_w] + [i * step for i in range(1,len(frame_list_pre)-1)] + [max_w] + interpolated_frames = [] + for i in range(len(frame_list_pre)): + interpolated_frames.append(frame_list_pre[i] * (1-post_w_list[i]) + frame_list_post[i] * post_w_list[i]) + return interpolated_frames \ No newline at end of file diff --git a/core/models/video_depth_anything/video_depth.py b/core/models/video_depth_anything/video_depth.py new file mode 100644 index 0000000..801e5a4 --- /dev/null +++ b/core/models/video_depth_anything/video_depth.py @@ -0,0 +1,163 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F +import torch.nn as nn +from torchvision.transforms import Compose +import cv2 +from tqdm import tqdm +import numpy as np +import gc + +from core.models.video_depth_anything.dinov2 import DINOv2 +from core.models.video_depth_anything.dpt_temporal import DPTHeadTemporal +from core.models.video_depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet + +from core.models.video_depth_anything.util.util import compute_scale_and_shift, get_interpolate_frames + +# infer settings, do not change +INFER_LEN = 32 +OVERLAP = 10 +KEYFRAMES = [0,12,24,25,26,27,28,29,30,31] +INTERP_LEN = 8 + +class VideoDepthAnything(nn.Module): + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + use_bn=False, + use_clstoken=False, + num_frames=32, + pe='ape', + metric=False, + ): + super(VideoDepthAnything, self).__init__() + + self.intermediate_layer_idx = { + 'vits': [2, 5, 8, 11], + "vitb": [2, 5, 8, 11], + 'vitl': [4, 11, 17, 23] + } + + self.encoder = encoder + self.pretrained = DINOv2(model_name=encoder) + + self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe) + self.metric = metric + + def forward(self, x): + B, T, C, H, W = x.shape + patch_h, patch_w = H // 14, W // 14 + features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True) + depth = self.head(features, patch_h, patch_w, T)[0] + depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True) + depth = F.relu(depth) + return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W] + + def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda', fp32=False): + frame_height, frame_width = frames[0].shape[:2] + ratio = max(frame_height, frame_width) / min(frame_height, frame_width) + if ratio > 1.78: # we recommend to process video with ratio smaller than 16:9 due to memory limitation + input_size = int(input_size * 1.777 / ratio) + input_size = round(input_size / 14) * 14 + + transform = Compose([ + Resize( + width=input_size, + height=input_size, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + frame_list = [frames[i] for i in range(frames.shape[0])] + frame_step = INFER_LEN - OVERLAP + org_video_len = len(frame_list) + append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step) + frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len + + depth_list = [] + pre_input = None + for frame_id in tqdm(range(0, org_video_len, frame_step)): + cur_list = [] + for i in range(INFER_LEN): + cur_list.append(torch.from_numpy(transform({'image': frame_list[frame_id+i].astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0)) + cur_input = torch.cat(cur_list, dim=1).to(device) + if pre_input is not None: + cur_input[:, :OVERLAP, ...] = pre_input[:, KEYFRAMES, ...] + + with torch.no_grad(): + with torch.autocast(device_type=device, enabled=(not fp32)): + depth = self.forward(cur_input) # depth shape: [1, T, H, W] + + depth = depth.to(cur_input.dtype) + depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True) + depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])] + + pre_input = cur_input + + del frame_list + gc.collect() + + depth_list_aligned = [] + ref_align = [] + align_len = OVERLAP - INTERP_LEN + kf_align_list = KEYFRAMES[:align_len] + + for frame_id in range(0, len(depth_list), INFER_LEN): + if len(depth_list_aligned) == 0: + depth_list_aligned += depth_list[:INFER_LEN] + for kf_id in kf_align_list: + ref_align.append(depth_list[frame_id+kf_id]) + else: + curr_align = [] + for i in range(len(kf_align_list)): + curr_align.append(depth_list[frame_id+i]) + + if self.metric: + scale, shift = 1.0, 0.0 + else: + scale, shift = compute_scale_and_shift(np.concatenate(curr_align), + np.concatenate(ref_align), + np.concatenate(np.ones_like(ref_align)==1)) + + pre_depth_list = depth_list_aligned[-INTERP_LEN:] + post_depth_list = depth_list[frame_id+align_len:frame_id+OVERLAP] + for i in range(len(post_depth_list)): + post_depth_list[i] = post_depth_list[i] * scale + shift + post_depth_list[i][post_depth_list[i]<0] = 0 + depth_list_aligned[-INTERP_LEN:] = get_interpolate_frames(pre_depth_list, post_depth_list) + + for i in range(OVERLAP, INFER_LEN): + new_depth = depth_list[frame_id+i] * scale + shift + new_depth[new_depth<0] = 0 + depth_list_aligned.append(new_depth) + + ref_align = ref_align[:1] + for kf_id in kf_align_list[1:]: + new_depth = depth_list[frame_id+kf_id] * scale + shift + new_depth[new_depth<0] = 0 + ref_align.append(new_depth) + + depth_list = depth_list_aligned + + return np.stack(depth_list[:org_video_len], axis=0), target_fps + diff --git a/core/models/video_depth_anything/video_depth_stream.py b/core/models/video_depth_anything/video_depth_stream.py new file mode 100644 index 0000000..25e7480 --- /dev/null +++ b/core/models/video_depth_anything/video_depth_stream.py @@ -0,0 +1,161 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F +import torch.nn as nn +from torchvision.transforms import Compose +import cv2 +import numpy as np + +from core.models.video_depth_anything.dinov2 import DINOv2 +from core.models.video_depth_anything.dpt_temporal import DPTHeadTemporal +from core.models.video_depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet + +from core.models.video_depth_anything.utils.util import compute_scale_and_shift, get_interpolate_frames + +# infer settings, do not change +INFER_LEN = 32 +OVERLAP = 10 +INTERP_LEN = 8 + +class VideoDepthAnything(nn.Module): + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + use_bn=False, + use_clstoken=False, + num_frames=32, + pe='ape' + ): + super(VideoDepthAnything, self).__init__() + + self.intermediate_layer_idx = { + 'vits': [2, 5, 8, 11], + "vitb": [2, 5, 8, 11], + 'vitl': [4, 11, 17, 23] + } + + self.encoder = encoder + self.pretrained = DINOv2(model_name=encoder) + + self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe) + self.transform = None + self.frame_id_list = [] + self.frame_cache_list = [] + self.gap = (INFER_LEN - OVERLAP) * 2 - 1 - (OVERLAP - INTERP_LEN) + assert self.gap == 41 + self.id = -1 + + def forward(self, x): + return self.forward_depth(self.forward_features(x), x.shape)[0] + + def forward_features(self, x): + features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True) + return features + + def forward_depth(self, features, x_shape, cached_hidden_state_list=None): + B, T, C, H, W = x_shape + patch_h, patch_w = H // 14, W // 14 + depth, cur_cached_hidden_state_list = self.head(features, patch_h, patch_w, T, cached_hidden_state_list=cached_hidden_state_list) + depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True) + depth = F.relu(depth) + return depth.squeeze(1).unflatten(0, (B, T)), cur_cached_hidden_state_list # return shape [B, T, H, W] + + def infer_video_depth_one(self, frame, input_size=518, device='cuda', fp32=False): + self.id += 1 + + if self.transform is None: # first frame + # Initialize the transform + frame_height, frame_width = frame.shape[:2] + self.frame_height = frame_height + self.frame_width = frame_width + ratio = max(frame_height, frame_width) / min(frame_height, frame_width) + if ratio > 1.78: # we recommend to process video with ratio smaller than 16:9 due to memory limitation + input_size = int(input_size * 1.777 / ratio) + input_size = round(input_size / 14) * 14 + + self.transform = Compose([ + Resize( + width=input_size, + height=input_size, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + # Inference the first frame + cur_list = [torch.from_numpy(self.transform({'image': frame.astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0)] + cur_input = torch.cat(cur_list, dim=1).to(device) + + with torch.no_grad(): + with torch.autocast(device_type=device, enabled=(not fp32)): + cur_feature = self.forward_features(cur_input) + x_shape = cur_input.shape + depth, cached_hidden_state_list = self.forward_depth(cur_feature, x_shape) + + depth = depth.to(cur_input.dtype) + depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True) + + # Copy multiple cache to simulate the windows + self.frame_cache_list = [cached_hidden_state_list] * INFER_LEN + self.frame_id_list.extend([0] * (INFER_LEN - 1)) + + new_depth = depth[0][0].cpu().numpy() + else: + frame_height, frame_width = frame.shape[:2] + assert frame_height == self.frame_height + assert frame_width == self.frame_width + + # infer feature + cur_input = torch.from_numpy(self.transform({'image': frame.astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0).to(device) + with torch.no_grad(): + with torch.autocast(device_type=device, enabled=(not fp32)): + cur_feature = self.forward_features(cur_input) + x_shape = cur_input.shape + + cur_list = self.frame_cache_list[0:2] + self.frame_cache_list[-INFER_LEN+3:] + ''' + cur_id = self.frame_id_list[0:2] + self.frame_id_list[-INFER_LEN+3:] + print(f"cur_id: {cur_id}") + ''' + assert len(cur_list) == INFER_LEN - 1 + cur_cache = [torch.cat([h[i] for h in cur_list], dim=1) for i in range(len(cur_list[0]))] + + # infer depth + with torch.no_grad(): + with torch.autocast(device_type=device, enabled=(not fp32)): + depth, new_cache = self.forward_depth(cur_feature, x_shape, cached_hidden_state_list=cur_cache) + + depth = depth.to(cur_input.dtype) + depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True) + depth_list = [depth[i][0].cpu().numpy() for i in range(depth.shape[0])] + + new_depth = depth_list[-1] + + self.frame_cache_list.append(new_cache) + + # adjust the sliding window + self.frame_id_list.append(self.id) + if self.id + INFER_LEN > self.gap + 1: + del self.frame_id_list[1] + del self.frame_cache_list[1] + + return new_depth diff --git a/core/preview_gui.py b/core/preview_gui.py index 7abea02..1ca841d 100644 --- a/core/preview_gui.py +++ b/core/preview_gui.py @@ -8,6 +8,9 @@ import os, platform, warnings import datetime from tkinter import filedialog +import numpy as np + + from core.render_3d import ( frame_to_tensor, @@ -74,6 +77,10 @@ def open_3d_preview_window( # --- IPD preview state --- ipd_enabled = tk.BooleanVar(value=bool(settings.get("ipd_enabled", True))) ipd_scale = tk.DoubleVar(value=float(settings.get("ipd_scale", 1.00))) + + # --- Convergence overlay state --- + show_convergence_guides = tk.BooleanVar(value=bool(settings.get("show_convergence_guides", False))) + preview_win = tk.Toplevel() @@ -235,18 +242,27 @@ def apply_size(): feathering_checkbox = tk.Checkbutton(top_controls_frame, text="Feathering", variable=enable_feathering) feathering_checkbox.grid(row=0, column=7) + + guides_checkbox = tk.Checkbutton( + top_controls_frame, + text="Convergence Guides", + variable=show_convergence_guides, + command=lambda: update_preview_debounced() + ) + guides_checkbox.grid(row=0, column=9, padx=(0, 10)) + shift_frame = tk.LabelFrame(control_container, text="Depth Shift Settings", padx=10, pady=5) shift_frame.pack(pady=(0, 10), anchor="center") - fg_slider = tk.Scale(shift_frame, from_=-20, to=20, resolution=0.5, orient="horizontal", label="BG Shift", variable=fg_shift, length=200) - fg_slider.grid(row=0, column=2, padx=10) + fg_slider = tk.Scale(shift_frame, from_=-20, to=20, resolution=0.5, orient="horizontal", label="FG Shift", variable=fg_shift, length=200) + fg_slider.grid(row=0, column=0, padx=10) mg_slider = tk.Scale(shift_frame, from_=-10, to=10, resolution=0.5, orient="horizontal", label="MG Shift", variable=mg_shift, length=200) mg_slider.grid(row=0, column=1, padx=10) - bg_slider = tk.Scale(shift_frame, from_=-20, to=20, resolution=0.5, orient="horizontal", label="FG Shift", variable=bg_shift, length=200) - bg_slider.grid(row=0, column=0, padx=10) + bg_slider = tk.Scale(shift_frame, from_=-20, to=20, resolution=0.5, orient="horizontal", label="BG Shift", variable=bg_shift, length=200) + bg_slider.grid(row=0, column=2, padx=10) feather_frame = tk.LabelFrame(control_container, text="Parallax Control", padx=10, pady=5) feather_frame.pack(pady=(0, 10), anchor="center") @@ -457,6 +473,76 @@ def update_convergence_strength(*_): messagebox.showwarning("Invalid Input", "Enter a valid float for convergence strength.") convergence_slider.bind("", update_convergence_strength) + + def draw_convergence_guides(img_bgr, w, h): + """ + Draws convergence crosshair + grid markers. + Automatically respects letterbox borders. + """ + if img_bgr is None: + return img_bgr + + out = img_bgr.copy() + + gray = cv2.cvtColor(out, cv2.COLOR_BGR2GRAY) + row_mean = gray.mean(axis=1) + + # Detect non-black rows (active picture area) + thresh = 8.0 + active_rows = np.where(row_mean > thresh)[0] + + if active_rows.size > 0: + y0 = int(active_rows[0]) + y1 = int(active_rows[-1]) + else: + y0, y1 = 0, h - 1 + + pad = max(2, int(min(w, h) * 0.01)) + y0 = max(0, y0 + pad) + y1 = min(h - 1, y1 - pad) + + ax0, ax1 = 0, w - 1 + ay0, ay1 = y0, y1 + aw = ax1 - ax0 + 1 + ah = ay1 - ay0 + 1 + + cx = (ax0 + ax1) // 2 + cy = (ay0 + ay1) // 2 + + base = min(aw, ah) + + # ⬇️ Bigger, more readable sizes + arm = max(10, base // 55) + gap = max(6, base // 110) + thickness = 2 + col = (235, 235, 235) + + # Crosshair + cv2.line(out, (cx - arm, cy), (cx - gap, cy), col, thickness, cv2.LINE_AA) + cv2.line(out, (cx + gap, cy), (cx + arm, cy), col, thickness, cv2.LINE_AA) + cv2.line(out, (cx, cy - arm), (cx, cy - gap), col, thickness, cv2.LINE_AA) + cv2.line(out, (cx, cy + gap), (cx, cy + arm), col, thickness, cv2.LINE_AA) + + # Center dot + cv2.circle(out, (cx, cy), 3, col, -1, cv2.LINE_AA) + + # ⬇️ Larger grid markers (easy to see, not noisy) + marker_size = max(12, base // 35) + + for gx in (ax0 + aw // 4, cx, ax0 + (3 * aw) // 4): + for gy in (ay0 + ah // 4, cy, ay0 + (3 * ah) // 4): + cv2.drawMarker( + out, + (gx, gy), + (200, 200, 200), + markerType=cv2.MARKER_CROSS, + markerSize=marker_size, + thickness=2, + line_type=cv2.LINE_AA + ) + + return out + def update_preview_now(): nonlocal preview_job @@ -590,9 +676,15 @@ def update_preview_now(): else: preview_img = generate_preview_image(preview_mode, left_tensor, right_tensor, shift_map, w, h) + # Resize and render # Resize and render if preview_img is not None: + # Optional convergence overlay (draw in preview image space) + if show_convergence_guides.get(): + preview_img = draw_convergence_guides(preview_img, w, h) + img_rgb = cv2.cvtColor(preview_img, cv2.COLOR_BGR2RGB) + try: preview_width = int(width_entry.get()) preview_height = int(height_entry.get()) @@ -636,6 +728,8 @@ def on_close(): 'brightness': brightness.get(), 'ipd_enabled': ipd_enabled.get(), 'ipd_scale': ipd_scale.get(), + 'show_convergence_guides': show_convergence_guides.get(), + } save_settings(settings) diff --git a/core/render_3d.py b/core/render_3d.py index 4371214..b4ecd3c 100644 --- a/core/render_3d.py +++ b/core/render_3d.py @@ -135,6 +135,71 @@ def merge_audio_from_source(final_video, original_video, output_with_audio): +def ffmpeg_rgb48_reader(path, width, height, start_s=None, end_s=None): + """ + Decode video frames to RGB 16-bit (rgb48le) while explicitly preserving HDR signaling. + This avoids FFmpeg doing implicit/guessed colorspace conversions on HDR10 sources. + + Returns frames as float32 RGB in [0,1] (still PQ-encoded values, not tonemapped). + """ + cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error"] + + # Seek before input for speed (keyframe seek). If you need exact frame-accurate + # seeking, do a second -ss after -i, but this is usually fine for rendering. + if start_s is not None: + cmd += ["-ss", str(float(start_s))] + + cmd += ["-i", path] + + # Clip window + if end_s is not None and start_s is not None: + dur = max(0.0, float(end_s) - float(start_s)) + cmd += ["-t", str(dur)] + elif end_s is not None: + cmd += ["-to", str(float(end_s))] + + # Force HDR colorspace handling so FFmpeg doesn't guess: + # - zscale sets primaries/transfer/matrix and preserves PQ/BT.2020 + # - npl=1000 sets nominal peak luminance (helps prevent weird scaling) + # - format=rgb48le ensures 16-bit RGB output + vf = ( + "zscale=primaries=bt2020:transfer=smpte2084:matrix=bt2020nc:" + "range=tv:npl=1000,format=rgb48le" + ) + + cmd += [ + "-an", "-sn", "-dn", + "-vf", vf, + "-f", "rawvideo", + "-pix_fmt", "rgb48le", + "-vsync", "0", + "-" + ] + + p = subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=10**7) + frame_bytes = int(width) * int(height) * 3 * 2 # 3 channels * 16-bit + + try: + while True: + buf = p.stdout.read(frame_bytes) + if not buf or len(buf) < frame_bytes: + break + arr = np.frombuffer(buf, dtype=np.uint16).reshape((height, width, 3)) + # Still PQ-encoded, just higher precision; do not tonemap here. + yield (arr.astype(np.float32) / 65535.0).clip(0.0, 1.0) + finally: + try: + if p.stdout: + p.stdout.close() + except Exception: + pass + p.wait() + # Optional: surface decode errors + if p.returncode not in (0, None): + raise RuntimeError(f"ffmpeg_rgb48_reader: ffmpeg exited with code {p.returncode}") + + + def ffmpeg_yuv10_reader(path, width, height): """ Yields P010LE frames as float32 RGB in [0,1] with simple 10-bit scaling. @@ -171,6 +236,26 @@ def ffmpeg_yuv10_reader(path, width, height): p.stdout.close(); p.wait() +def reset_render_state(): + # reset shift EMA + if hasattr(pixel_shift_cuda, "_shift_ema"): + pixel_shift_cuda._shift_ema = None + + # reset floating window tracker + if "floating_window_tracker" in globals(): + floating_window_tracker.prev_offset = 0.0 + floating_window_tracker.frame_counter = 0 + + # reset DFW easing + for k in ("dfw_last_side", "dfw_last_width"): + if k in globals(): + del globals()[k] + + # reset depth percentile EMA so it learns per render + global depth_ema_norm + depth_ema_norm = DepthPercentileEMA(p_lo=0.02, p_hi=0.98, alpha=0.92) + + def sculpt_depth_u8(base_depth_u8, mask_u8, *, near=1.0, far=0.4, feather_px=12, round_gamma=1.2): @@ -480,6 +565,120 @@ def update(self, x): conv_ema = ConvergenceEMA(alpha=0.97) MID_GAMMA = 0.90 # 0.80–0.95 works well +def frame16_to_tensor(rgb_float_01): + """ + rgb_float_01: [H,W,3] float32 0..1 in RGB (PQ-encoded values, but high precision) + Returns torch [3,H,W] float32 0..1 on device. + """ + t = torch.from_numpy(rgb_float_01).float().permute(2, 0, 1).contiguous() + return t.to(torch_device) + +def tensor_to_rgb48_bytes(rgb_tensor): + """ + rgb_tensor: torch [3,H,W] float in [0,1] + Returns bytes for rgb48le (uint16 little-endian). + """ + x = rgb_tensor.clamp(0.0, 1.0).permute(1, 2, 0).detach().cpu().numpy() + u16 = (x * 65535.0 + 0.5).astype(np.uint16) + return u16.tobytes() + +import torch.nn.functional as F + +def tensor_pad_to_aspect_ratio(rgb_t, target_width, target_height): + """ + Torch equivalent of pad_to_aspect_ratio() + rgb_t: [3,H,W] RGB float in 0..1 + Returns [3,target_height,target_width] + """ + + C, h, w = rgb_t.shape + target_aspect = target_width / target_height + current_aspect = w / h + + # Step 1: resize to fit while preserving aspect + if current_aspect > target_aspect: + # wider → match width + new_w = target_width + new_h = int(target_width / current_aspect) + else: + # taller → match height + new_h = target_height + new_w = int(current_aspect * target_height) + + resized = F.interpolate( + rgb_t.unsqueeze(0), + size=(new_h, new_w), + mode="bilinear", + align_corners=False + ).squeeze(0) + + # Step 2: padded canvas (black) + padded = torch.zeros( + (3, target_height, target_width), + device=rgb_t.device, + dtype=rgb_t.dtype + ) + + # Step 3: center it + x_offset = (target_width - new_w) // 2 + y_offset = (target_height - new_h) // 2 + + padded[:, y_offset:y_offset + new_h, x_offset:x_offset + new_w] = resized + + return padded.clamp(0.0, 1.0) + +def tensor_sharpen(rgb_t, factor=0.0): + # factor 0 = no sharpen + if factor <= 1e-6: + return rgb_t + # unsharp-ish kernel (simple and stable) + # conv2d expects [N,C,H,W] + k = torch.tensor([[0, -1, 0], + [-1, 5.0 + float(factor), -1], + [0, -1, 0]], device=rgb_t.device, dtype=rgb_t.dtype).view(1,1,3,3) + x = rgb_t.unsqueeze(0) # [1,3,H,W] + # apply per-channel by groups=3 + k3 = k.repeat(3, 1, 1, 1) # [3,1,3,3] + y = F.conv2d(x, k3, padding=1, groups=3) + return y.squeeze(0).clamp(0.0, 1.0) + +def tensor_apply_side_mask(rgb_t, side="left", width=40, solid_black=True, fade=False): + if width <= 0: + return rgb_t + C, H, W = rgb_t.shape + w = min(int(width), W) + mask = torch.ones((1, H, W), device=rgb_t.device, dtype=rgb_t.dtype) + + if solid_black: + if side == "left": + mask[:, :, :w] = 0 + else: + mask[:, :, W-w:] = 0 + else: + if fade: + ramp = torch.linspace(0, 1, w, device=rgb_t.device, dtype=rgb_t.dtype) + if side == "left": + mask[:, :, :w] = ramp.view(1, 1, w) + else: + mask[:, :, W-w:] = ramp.flip(0).view(1, 1, w) + else: + if side == "left": + mask[:, :, :w] = 0 + else: + mask[:, :, W-w:] = 0 + + return (rgb_t * mask).clamp(0.0, 1.0) + +def format_3d_output_torch(left_t, right_t, fmt): + # left_t/right_t: [3,H,W] + if fmt in ("Half-SBS", "Full-SBS", "VR"): + return torch.cat([left_t, right_t], dim=2) # SBS + elif fmt == "Passive Interlaced": + out = left_t.clone() + out[:, 1::2, :] = right_t[:, 1::2, :] + return out + else: + return torch.cat([left_t, right_t], dim=2) def tensor_to_frame(tensor): frame_cpu = (tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) @@ -816,6 +1015,7 @@ def pixel_shift_cuda( fg_pop_multiplier=1.20, bg_push_multiplier=1.10, subject_lock_strength=1.00, + return_tensors=False, ): width = int(width) height = int(height) @@ -850,7 +1050,7 @@ def pixel_shift_cuda( # weights from shaped depth (steeper foreground falloff) fg_weight = (1.0 - d_shaped).pow(1.5).clamp(0, 1) - mg_weight = (1.0 - (d_shaped - depth_pop_mid).abs() * 3.0).clamp(0, 1) # slightly tighter mid band + mg_weight = (1.0 - (d_shaped - depth_pop_mid).abs() * 5.0).clamp(0, 1) # slightly tighter mid band bg_weight = d_shaped.clamp(0, 1) half_width = width / 2.0 @@ -959,10 +1159,58 @@ def pixel_shift_cuda( left_blended, right_blended = warped_left, warped_right if return_shift_map: + if return_tensors: + return left_blended, right_blended, final_shift.detach().cpu() return tensor_to_frame(left_blended), tensor_to_frame(right_blended), final_shift.detach().cpu() else: + if return_tensors: + return left_blended, right_blended return tensor_to_frame(left_blended), tensor_to_frame(right_blended) +def tensor_pad_to_aspect(t: torch.Tensor, target_w: int, target_h: int) -> torch.Tensor: + """ + t: [3,H,W] RGB float 0..1 + Pads with black to exactly target_w/target_h, centered. + """ + C, H, W = t.shape + out = t + # resize to fit inside target while preserving aspect + src_ar = W / max(H, 1) + dst_ar = target_w / max(target_h, 1) + + if abs(src_ar - dst_ar) > 1e-6: + if src_ar > dst_ar: + # too wide, fit width + new_w = target_w + new_h = int(round(target_w / src_ar)) + else: + # too tall, fit height + new_h = target_h + new_w = int(round(target_h * src_ar)) + else: + new_w, new_h = target_w, target_h + + out = F.interpolate(out.unsqueeze(0), size=(new_h, new_w), mode="bilinear", align_corners=False).squeeze(0) + + # pad to target + pad_l = max(0, (target_w - new_w) // 2) + pad_r = max(0, target_w - new_w - pad_l) + pad_t = max(0, (target_h - new_h) // 2) + pad_b = max(0, target_h - new_h - pad_t) + + return F.pad(out, (pad_l, pad_r, pad_t, pad_b), mode="constant", value=0.0).clamp(0.0, 1.0) + + +def tensor_apply_sharpen(t: torch.Tensor, factor: float = 1.0) -> torch.Tensor: + """ + Simple unsharp-style sharpen for tensors [3,H,W] in 0..1. + """ + if factor <= 0: + return t + # light blur + blur = tv_gaussian_blur(t, kernel_size=3, sigma=1.0) + out = t + (t - blur) * float(factor) + return out.clamp(0.0, 1.0) # Sharpening @@ -1101,7 +1349,7 @@ def format_3d_output(left, right, fmt): return np.hstack((lw, rw)) elif fmt == "Red-Cyan Anaglyph": - return generate_anaglyph_3d(left, right, mode="halfcolor") # start with halfcolor + return generate_anaglyph_3d(left, right, mode="dubois") # start with halfcolor elif fmt == "Passive Interlaced": interlaced = np.zeros_like(left) @@ -1153,7 +1401,7 @@ def apply_side_mask(image, side="left", width=40, fade=False, solid_black=True): output = image.copy() if solid_black: - # 🧱 Solid opaque black bar — cinema floating window + # Solid opaque black bar — cinema floating window if side == "left": output[:, :width] = 0 else: @@ -1271,11 +1519,32 @@ def render_sbs_3d( end_s=None, eye_mode="sbs", ): - + reset_render_state() cap, dcap = cv2.VideoCapture(input_path), cv2.VideoCapture(depth_path) if not cap.isOpened() or not dcap.isOpened(): return + hdr_gen = None + if preserve_hdr10: + # Use original input dimensions for decode + src_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + src_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + hdr_gen = ffmpeg_rgb48_reader(input_path, src_w, src_h, start_s=start_s, end_s=end_s) + + def read_next_frame(): + """Returns (ret, frame_tensor, frame_bgr_or_None).""" + if preserve_hdr10: + try: + rgb = next(hdr_gen) # float RGB 0..1 + except StopIteration: + return False, None, None + return True, frame16_to_tensor(rgb), None + else: + ret, frame_bgr = cap.read() + if not ret: + return False, None, None + return True, frame_to_tensor(frame_bgr), frame_bgr + # base facts total_frames_full = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) or fps or 30.0 @@ -1305,13 +1574,13 @@ def render_sbs_3d( dcap.set(cv2.CAP_PROP_POS_FRAMES, start_frame_idx) # FIRST READ occurs *after* seeking - ret1, frame = cap.read() + ret1, frame_tensor, frame = read_next_frame() ret2, depth = dcap.read() + if not ret1 or not ret2: cap.release(); dcap.release() return - global global_session_start_time if global_session_start_time is None: global_session_start_time = time.time() @@ -1344,7 +1613,7 @@ def render_sbs_3d( # 🆕 blank frame indices are absolute — offset them for the clip window blank_offset = start_frame_idx - first_frame_tensor = frame_to_tensor(frame) + first_frame_tensor = frame_tensor.clone() if auto_crop_black_bars: # Detect once on first frame @@ -1379,6 +1648,7 @@ def render_sbs_3d( # fallback to current frame tensor size _, h0, w0 = first_frame_tensor.shape original_video_width, original_video_height = w0, h0 + resized_width = original_video_width resized_height = original_video_height @@ -1387,21 +1657,38 @@ def render_sbs_3d( per_eye_h = resized_height out_width = per_eye_w * 2 out_height = per_eye_h + elif output_format == "Half-SBS": per_eye_w = resized_width // 2 per_eye_h = resized_height out_width = resized_width - out_height = resized_height + out_height = resized_height + elif output_format == "VR": per_eye_w = 1440 per_eye_h = 1600 out_width = per_eye_w * 2 out_height = per_eye_h + + elif output_format == "Red-Cyan Anaglyph": + per_eye_w = resized_width + per_eye_h = resized_height + out_width = resized_width + out_height = resized_height + + elif output_format == "Passive Interlaced": + # IMPORTANT: interlaced is single-frame size (not SBS) + per_eye_w = resized_width + per_eye_h = resized_height + out_width = resized_width + out_height = resized_height + else: per_eye_w = resized_width per_eye_h = resized_height out_width = resized_width * 2 out_height = resized_height + else: resized_height = output_height resized_width = int(resized_height * target_ratio) @@ -1412,32 +1699,43 @@ def render_sbs_3d( per_eye_w, per_eye_h = 1920, 1080 out_width = per_eye_w * 2 out_height = per_eye_h + elif output_format == "Half-SBS": per_eye_w = resized_width // 2 per_eye_h = resized_height out_width = resized_width out_height = resized_height + elif output_format == "VR": per_eye_w = 1440 per_eye_h = 1600 out_width = per_eye_w * 2 out_height = per_eye_h + elif output_format == "Red-Cyan Anaglyph": # One frame only, not SBS per_eye_w = resized_width per_eye_h = resized_height out_width = resized_width out_height = resized_height + + elif output_format == "Passive Interlaced": + # IMPORTANT: interlaced is single-frame size (not SBS) + per_eye_w = resized_width + per_eye_h = resized_height + out_width = resized_width + out_height = resized_height + else: per_eye_w = resized_width per_eye_h = resized_height out_width = resized_width * 2 out_height = resized_height - + if eye_mode in ("left", "right"): - out_width = per_eye_w + out_width = per_eye_w out_height = per_eye_h - + # --- invariants (fixed for the whole render) --- cinema_aspect_ratio = aspect_ratios.get(selected_aspect_ratio.get(), 16/9) single_eye = eye_mode in ("left", "right") @@ -1452,12 +1750,8 @@ def render_sbs_3d( eye_w = per_eye_w eye_h = per_eye_h - # safer: compute width for floating window based on the actual frame you’ll mask - if single_eye: - width_for_bars = per_eye_w - else: - width_for_bars = resized_width // 2 # half for each eye in SBS - + # Floating window should always operate on per-eye width + width_for_bars = per_eye_w # DOF / Color grading flags don’t change during render need_dof = (dof_strength > 0.0) @@ -1474,7 +1768,7 @@ def render_sbs_3d( ffmpeg_cmd = [ "ffmpeg","-y", "-f","rawvideo","-vcodec","rawvideo", - "-pix_fmt","bgr24", + "-pix_fmt", "rgb48le" if preserve_hdr10 else "bgr24", "-s", f"{out_width}x{out_height}", "-r", str(fps), "-i","-", @@ -1482,10 +1776,10 @@ def render_sbs_3d( "-c:v", selected_ffmpeg_codec, ] + is_nvenc = "nvenc" in selected_ffmpeg_codec # h264_nvenc/hevc_nvenc/av1_nvenc if preserve_hdr10: - # 10-bit + HDR signaling (no tone-map) ffmpeg_cmd += [ "-pix_fmt","p010le", "-color_range","tv", @@ -1493,6 +1787,7 @@ def render_sbs_3d( "-color_primaries","bt2020", "-color_trc","smpte2084", ] + if is_nvenc: ffmpeg_cmd += [ "-preset","p5", # NVENC preset (p1 fastest…p7 slowest) @@ -1548,18 +1843,10 @@ def render_sbs_3d( global temporal_depth_filter temporal_depth_filter = TemporalDepthFilter(alpha=0.5) - cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame_idx) - dcap.set(cv2.CAP_PROP_POS_FRAMES, start_frame_idx) avg_fps = 0 prev_depth_tensor = None focal_tracker = FocalDepthTracker(alpha=0.15, deadband=0.03, max_step=0.02) matte_ema = MatteEMA(alpha=ROTO_EMA_ALPHA) - - ret1, frame = cap.read() - ret2, depth = dcap.read() - if not ret1 or not ret2: - cap.release(); dcap.release() - return # Decide how many frames to process (for loop + progress) total_frames = clip_total_frames if clip_total_frames > 0 else total_frames_full @@ -1594,12 +1881,11 @@ def render_sbs_3d( if cancel_flag.is_set(): break - ret1, frame = cap.read() + ret1, frame_tensor, frame = read_next_frame() ret2, depth = dcap.read() if not ret1 or not ret2: break - - frame_tensor = frame_to_tensor(frame) + depth_tensor = depth_to_tensor(depth) if auto_crop_black_bars: @@ -1665,8 +1951,7 @@ def render_sbs_3d( # Continue with your existing temporal/percentile normalization depth_tensor = temporal_depth_filter.smooth(depth_tensor) depth_tensor = depth_ema_norm.normalize(depth_tensor) - - + fg, mg, bg = smoother.smooth(fg_shift, mg_shift, bg_shift) # dynamic IPD scale @@ -1678,12 +1963,17 @@ def render_sbs_3d( if (blank_offset + idx) in blank_frames: print(f"⏩ Skipping blank frame {idx}") - left_frame = frame - right_frame = frame + if preserve_hdr10: + # frame is None in HDR mode, so use the current tensor as both eyes + left_frame = frame_tensor + right_frame = frame_tensor + else: + left_frame = frame + right_frame = frame else: if ipd_factor == 0.0: left_frame, right_frame = pixel_shift_cuda( - frame_tensor, depth_tensor, resized_width, resized_height, + frame_tensor, depth_tensor, eye_w, eye_h, fg, mg, bg, blur_ksize=blur_ksize, feather_strength=feather_strength, @@ -1704,11 +1994,12 @@ def render_sbs_3d( fg_pop_multiplier=1.20, bg_push_multiplier=1.10, subject_lock_strength=1.00, + return_tensors=True, ) else: fg *= ipd_factor; mg *= ipd_factor; bg *= ipd_factor left_frame, right_frame = pixel_shift_cuda( - frame_tensor, depth_tensor, resized_width, resized_height, + frame_tensor, depth_tensor, eye_w, eye_h, fg, mg, bg, blur_ksize=blur_ksize, feather_strength=feather_strength, @@ -1729,6 +2020,7 @@ def render_sbs_3d( fg_pop_multiplier=1.20, bg_push_multiplier=1.10, subject_lock_strength=1.00, + return_tensors=True, ) @@ -1739,8 +2031,8 @@ def render_sbs_3d( if need_dof or need_color: # 1) to tensors once - left_t = frame_to_tensor(left_frame) # [3,H,W] - right_t = frame_to_tensor(right_frame) + left_t = left_frame + right_t = right_frame # 2) match depth to the eye frame once H, W = left_t.shape[1], left_t.shape[2] @@ -1767,9 +2059,15 @@ def render_sbs_3d( contrast=color_contrast, brightness=color_brightness) - # 5) back to numpy once - left_frame = tensor_to_frame(left_t) - right_frame = tensor_to_frame(right_t) + # 5) back to numpy for SDR only + if preserve_hdr10: + # keep tensors for HDR pipe + left_frame = left_t + right_frame = right_t + else: + left_frame = tensor_to_frame(left_t) + right_frame = tensor_to_frame(right_t) + # floating window mask subject_depth = estimate_subject_depth(depth_tensor) @@ -1782,34 +2080,24 @@ def render_sbs_3d( ) / (width_for_bars / 2 + 1e-6) zero_parallax_offset = float( - floating_window_tracker.smooth_offset(raw_zero, threshold=0.001)) - - # sharpen & pack - left_sharp = apply_sharpening(left_frame, sharpness_factor) - right_sharp = apply_sharpening(right_frame, sharpness_factor) - - if output_format == "Full-SBS": - left_out = pad_to_aspect_ratio(left_sharp, per_eye_w, per_eye_h) - right_out = pad_to_aspect_ratio(right_sharp, per_eye_w, per_eye_h) - elif output_format == "Half-SBS": - left_out = cv2.resize(left_sharp, (per_eye_w, per_eye_h), interpolation=cv2.INTER_AREA) - right_out = cv2.resize(right_sharp, (per_eye_w, per_eye_h), interpolation=cv2.INTER_AREA) - else: - left_out = pad_to_aspect_ratio(left_sharp, per_eye_w, per_eye_h) - right_out = pad_to_aspect_ratio(right_sharp, per_eye_w, per_eye_h) + floating_window_tracker.smooth_offset(raw_zero, threshold=0.001) + ) - # --- Dynamic Floating Window (softer and side aware) --- + # --- Dynamic Floating Window (shared compute, HDR + SDR) --- + dfw_apply = False + dfw_side = "left" + dfw_width = 0 + if use_floating_window and use_subject_tracking: global dfw_last_side, dfw_last_width - if 'dfw_last_side' not in globals(): + if "dfw_last_side" not in globals(): dfw_last_side = "left" dfw_last_width = 0 - # zero_parallax_offset just above is in "grid" space, usually [-1, 1] + # zero_parallax_offset is in "grid" space, usually [-1, 1] parallax_mag = abs(float(zero_parallax_offset)) - # Do not draw any bar if parallax is tiny if parallax_mag < DFW_MIN_PARALLAX: target_width = 0 else: @@ -1821,76 +2109,126 @@ def render_sbs_3d( depth_delta = abs(subject_depth_val - 0.5) - # Blend parallax and subject depth together parallax_delta = ( DFW_PARALLAX_WEIGHT * parallax_mag + DFW_DEPTH_WEIGHT * depth_delta ) - # Clamp the influence so big parallax does not explode the bar parallax_delta = min(parallax_delta, 0.12) - # Convert to pixels and clamp to a small fraction of the eye width target_width = int(per_eye_w * parallax_delta) max_bar_px = int(per_eye_w * DFW_MAX_BAR_FRAC) target_width = max(0, min(target_width, max_bar_px)) - # Decide which side to place the window on - # If this feels flipped for your content, just swap "left"/"right" here dfw_last_side = "left" if zero_parallax_offset > 0.0 else "right" - # Ease width over time so it does not pop dfw_last_width = int( DFW_WIDTH_EASE * dfw_last_width + (1.0 - DFW_WIDTH_EASE) * target_width ) - # Small widths are basically invisible, so skip - if dfw_last_width > 1: + dfw_side = dfw_last_side + dfw_width = dfw_last_width + dfw_apply = (dfw_width > 1) + + if preserve_hdr10 and not use_ffmpeg: + raise RuntimeError("HDR10 output requires FFmpeg. OpenCV VideoWriter is SDR-only in this pipeline.") + + # sharpen & pack + if preserve_hdr10: + # left_frame/right_frame are torch tensors [3,H,W] RGB float 0..1 + + # 1) Sharpen in tensor space + left_t = tensor_apply_sharpen(left_frame, sharpness_factor) + right_t = tensor_apply_sharpen(right_frame, sharpness_factor) + + # 2) Pack/resize in tensor space to match the exact per-eye target + # For HDR you should not use cv2. Keep tensors. + if output_format == "Half-SBS": + # Half-SBS means each eye is half width, same height + left_t = F.interpolate(left_t.unsqueeze(0), size=(per_eye_h, per_eye_w), mode="bilinear", align_corners=False).squeeze(0) + right_t = F.interpolate(right_t.unsqueeze(0), size=(per_eye_h, per_eye_w), mode="bilinear", align_corners=False).squeeze(0) + else: + # Full-SBS, VR, Anaglyph, Passive: keep per-eye sizing consistent + left_t = tensor_pad_to_aspect(left_t, per_eye_w, per_eye_h) + right_t = tensor_pad_to_aspect(right_t, per_eye_w, per_eye_h) + + # 3) Dynamic Floating Window, apply in tensor space + if dfw_apply: + left_t = tensor_apply_side_mask( + left_t, side=dfw_side, width=dfw_width, + fade=DFW_USE_FADE, solid_black=(not DFW_USE_FADE) + ) + right_t = tensor_apply_side_mask( + right_t, side=dfw_side, width=dfw_width, + fade=DFW_USE_FADE, solid_black=(not DFW_USE_FADE) + ) + + # 4) Final pack as tensor + if eye_mode == "left": + final_tensor = left_t + elif eye_mode == "right": + final_tensor = right_t + else: + # SBS tensor pack (RGB) + final_tensor = torch.cat([left_t, right_t], dim=2) # concat width + + # Optional: if you really need Passive Interlaced in HDR, do it in tensor space + if (eye_mode == "sbs") and (output_format == "Passive Interlaced"): + # interlace rows: even rows left, odd rows right, output is single-eye size + H, W2 = final_tensor.shape[1], final_tensor.shape[2] + W = W2 // 2 + left_eye = final_tensor[:, :, :W] + right_eye = final_tensor[:, :, W:] + inter = left_eye.clone() + inter[:, 1::2, :] = right_eye[:, 1::2, :] + final_tensor = inter + + else: + # SDR numpy path, keep your existing code + left_sharp = apply_sharpening(left_frame, sharpness_factor) + right_sharp = apply_sharpening(right_frame, sharpness_factor) + + if output_format == "Full-SBS": + left_out = pad_to_aspect_ratio(left_sharp, per_eye_w, per_eye_h) + right_out = pad_to_aspect_ratio(right_sharp, per_eye_w, per_eye_h) + elif output_format == "Half-SBS": + left_out = cv2.resize(left_sharp, (per_eye_w, per_eye_h), interpolation=cv2.INTER_AREA) + right_out = cv2.resize(right_sharp, (per_eye_w, per_eye_h), interpolation=cv2.INTER_AREA) + else: + left_out = pad_to_aspect_ratio(left_sharp, per_eye_w, per_eye_h) + right_out = pad_to_aspect_ratio(right_sharp, per_eye_w, per_eye_h) + + # Dynamic Floating Window stays the same for SDR (your existing apply_side_mask calls) + if dfw_apply: if DFW_USE_FADE: - left_out = apply_side_mask( - left_out, - side=dfw_last_side, - width=dfw_last_width, - fade=True, - solid_black=False, - ) - right_out = apply_side_mask( - right_out, - side=dfw_last_side, - width=dfw_last_width, - fade=True, - solid_black=False, - ) + left_out = apply_side_mask(left_out, side=dfw_side, width=dfw_width, fade=True, solid_black=False) + right_out = apply_side_mask(right_out, side=dfw_side, width=dfw_width, fade=True, solid_black=False) else: - # Hard cinema style black bar - left_out = apply_side_mask( - left_out, - side=dfw_last_side, - width=dfw_last_width, - fade=False, - solid_black=True, - ) - right_out = apply_side_mask( - right_out, - side=dfw_last_side, - width=dfw_last_width, - fade=False, - solid_black=True, - ) - - if eye_mode == "left": - final = left_out - elif eye_mode == "right": - final = right_out - else: # "sbs" - final = format_3d_output(left_out, right_out, output_format) - + left_out = apply_side_mask(left_out, side=dfw_side, width=dfw_width, fade=False, solid_black=True) + right_out = apply_side_mask(right_out, side=dfw_side, width=dfw_width, fade=False, solid_black=True) + + + if eye_mode == "left": + final = left_out + elif eye_mode == "right": + final = right_out + else: + final = format_3d_output(left_out, right_out, output_format) # write frame if use_ffmpeg: try: - ffmpeg_proc.stdin.write(final.astype(np.uint8).tobytes()) + if preserve_hdr10: + # ✅ HDR10 path: write 16-bit RGB (rgb48le) to ffmpeg stdin + # Expectation: you must be generating a float RGB tensor [3,H,W] in 0..1 + # (example name: final_tensor). If you only have `final` as uint8 BGR, + # you are NOT preserving HDR10. + ffmpeg_proc.stdin.write(tensor_to_rgb48_bytes(final_tensor)) + else: + # SDR path: write 8-bit BGR + ffmpeg_proc.stdin.write(final.astype(np.uint8).tobytes()) + except Exception as e: print(f"❌ FFmpeg write error: {e}") break @@ -2009,6 +2347,8 @@ def render_sbs_3d_image( color_brightness: float = 0.0, eye_mode: str = "sbs", ): + reset_render_state() + """ Single image version of render_sbs_3d. Runs pixel_shift_cuda with the same depth shaping, parallax logic, and @@ -2073,8 +2413,8 @@ def _val(v): # Optional black bar crop (same logic as video path) cached_crop = (0, 0) if auto_crop_black_bars: - # Reuse the first-frame crop for the entire clip so frame and depth - # stay perfectly aligned and do not jitter. + top_crop, bottom_crop = detect_black_bars(frame_tensor) + cached_crop = (top_crop, bottom_crop) frame_tensor, _ = crop_black_bars_torch(frame_tensor, cached_crop) depth_tensor, _ = crop_black_bars_torch(depth_tensor, cached_crop) @@ -2135,10 +2475,9 @@ def _val(v): eye_h = per_eye_h single_eye = eye_mode in ("left", "right") - if single_eye: - width_for_bars = per_eye_w - else: - width_for_bars = resized_width // 2 + + # Floating window math should always use per-eye width (VR, SBS, single-eye all consistent) + width_for_bars = per_eye_w need_dof = (dof_strength > 0.0) need_color = ( @@ -2181,8 +2520,9 @@ def _val(v): mask_u8, near=ROTO_NEAR, far=ROTO_FAR, - feather_px=ROTO_ROUND_GAMMA, + feather_px=ROTO_FEATHER_PX, round_gamma=ROTO_ROUND_GAMMA, + ) depth_tensor = torch.from_numpy(depth_u8).to(frame_tensor.device).float().unsqueeze(0) / 255.0 @@ -2214,8 +2554,8 @@ def _val(v): left_frame, right_frame = pixel_shift_cuda( frame_tensor, depth_tensor, - resized_width, - resized_height, + eye_w, + eye_h, fg, mg, bg, @@ -2577,12 +2917,20 @@ def process_video( if format_selected == "Full-SBS": output_width = width * 2 output_height = height + elif format_selected == "Half-SBS": output_width = width output_height = height + + elif format_selected == "Passive Interlaced": + # IMPORTANT: same size as original frame (not SBS!) + output_width = width + output_height = height + elif format_selected == "VR": output_width = 4096 output_height = int(output_width / aspect_ratio) + else: output_width = width output_height = int(output_width / aspect_ratio) @@ -2596,7 +2944,7 @@ def process_video( final_render_path = None # 🔥 Start render process - if format_selected in ["Full-SBS", "Half-SBS", "Red-Cyan Anaglyph", "Passive Interlaced"]: + if format_selected in ["Full-SBS", "Half-SBS", "VR", "Red-Cyan Anaglyph", "Passive Interlaced"]: final_render_path = render_sbs_3d( input_path, depth_path, diff --git a/core/render_depth.py b/core/render_depth.py index fa53eed..c39660d 100644 --- a/core/render_depth.py +++ b/core/render_depth.py @@ -60,6 +60,7 @@ def _vd3d_base_dir(): from core.models.depth_anything_v2.dpt import DepthAnythingV2 + global pipe pipe = None pipe_type = None @@ -100,6 +101,63 @@ def _vd3d_base_dir(): "AV1 (QSV - Intel ARC / Gen11+)": "av1_qsv", } +# ===== Codec helpers ===== + +def is_opencv_safe_fourcc(ffmpeg_codec: str) -> bool: + return ffmpeg_codec in ("mp4v", "XVID", "DIVX") + + +def start_ffmpeg_writer(output_path, fps, w, h, ffmpeg_codec): + cmd = [ + "ffmpeg", "-y", + "-hide_banner", "-loglevel", "error", + + # Bigger queue so stdin bursts do not stall as easily + "-thread_queue_size", "512", + + "-f", "rawvideo", + "-pix_fmt", "bgr24", + "-s", f"{w}x{h}", + "-r", str(float(fps)), + "-i", "-", + "-an", + "-c:v", ffmpeg_codec, + ] + + # Codec specific speed tuning + if ffmpeg_codec in ("libx264", "libx265"): + cmd += [ + "-preset", "veryfast", + "-crf", "18", + "-pix_fmt", "yuv420p", + ] + + elif ffmpeg_codec in ("h264_nvenc", "hevc_nvenc", "av1_nvenc"): + # NVENC speed presets: p1 fastest .. p7 best quality + cmd += [ + "-preset", "p2", + "-rc", "vbr", + "-cq", "19", + "-pix_fmt", "yuv420p", + ] + + else: + # Safe default + cmd += ["-pix_fmt", "yuv420p"] + + cmd += [output_path] + + return subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + bufsize=10**8, # big buffer helps a lot for raw video piping + ) + + + + def request_depth_cancel(): """ Called by the global Cancel button when Depth pipeline is running. @@ -404,32 +462,27 @@ def _pred_to_np(pred): return pred.detach().cpu().float().numpy() return np.asarray(pred, dtype=np.float32) -def _run_pipe_or_tile(images_pil, inference_size): +def _run_pipe_or_tile(images_pil, inference_size=None, **kwargs): """ Returns list[{'predicted_depth': ndarray or tensor}] """ global pipe, pipe_type, PIPE_EXTRA_ARGS - extra = PIPE_EXTRA_ARGS if isinstance(PIPE_EXTRA_ARGS, dict) else {} + extra_global = PIPE_EXTRA_ARGS if isinstance(PIPE_EXTRA_ARGS, dict) else {} # --- ONNX special handling: force the warm-up proven size --- if pipe_type == "onnx": - # 1) Prefer the size that succeeded during warm-up good = getattr(pipe, "_good_size", None) - # 2) Pick a base size: good size > user-provided > safe default if good is not None: inf_w, inf_h = good elif inference_size is not None: inf_w, inf_h = int(inference_size[0]), int(inference_size[1]) else: - # conservative safe default for VDA-style exports inf_w, inf_h = 512, 288 - # 3) Snap to multiples of 32 (keeps all intermediate shapes aligned) inf_w, inf_h = snap_for_vda(inf_w, inf_h, base=32) inference_size = (inf_w, inf_h) - # 4) Run the ONNX callable try: res = pipe(images_pil, inference_size=inference_size) if isinstance(res, list): @@ -446,7 +499,10 @@ def _run_pipe_or_tile(images_pil, inference_size): r = r[0] outs.append(r if isinstance(r, dict) else {"predicted_depth": r}) return outs - if USE_TILED_DEPTH: + + # ✅ VDA is sequence-based; never tile it + use_tiled = bool(USE_TILED_DEPTH) and (pipe_type not in ("vda",)) + if use_tiled: preds = [] for img in images_pil: rgb = np.array(img.convert("RGB")) @@ -456,23 +512,35 @@ def _run_pipe_or_tile(images_pil, inference_size): preds.append({"predicted_depth": dep}) return preds + # ✅ Merge per-call kwargs (target_fps, input_size, etc) with global args (steps, etc) + call_kwargs = {} + call_kwargs.update(extra_global) + call_kwargs.update(kwargs) - # Non-tiled - try: - # first try with extra - res = pipe(images_pil, inference_size=inference_size, **extra) - except TypeError: - # retry without extra kwargs (HF depth-estimation won't accept them) + # Some pipes accept kwargs, HF depth-estimation often doesn't + forward_ok = pipe_type in ("vda", "da3", "depthcrafter", "onnx") + + if forward_ok: try: + res = pipe(images_pil, inference_size=inference_size, **call_kwargs) + except TypeError: + # retry without any extras res = pipe(images_pil, inference_size=inference_size) + else: + # HF / generic: try global extras only + try: + res = pipe(images_pil, inference_size=inference_size, **extra_global) except TypeError: - outs = [] - for img in images_pil: - r = pipe(img, inference_size=inference_size) - if isinstance(r, list): - r = r[0] - outs.append(r if isinstance(r, dict) else {"predicted_depth": r}) - return outs + try: + res = pipe(images_pil, inference_size=inference_size) + except TypeError: + outs = [] + for img in images_pil: + r = pipe(img, inference_size=inference_size) + if isinstance(r, list): + r = r[0] + outs.append(r if isinstance(r, dict) else {"predicted_depth": r}) + return outs if isinstance(res, list): return res @@ -481,6 +549,7 @@ def _run_pipe_or_tile(images_pil, inference_size): else: return [{"predicted_depth": res}] + class TemporalDepthNormalizer: """ Keeps a smooth running [lo, hi] range over time so video depth @@ -613,7 +682,7 @@ def detect_letterbox_strict_robust( y_thresh=24, var_thresh=3.0, sat_thresh=10.0, - max_scan_frac=0.25, + max_scan_frac=0.18, min_band_frac=0.06, edge_max=0.06 # rows with more than ~4% edges are not “bars” ): @@ -848,7 +917,7 @@ def update(self, frame_bgr, frame_idx): return self.top, self.bot -# ---------- Cropping (unchanged, with guard) ---------- +# ---------- Cropping ---------- def crop_by_bars(frame_bgr, top, bottom): h = frame_bgr.shape[0] top = max(int(top), 0); bottom = max(int(bottom), 0) @@ -900,8 +969,8 @@ def get_weights_dir(): "Original": None, # General square resolutions - "256x256": (256, 256), - "384x384": (384, 384), + "256x256": (256, 256), + "384x384": (384, 384), "448x448": (448, 448), "518x518": (518, 518), "576x576": (576, 576), @@ -927,8 +996,6 @@ def get_weights_dir(): "1792x1008":(1792, 1008), "1920x1088":(1920, 1088), # NOTE: 1088 instead of 1080 # Squares / general - "256x256": (256, 256), - "384x384": (384, 384), "512x512": (512, 512), "640x640": (640, 640), "768x768": (768, 768), @@ -967,13 +1034,29 @@ def load_supported_models(): # Depth Anything v2 # in load_supported_models() + + "Video Depth Anything Large": "vda:depth-anything/Video-Depth-Anything-Large", + "Video Depth Anything Small": "vda:depth-anything/Video-Depth-Anything-Small", + + "Video Depth Anything (ONNX)": "onnx:VideoDepthAnything", + "Distill-Any-Depth Large(ONNX)": "onnx:DistillAnyDepthLarge", "Distill-Any-Depth Base(ONNX)": "onnx:DistillAnyDepthBase", + "Distill-Any-Depth Small(ONNX)": "onnx:DistillAnyDepthSmall", + + "DA3METRIC-LARGE": "da3:depth-anything/DA3METRIC-LARGE", + "DA3MONO-LARGE": "da3:depth-anything/DA3MONO-LARGE", + "DA3-LARGE": "da3:depth-anything/DA3-LARGE", + "DA3-LARGE-1.1": "da3:depth-anything/DA3-LARGE-1.1", + "DA3-BASE": "da3:depth-anything/DA3-BASE", + "DA3-SMALL": "da3:depth-anything/DA3-SMALL", + "DA3-GIANT": "da3:depth-anything/DA3-GIANT", + "DA3-GIANT-1.1": "da3:depth-anything/DA3-GIANT-1.1", + "DA3NESTED-GIANT-LARGE": "da3:depth-anything/DA3NESTED-GIANT-LARGE", + "DA3NESTED-GIANT-LARGE-1.1": "da3:depth-anything/DA3NESTED-GIANT-LARGE-1.1", + + -# "DA3-GIANT": "depth-anything/DA3-GIANT", -# "DA3-LARGE": "depth-anything/DA3-LARGE", -# "DA3-BASE": "depth-anything/DA3-BASE", -# "DA3-SMALL": "depth-anything/DA3-SMALL", "Depth Anything v2 Large": "depth-anything/Depth-Anything-V2-Large-hf", "Depth Anything v2 Base": "depth-anything/Depth-Anything-V2-Base-hf", @@ -995,6 +1078,7 @@ def load_supported_models(): # Other popular models # "DA-2 (Haodongli)": "haodongli/DA-2", # "Bridge (Dingning)": "Dingning/BRIDGE", +# "Pixel-Perfect-Depth": "gangweix/Pixel-Perfect-Depth", "LBM Depth": "jasperai/LBM_depth", "DepthPro (Apple)": "apple/DepthPro-hf", "ZoeDepth (NYU+KITTI)": "Intel/zoedepth-nyu-kitti", @@ -1084,6 +1168,28 @@ def ensure_model_downloaded(checkpoint, use_fp16: bool = False): except Exception as e: print(f"❌ DA-V2 adapter failed: {e}") return None, None + + # --- DepthAnything v3 adapter: da3: --- + if isinstance(checkpoint, str) and checkpoint.startswith(("da3:", "dav3:")): + from core.adapters.depthanything3_adapter import load_da3_adapter + spec = checkpoint.split(":", 1)[1].strip() + print(f"🧩 Loading DA3 adapter for: {spec}") + try: + return load_da3_adapter(spec, cache_dir=local_model_dir, use_fp16=use_fp16) + except Exception as e: + print(f"❌ DA3 adapter failed: {e}") + return None, None + + # --- Video Depth Anything adapter: vda: --- + if isinstance(checkpoint, str) and checkpoint.startswith("vda:"): + from core.adapters.videodepthanything_adapter import load_vda_adapter + spec = checkpoint.split(":", 1)[1].strip() + print(f"🧩 Loading VDA adapter for: {spec}") + try: + return load_vda_adapter(spec, cache_dir=local_model_dir, use_fp16=use_fp16) + except Exception as e: + print(f"❌ VDA adapter failed: {e}") + return None, None # --- LBM Adapter --- if isinstance(checkpoint, str) and "lbm" in checkpoint.lower(): @@ -1117,7 +1223,13 @@ def ensure_model_downloaded(checkpoint, use_fp16: bool = False): fixed_dir, torch_dtype=dtype if torch.cuda.is_available() else None, ) + + if torch.cuda.is_available(): + model = model.to(memory_format=torch.channels_last) + model.eval() + processor = _load_flexible_processor(fixed_dir, prefer_fast=True) + print(f"📂 Loaded local Hugging Face model from {fixed_dir}") return model, processor except Exception as e: @@ -1195,39 +1307,42 @@ def generic_diffusers_call(x, **kw): # --- Hugging Face online model (tolerant to custom names) --- safe_folder_name = checkpoint.replace("/", "_") local_path = os.path.join(local_model_dir, safe_folder_name) + try: # Try standard load first model = AutoModelForDepthEstimation.from_pretrained( checkpoint, cache_dir=local_path, ) + + if torch.cuda.is_available(): + model = model.to(memory_format=torch.channels_last) + model.eval() + processor = _load_flexible_processor(checkpoint, cache_dir=local_path, prefer_fast=True) print(f"⬇️ Downloaded model from Hugging Face: {checkpoint}") return model, processor + except Exception as e1: print(f"⚠️ Standard HF load failed, trying normalization: {e1}") try: - # Pull a snapshot, normalize names, then load from the local folder from huggingface_hub import snapshot_download snap_dir = snapshot_download( repo_id=checkpoint, cache_dir=local_path, local_files_only=False, ) - # Local HF folder load + fixed_dir = _ensure_expected_weight_name(snap_dir) - if torch.cuda.is_available() and use_fp16: - model = AutoModelForDepthEstimation.from_pretrained( - fixed_dir, - torch_dtype=dtype, - ) - else: - model = AutoModelForDepthEstimation.from_pretrained(fixed_dir) - try: - p0 = next(model.parameters()) - print(f"🧪 Depth model loaded | dtype={p0.dtype} device={p0.device}") - except Exception: - pass + + model = AutoModelForDepthEstimation.from_pretrained( + fixed_dir, + torch_dtype=dtype if (torch.cuda.is_available() and use_fp16) else None, + ) + + if torch.cuda.is_available(): + model = model.to(memory_format=torch.channels_last) + model.eval() processor = _load_flexible_processor(fixed_dir, cache_dir=local_path, prefer_fast=True) if processor is None: @@ -1235,12 +1350,14 @@ def generic_diffusers_call(x, **kw): print(f"🛠️ Normalized non-standard weights; loaded from {fixed_dir}") return model, processor + except Exception as e2: print(f"❌ Failed to load Hugging Face model after normalization: {e2}") return None, None + def load_onnx_model(model_dir, device="CUDAExecutionProvider"): import onnxruntime as ort model_path = os.path.join(model_dir, "model.onnx") @@ -1251,11 +1368,15 @@ def load_onnx_model(model_dir, device="CUDAExecutionProvider"): print(f"🧠 Loading ONNX model from: {model_path}") so = ort.SessionOptions() - # Stay conservative for VDA and friends - so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL - so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - so.intra_op_num_threads = 1 - so.inter_op_num_threads = 1 + # Safe performance + so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL # keep safe + # Let ORT decide threads unless you KNOW better + so.intra_op_num_threads = 0 + so.inter_op_num_threads = 0 + so.enable_mem_pattern = True + so.enable_cpu_mem_arena = True + # Multi backend detection, but still safe available = ort.get_available_providers() @@ -1313,25 +1434,52 @@ def load_onnx_model(model_dir, device="CUDAExecutionProvider"): except Exception: fixed_T = None - # 📝 Some ONNX exports are locked to a single input size (e.g. DistillAnyDepthBase) + # 📝 Some ONNX exports are locked to a single input size (e.g. DistillAnyDepth*) fixed_HW = None model_tag = os.path.basename(os.path.normpath(model_dir)).lower() - if "distillanydepthbase" in model_tag: + + # 🔍 Debug: confirm which model folder name we’re matching against + print(f"🏷️ ONNX model_tag: {model_tag}") + + if "distillanydepth" in model_tag: fixed_HW = (518, 518) - print(f"📐 Detected DistillAnyDepthBase ONNX – forcing fixed input size {fixed_HW}") + print(f"📐 Detected DistillAnyDepth ONNX – forcing fixed input size {fixed_HW}") print(f"🔎 Input shape: {input_shape} | Rank: {input_rank} | fixed_T={fixed_T} | fixed_HW={fixed_HW}") - def _prep_images(images, inference_size): arrs = [] + metas = [] for img in images: if inference_size: - img = img.resize(inference_size, Image.BICUBIC) - x = np.asarray(img, dtype=np.float32) / 255.0 # HWC - x = ((x - IMAGENET_MEAN) / IMAGENET_STD).transpose(2, 0, 1).copy() # CHW + W, H = inference_size + ow, oh = img.size + + # scale that fits inside W,H + scale = min(W / ow, H / oh) + nw, nh = int(round(ow * scale)), int(round(oh * scale)) + + img_r = img.resize((nw, nh), Image.BICUBIC) + + # compute padding to center + pad_left = (W - nw) // 2 + pad_top = (H - nh) // 2 + pad_right = W - nw - pad_left + pad_bottom = H - nh - pad_top + + img_p = ImageOps.expand(img_r, border=(pad_left, pad_top, pad_right, pad_bottom), fill=(0, 0, 0)) + + metas.append((pad_left, pad_top, nw, nh, W, H)) + img = img_p + else: + metas.append((0, 0, img.size[0], img.size[1], img.size[0], img.size[1])) + + x = np.asarray(img, dtype=np.float32) / 255.0 + x = ((x - IMAGENET_MEAN) / IMAGENET_STD).transpose(2, 0, 1).copy() arrs.append(x) - return arrs + + return arrs, metas + def run_onnx(images, inference_size=None): # Some models (like DistillAnyDepthBase) only work at one exact size. @@ -1347,7 +1495,7 @@ def run_onnx(images, inference_size=None): inference_size = (W, H) - img_batch = _prep_images(images, inference_size) + img_batch, metas = _prep_images(images, inference_size) if input_rank == 5: T = len(img_batch) @@ -1358,6 +1506,8 @@ def run_onnx(images, inference_size=None): img_batch = img_batch[:fixed_T] T = fixed_T input_tensor = np.stack(img_batch, axis=0)[None, ...] # [1, T, 3, H, W] + + elif input_rank == 4: input_tensor = np.stack(img_batch, axis=0) # [B, 3, H, W] else: @@ -1619,7 +1769,84 @@ def warmup_thread(): status_label_widget, f"✅ Diffusers pipeline loaded: {selected_checkpoint} (device: {device_display_name()})" )) + + # --- Video Depth Anything adapter callable --- + is_vda = bool(caps.get("kind") == "vda" or caps.get("is_video_model", False)) + + if is_vda: + pipe = model_callable + pipe_type = "vda" + skip_warmup = True + + status_label_widget.after( + 0, + lambda: start_spinner(status_label_widget, "🔄 Warming up Video Depth Anything...") + ) + + # VDA is sequence-based. Warm it up with a tiny fake “clip” + if not skip_warmup: + try: + # Small, short clip for warmup + dummy_frames = [ + Image.new("RGB", (512, 288), (127, 127, 127)) + for _ in range(4) + ] + + # Let adapter infer; pass an input_size if you want it explicit + _ = pipe(dummy_frames, inference_size=(512, 288), input_size=518, target_fps=24) + + print("🔥 VDA warmed up with dummy clip") + except Exception as e: + print(f"⚠️ VDA warm-up failed: {e}") + else: + print("⏭️ Skipping VDA warm-up by request.") + + status_label_widget.after( + 0, + lambda: stop_spinner( + status_label_widget, + f"✅ VDA model loaded: {selected_checkpoint} (device: {device_display_name()})" + ) + ) + return + + + # --- Depth Anything v3 adapter callable --- + is_da3 = bool(caps.get("kind") == "da3" or caps.get("has_builtin_processor", False)) + + if is_da3: + pipe = model_callable + pipe_type = "da3" + status_label_widget.after( + 0, + lambda: start_spinner(status_label_widget, "🔄 Warming up Depth Anything v3...") + ) + + if not skip_warmup: + try: + # Use a sane default warmup size. DA3 also has process_res internally. + dummy = Image.new("RGB", (512, 288), (127, 127, 127)) + + # If your adapter supports inference_size, keep it consistent with your other pipes: + _ = pipe([dummy], process_res=756) + + + print("🔥 DA3 warmed up with dummy frame") + except Exception as e: + print(f"⚠️ DA3 warm-up failed: {e}") + else: + print("⏭️ Skipping DA3 warm-up by request.") + + status_label_widget.after( + 0, + lambda: stop_spinner( + status_label_widget, + f"✅ DA3 model loaded: {selected_checkpoint} (device: {device_display_name()})" + ) + ) + return + else: processor = meta raw_pipe = pipeline( @@ -1630,18 +1857,24 @@ def warmup_thread(): ) def hf_batch_safe_pipe(images, inference_size=None, **_): - # resize + # resize ONCE here if inference_size: if isinstance(images, list): images = [img.resize(inference_size, Image.BICUBIC) for img in images] else: images = images.resize(inference_size, Image.BICUBIC) - # call raw HF pipeline (it supports single image or list) - if isinstance(images, list): - return raw_pipe(images) - else: - return [raw_pipe(images)] + # inference guards + use_fp16 = bool(fp16_var.get()) and (torch_device.type in ("cuda", "mps")) + + if torch.cuda.is_available(): + with torch.inference_mode(): + if use_fp16: + with torch.autocast("cuda", dtype=torch.float16): + return raw_pipe(images) if isinstance(images, list) else [raw_pipe(images)] + else: + return raw_pipe(images) if isinstance(images, list) else [raw_pipe(images)] + pipe = hf_batch_safe_pipe pipe_type = "hf" @@ -2091,6 +2324,7 @@ def open_image(status_label_widget, progress_bar_widget, colormap_var, invert_va def process_video_folder( folder_path, batch_size_widget, + codec_var, inference_steps_entry, output_dir_var, inference_res_var, @@ -2132,6 +2366,7 @@ def process_video_folder( progress_bar, cancel_requested, invert_var, + codec_var, inference_steps_entry, ), daemon=True @@ -2236,6 +2471,7 @@ def process_video2( save_frames=False, target_fps=15, ignore_letterbox_bars=False, + prefer_opencv_writer=False, ): @@ -2333,20 +2569,25 @@ def _warn(): original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + # pick how often to run VDA + if pipe_type == "vda" and target_fps and target_fps > 0 and fps and fps > target_fps: + stride = max(1, int(round(fps / target_fps))) + else: + stride = 1 + # ... after you read fps/original_w/h tracker = LetterboxTracker(original_height, fps) - bars_top, bars_bottom, (locked_bars, locked_zero) = tracker.bootstrap(cap) - # Try to reuse previously measured bars from sidecar (stabilizes pass 2 / other eye) + bars_top, bars_bottom, (_lb, _lz) = tracker.bootstrap(cap) + + # 1) Try sidecar first (best signal) try: - # 1) Prefer a sidecar next to the current input (same basename) candidate_in = os.path.splitext(file_path)[0] + ".letterbox.json" meta = None if os.path.exists(candidate_in): with open(candidate_in, "r", encoding="utf-8") as f: meta = json.load(f) - # 2) Otherwise, if we’re writing a _depth.mkv, look for a sidecar from pass 1 if meta is None: sibling = os.path.splitext(os.path.join( os.path.dirname(file_path), @@ -2362,15 +2603,39 @@ def _warn(): tracker.top, tracker.bot = t, b tracker.locked_bars = (t + b) > 0 tracker.locked_zero = (t + b) == 0 + tracker._cooldown = 0 bars_top, bars_bottom = t, b print(f"[VD3D] Sidecar override: top={t} bottom={b}") - except Exception as _e: + except Exception: pass - + + if tracker.prev_gray is None and (bars_top + bars_bottom) == 0: + + # 2) If still zero, probe ~2 seconds in (skip dark intros) + if (bars_top + bars_bottom) == 0: + pos_backup = cap.get(cv2.CAP_PROP_POS_FRAMES) + cap.set(cv2.CAP_PROP_POS_MSEC, 2000) + ok, f = cap.read() + cap.set(cv2.CAP_PROP_POS_FRAMES, pos_backup or 0) + + if ok and not is_near_black_frame(f): + t2, b2 = detect_letterbox_strict_robust(f) + if (t2 + b2) > 0: + tracker.top, tracker.bot = t2, b2 + tracker.locked_bars = True + tracker.locked_zero = False + tracker._cooldown = 0 + bars_top, bars_bottom = t2, b2 + print(f"[VD3D] Fallback probe bars: top={t2} bottom={b2}") + + + # 3) Now print real current lock state + locked_bars = tracker.locked_bars + locked_zero = tracker.locked_zero print(f"[VD3D] Bootstrap bars: top={bars_top} bottom={bars_bottom} | " f"locked_bars={locked_bars} locked_zero={locked_zero}") - # (Optional) write sidecar using bootstrapped values: + # 4) Write sidecar using real current lock state try: sidecar = os.path.splitext(output_path)[0] + ".letterbox.json" with open(sidecar, "w", encoding="utf-8") as f: @@ -2382,7 +2647,6 @@ def _warn(): except Exception as e: print(f"⚠️ Failed to write letterbox sidecar: {e}") - print(f"📁 Saving video to: {output_path}") # Codec handling @@ -2392,28 +2656,40 @@ def _warn(): else: ffmpeg_codec = None - # Try to pick a matching OpenCV FourCC - if ffmpeg_codec: - if "264" in ffmpeg_codec: - fourcc = cv2.VideoWriter_fourcc(*"H264") # H.264 - elif "265" in ffmpeg_codec: - fourcc = cv2.VideoWriter_fourcc(*"HEVC") # HEVC / H.265 - elif "xvid" in ffmpeg_codec.lower(): - fourcc = cv2.VideoWriter_fourcc(*"XVID") - elif "mp4v" in ffmpeg_codec.lower(): - fourcc = cv2.VideoWriter_fourcc(*"mp4v") + # Prefer FFmpeg pipe by default (fastest). + # Only use OpenCV if user explicitly requests it (troubleshooting). + use_opencv = bool(prefer_opencv_writer) and (ffmpeg_codec is None or is_opencv_safe_fourcc(ffmpeg_codec)) + + ff_proc = None + out = None + + if use_opencv: + # OpenCV FourCC mapping (ONLY for safe codecs) + if ffmpeg_codec: + if ffmpeg_codec.lower() == "mp4v": + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + elif ffmpeg_codec.upper() == "XVID": + fourcc = cv2.VideoWriter_fourcc(*"XVID") + elif ffmpeg_codec.upper() == "DIVX": + fourcc = cv2.VideoWriter_fourcc(*"DIVX") + else: + fourcc = cv2.VideoWriter_fourcc(*"XVID") else: - fourcc = cv2.VideoWriter_fourcc(*"XVID") # fallback - else: - fourcc = cv2.VideoWriter_fourcc(*"XVID") - - out = cv2.VideoWriter(output_path, fourcc, fps, (original_width, original_height)) + fourcc = cv2.VideoWriter_fourcc(*"XVID") - if not out.isOpened(): - print("⚠️ Failed to create writer with chosen codec. Falling back to XVID.") - fourcc = cv2.VideoWriter_fourcc(*"XVID") out = cv2.VideoWriter(output_path, fourcc, fps, (original_width, original_height)) - + + if not out.isOpened(): + print("⚠️ OpenCV writer failed. Falling back to FFmpeg pipe.") + out = None + use_opencv = False + + if not use_opencv: + # FFmpeg pipe writer for everything else + if not ffmpeg_codec: + ffmpeg_codec = "libx264" + ff_proc = start_ffmpeg_writer(output_path, fps, original_width, original_height, ffmpeg_codec) + def cleanup_video_handles(cap_obj, out_obj): try: if cap_obj is not None: @@ -2441,9 +2717,18 @@ def cleanup_video_handles(cap_obj, out_obj): write_index = 0 frames_batch = [] total_processed_frames = 0 - + # For VDA stride mapping + repeat_counts = [] # same length as frames_batch + bars_batch = [] # (top,bottom) per inferred frame + inference_size = parse_inference_resolution(inference_res_var.get()) - resize_required = inference_size is not None + if inference_size is not None: + target_w, target_h = map(int, inference_size) + interp = cv2.INTER_AREA if (target_w < original_width or target_h < original_height) else cv2.INTER_LINEAR + else: + target_w = target_h = None + interp = None + try: inference_steps = int(inference_steps_entry.get().strip()) @@ -2467,6 +2752,8 @@ def cleanup_video_handles(cap_obj, out_obj): global_session_start_time = time.time() previous_depth = None + neutral_u8 = np.full((original_height, original_width), 128, dtype=np.uint8) + # Smooth per-video depth normalization to avoid global flicker temp_normalizer = TemporalDepthNormalizer(pclip=(1.0, 99.0), momentum=0.95) prev_depth_u8 = None # for light temporal smoothing @@ -2489,20 +2776,40 @@ def cleanup_video_handles(cap_obj, out_obj): break frame_count += 1 - # When True, we IGNORE bars; when False, we HANDLE bars - if not ignore_letterbox_bars: + # NOTE: Current UI passes ignore_letterbox_bars=True. + # We interpret True as: DETECT/USE bars (so the "fill bars" logic can run). + if ignore_letterbox_bars: bars_top, bars_bottom = tracker.update(frame, frame_count) else: bars_top, bars_bottom = 0, 0 + + # DA3 + VDA do their own internal sizing (process_res / input_size), so don't resize here + if pipe_type in ("da3", "vda"): + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + else: + if inference_size is None: + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + else: + frame_rs = cv2.resize(frame, (target_w, target_h), interpolation=interp) + frame_rgb = cv2.cvtColor(frame_rs, cv2.COLOR_BGR2RGB) + + # Decide whether to infer this frame (VDA stride) or skip it + do_infer = True + if pipe_type == "vda" and stride > 1: + do_infer = ((frame_count - 1) % stride == 0) + + if do_infer: + frames_batch.append(Image.fromarray(frame_rgb)) + repeat_counts.append(1) # start coverage at 1 + bars_batch.append((bars_top, bars_bottom)) + else: + # This frame is skipped from inference, so extend coverage of the last inferred frame + if repeat_counts: + repeat_counts[-1] += 1 - frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - pil_image = Image.fromarray(frame_rgb) - - if resize_required: - pil_image = pil_image.resize(inference_size, Image.BICUBIC) - frames_batch.append(pil_image) + # Trigger inference when we have enough inferred frames, or at end of video + if len(frames_batch) == batch_size or (frame_count == total_frames and frames_batch): - if len(frames_batch) == batch_size or frame_count == total_frames: wait_if_paused(status_label) if cancel_requested.is_set(): @@ -2572,7 +2879,19 @@ def cleanup_video_handles(cap_obj, out_obj): # continue else: - predictions = _run_pipe_or_tile(frames_batch, inference_size) + extra = {} + if pipe_type == "vda": + extra = { + "target_fps": int(target_fps) if target_fps and target_fps > 0 else int(fps), + "input_size": 518, # or expose it later, but this is the VDA default +# "fp32": True, +# "max_res": 1280, # optional cap if your adapter supports it + } + + # DA3 can also take process_res here if you want to override per-video: + # if pipe_type == "da3": extra["process_res"] = 756 + + predictions = _run_pipe_or_tile(frames_batch, inference_size, **extra) assert isinstance(predictions, list), "Expected list of predictions from pipeline" @@ -2616,7 +2935,14 @@ def cleanup_video_handles(cap_obj, out_obj): + (1.0 - alpha) * depth_u8.astype(np.float32) ).astype(np.uint8) prev_depth_u8 = smoothed_u8 + + cover = repeat_counts[i] if (pipe_type == "vda" and i < len(repeat_counts)) else 1 + bt, bb = bars_batch[i] if (pipe_type == "vda" and i < len(bars_batch)) else (bars_top, bars_bottom) + # Temporarily override bars so repeats use the SAME bars + old_top, old_bottom = bars_top, bars_bottom + bars_top, bars_bottom = bt, bb + # Optional: extra temporal median filter across a short history # depth_history.append(smoothed_u8) # if len(depth_history) > 1: @@ -2626,30 +2952,34 @@ def cleanup_video_handles(cap_obj, out_obj): # depth_u8 = smoothed_u8 # letterbox handling (unchanged) + # Fill detected letterbox bars with neutral depth (median of the core) if ignore_letterbox_bars and (bars_top or bars_bottom): - core_h = original_height - bars_top - bars_bottom - if core_h <= 0: - bars_top = bars_bottom = 0 - core_h = original_height - gray_core = cv2.resize( - depth_u8, - (original_width, core_h), - interpolation=cv2.INTER_CUBIC - ) - neutral = int(np.median(gray_core)) if gray_core.size else 0 - full_gray = np.full( - (original_height, original_width), - neutral, - dtype=np.uint8 - ) - full_gray[bars_top:bars_top+core_h, :] = gray_core - bgr = cv2.cvtColor(full_gray, cv2.COLOR_GRAY2BGR) - to_save = full_gray + top = max(0, int(bars_top)) + bot = max(0, int(bars_bottom)) + if top + bot < original_height: + full_gray = depth_u8.copy() + core = full_gray[top:original_height - bot, :] + neutral = int(np.median(core)) if core.size else 0 + + if top > 0: + full_gray[:top, :] = neutral + if bot > 0: + full_gray[original_height - bot:, :] = neutral + + bgr = cv2.cvtColor(full_gray, cv2.COLOR_GRAY2BGR) + to_save = full_gray + else: + bgr = cv2.cvtColor(depth_u8, cv2.COLOR_GRAY2BGR) + to_save = depth_u8 else: bgr = cv2.cvtColor(depth_u8, cv2.COLOR_GRAY2BGR) to_save = depth_u8 + + if use_opencv: + out.write(bgr) + else: + ff_proc.stdin.write(bgr.tobytes()) - out.write(bgr) if save_frames: frame_filename = os.path.join( frame_output_dir, @@ -2658,7 +2988,20 @@ def cleanup_video_handles(cap_obj, out_obj): cv2.imwrite(frame_filename, to_save) write_index += 1 total_processed_frames += 1 - + + for _ in range(max(0, cover - 1)): + if use_opencv: + out.write(bgr) + else: + ff_proc.stdin.write(bgr.tobytes()) + if save_frames: + frame_filename = os.path.join(frame_output_dir, f"frame_{write_index:05d}.png") + cv2.imwrite(frame_filename, to_save) + write_index += 1 + total_processed_frames += 1 + + bars_top, bars_bottom = old_top, old_bottom + except Exception as e: print(f"⚠️ Depth processing error: {e}") @@ -2666,13 +3009,16 @@ def cleanup_video_handles(cap_obj, out_obj): if cancel_requested.is_set(): break - if frame_count % 100 == 0: - if torch.cuda.is_available(): + if torch.cuda.is_available() and frame_count % 300 == 0: + # optional: only if reserved > X + if torch.cuda.memory_reserved() > 0.90 * torch.cuda.get_device_properties(0).total_memory: torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect() frames_batch.clear() + repeat_counts.clear() + bars_batch.clear() elapsed = time.time() - global_session_start_time avg_fps = total_processed_frames / elapsed if elapsed > 0 else 0 @@ -2696,6 +3042,20 @@ def cleanup_video_handles(cap_obj, out_obj): # Always release handles and clean up GPU/CPU memory cleanup_video_handles(cap, out) + # ✅ Close FFmpeg pipe if we used it + if ff_proc is not None: + try: + ff_proc.stdin.close() + except Exception: + pass + try: + ff_proc.wait(timeout=10) + except Exception: + try: + ff_proc.kill() + except Exception: + pass + if cancel_requested.is_set(): ui_set_status("🛑 Cancelled.") ui_set_progress(0) @@ -2793,8 +3153,9 @@ def open_video(status_label, progress_bar, batch_size_widget, output_dir_var, in ), kwargs={ "offload_mode_dropdown": offload_mode_dropdown, - "target_fps": 15, + "target_fps": 8, "ignore_letterbox_bars": True, + "prefer_opencv_writer": False, } ).start() @@ -2812,4 +3173,3 @@ def _thread_hook(args): _log_ex(args.exc_type, args.exc_value, args.exc_traceback) threading.excepthook = _thread_hook - diff --git a/core/vd3d_live.py b/core/vd3d_live.py index 253fba3..edf0aa3 100644 --- a/core/vd3d_live.py +++ b/core/vd3d_live.py @@ -323,35 +323,39 @@ def depth_from_frame_fast(model, norm, device: str, t_inp = torch.empty((1, 3, ih, iw), device=device, dtype=torch.float32) _STAGING["inp"] = t_inp - t_inp.copy_( - torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0).to( - device=device, dtype=torch.float32 - ), - non_blocking=True, - ) + src = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0).contiguous() # CPU uint8 + t_inp.copy_(src, non_blocking=False) # copy CPU->GPU into persistent tensor + mean, std = norm t_inp = (t_inp / 255.0 - mean) / std - with torch.inference_mode(), torch.autocast( - device_type=device, dtype=getattr(next(model.parameters()), "dtype", torch.float16) - ): - out = model(pixel_values=t_inp).predicted_depth + with torch.inference_mode(): + if device == "cuda": + model_dtype = next(model.parameters()).dtype + with torch.autocast(device_type="cuda", dtype=model_dtype): + out = model(pixel_values=t_inp).predicted_depth + else: + out = model(pixel_values=t_inp).predicted_depth + if out.ndim == 4: out = out[:, 0] + pred = torch.nn.functional.interpolate( out.unsqueeze(1).float(), size=(h, w), mode="bicubic", align_corners=False, ).squeeze(1) + depth = pred[0] - d = depth.flatten() - lo = torch.quantile(d, 0.01) - hi = torch.quantile(d, 0.99) + + lo = torch.amin(depth) + hi = torch.amax(depth) depth01 = torch.clamp((depth - lo) / (hi - lo + 1e-6), 0, 1) - return depth01.detach().cpu().numpy().astype(np.float32) + return depth01 + # -------------------- Utilities -------------------- # def sbs_pack_gpu_rgb(left_t: torch.Tensor, right_t: torch.Tensor) -> np.ndarray: @@ -421,6 +425,10 @@ def run_live(args, external_stop: threading.Event | None = None): raise frame_q, stop_cap = start_latest_capture(cap) + # --- persistent GPU buffers to reduce per-frame allocations --- + frm_gpu_u8 = None # uint8 CUDA buffer (3,H,W) + frm_gpu_f32 = None # float32 CUDA buffer (3,H,W) + if diag: print("[diag] args:", vars(args)) @@ -475,7 +483,6 @@ def run_live(args, external_stop: threading.Event | None = None): fps_ema = None t_last = time.time() - depth01 = None depth_last_t = 0.0 depth_period = 1.0 / max(1e-3, args.depth_fps) @@ -516,6 +523,8 @@ def run_live(args, external_stop: threading.Event | None = None): print("▶️ Streaming… (f=fullscreen, m=mode, q=quit)") out_bgr = first + depth01_t = None + while True: if external_stop is not None and external_stop.is_set(): @@ -560,54 +569,61 @@ def run_live(args, external_stop: threading.Event | None = None): if x1 > x0 and y1 > y0: frame[y0:y1, x0:x1] = 0 - now = time.time() + now = time.perf_counter() # Depth update - if (depth01 is None) or (now - depth_last_t >= depth_period): + if (depth01_t is None) or (now - depth_last_t >= depth_period): depth_new = depth_from_frame_fast( model, proc, device, frame, (args.infer_w, args.infer_h) ) depth_last_t = now + if args.smooth: if depth_ema is None: depth_ema = depth_new else: depth_ema = (1.0 - ema_alpha) * depth_ema + ema_alpha * depth_new - depth01 = cv2.medianBlur( - (depth_ema * 255).astype(np.uint8), 3 - ).astype(np.float32) / 255.0 + depth01_t = depth_ema else: - depth01 = depth_new + depth01_t = depth_new + # View modes if view_mode == 0: out_bgr = frame elif view_mode == 1: - d8 = (depth01 * 255.0).astype(np.uint8) - out_bgr = cv2.applyColorMap(d8, cv2.COLORMAP_VIRIDIS) + d_cpu = (depth01_t * 255.0).clamp(0,255).byte().detach().cpu().numpy() + out_bgr = cv2.applyColorMap(d_cpu, cv2.COLORMAP_VIRIDIS) + else: - if HAVE_PIXEL_SHIFT and pixel_shift_cuda is not None and CUDA_AVAILABLE: - if getattr(args, "pixelshift_rgb", False): - base = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - else: - base = frame + # 3D-SBS mode + if HAVE_PIXEL_SHIFT and (pixel_shift_cuda is not None) and CUDA_AVAILABLE and (depth01_t is not None): + h, w = frame.shape[:2] - frm_t = torch.from_numpy(base).permute(2, 0, 1).to( - "cuda", dtype=torch.float32 - ).div_(255.0) - d_t = torch.from_numpy(depth01).to( - "cuda", dtype=torch.float32 - ).unsqueeze(0) + # Choose base in CPU memory + base_cpu = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if args.pixelshift_rgb else frame + + # Ensure contiguous CPU tensor view + tmp_cpu = torch.from_numpy(base_cpu).permute(2, 0, 1).contiguous() # uint8 CPU + + # Allocate persistent CUDA buffers once per resolution + if (frm_gpu_u8 is None) or (frm_gpu_u8.shape[1] != h) or (frm_gpu_u8.shape[2] != w): + frm_gpu_u8 = torch.empty((3, h, w), device="cuda", dtype=torch.uint8) + frm_gpu_f32 = torch.empty((3, h, w), device="cuda", dtype=torch.float32) + + # Copy CPU -> GPU (no new allocation) + frm_gpu_u8.copy_(tmp_cpu, non_blocking=False) + + # Convert to float + normalize using persistent buffer + frm_gpu_f32.copy_(frm_gpu_u8) # uint8 -> float32 conversion + frm_t = frm_gpu_f32.mul_(1.0 / 255.0) # in-place normalize 0..1 + + # Depth tensor for pixel shift (already CUDA, already 0..1) + d_t = depth01_t.unsqueeze(0) # shape [1,H,W] - h, w = frame.shape[:2] left, right = pixel_shift_cuda( - frm_t, - d_t, - w, - h, - args.fg_shift, - args.mg_shift, - args.bg_shift, + frm_t, d_t, w, h, + args.fg_shift, args.mg_shift, args.bg_shift, blur_ksize=9, feather_strength=12.0, return_shift_map=False, @@ -615,40 +631,45 @@ def run_live(args, external_stop: threading.Event | None = None): enable_edge_masking=True, ) + # ---- handle both torch and numpy returns ---- if isinstance(left, torch.Tensor) and isinstance(right, torch.Tensor): + # make sure on cuda and CHW if left.device.type != "cuda": - left = left.to("cuda") + left = left.to("cuda", non_blocking=True) if right.device.type != "cuda": - right = right.to("cuda") + right = right.to("cuda", non_blocking=True) - if left.dim() == 3 and left.shape[0] != 3: - left = left.permute(2, 0, 1) - right = right.permute(2, 0, 1) + if left.ndim == 3 and left.shape[0] != 3: # likely HWC + left = left.permute(2, 0, 1).contiguous() + right = right.permute(2, 0, 1).contiguous() - if getattr(args, "pixelshift_rgb", False): + if args.pixelshift_rgb: out_bgr = sbs_pack_gpu_rgb(left, right) else: out_bgr = sbs_pack_gpu_bgr(left, right) + else: - if left.dtype != np.uint8: - left_u8 = np.clip(left * 255.0, 0, 255).astype(np.uint8) - right_u8 = np.clip(right * 255.0, 0, 255).astype(np.uint8) - else: - left_u8, right_u8 = left, right + # numpy path + left_np = left + right_np = right - if left_u8.ndim == 3 and left_u8.shape[0] == 3: - left_u8 = np.transpose(left_u8, (1, 2, 0)) - right_u8 = np.transpose(right_u8, (1, 2, 0)) + # if float 0..1, convert to uint8 + if left_np.dtype != np.uint8: + left_np = np.clip(left_np * 255.0, 0, 255).astype(np.uint8) + right_np = np.clip(right_np * 255.0, 0, 255).astype(np.uint8) - if getattr(args, "pixelshift_rgb", False): - left_bgr = cv2.cvtColor(left_u8, cv2.COLOR_RGB2BGR) - right_bgr = cv2.cvtColor(right_u8, cv2.COLOR_RGB2BGR) - else: - left_bgr, right_bgr = left_u8, right_u8 + # if CHW, convert to HWC + if left_np.ndim == 3 and left_np.shape[0] == 3: + left_np = np.transpose(left_np, (1, 2, 0)) + right_np = np.transpose(right_np, (1, 2, 0)) + + # if RGB input, convert to BGR for OpenCV output + if args.pixelshift_rgb: + left_np = cv2.cvtColor(left_np, cv2.COLOR_RGB2BGR) + right_np = cv2.cvtColor(right_np, cv2.COLOR_RGB2BGR) + + out_bgr = np.hstack([left_np, right_np]) - out_bgr = np.hstack([left_bgr, right_bgr]) - else: - out_bgr = np.hstack([frame, frame]) # Virtual cam if vcam is None and args.virtualcam and HAVE_VCAM: @@ -744,35 +765,35 @@ def _build_vars(self): self.source_var = tk.StringVar(value="device") # "device" or "screen:1" self.device_index_var = tk.IntVar(value=0) self.backend_var = tk.StringVar(value=default_backend) - self.capture_fps_var = tk.IntVar(value=60) + self.capture_fps_var = tk.IntVar(value=30) self.width_var = tk.IntVar(value=0) # 0 = auto self.height_var = tk.IntVar(value=0) # 0 = auto self.cam_fps_var = tk.IntVar(value=0) # 0 = no explicit FPS self.fourcc_var = tk.StringVar(value="") self.no_capture_swap_var = tk.BooleanVar(value=False) - self.force_bgr_swap_var = tk.BooleanVar(value=False) + self.force_bgr_swap_var = tk.BooleanVar(value=True) # Depth / model self.model_var = tk.StringVar( - value="depth-anything/Depth-Anything-V2-Small-hf" + value="depth-anything/Depth-Anything-V2-Large-hf" ) self.fp16_var = tk.BooleanVar(value=CUDA_AVAILABLE) - self.infer_w_var = tk.IntVar(value=448) - self.infer_h_var = tk.IntVar(value=256) - self.depth_fps_var = tk.DoubleVar(value=8.0) - self.smooth_var = tk.BooleanVar(value=True) - self.ema_var = tk.DoubleVar(value=0.4) + self.infer_w_var = tk.IntVar(value=320) + self.infer_h_var = tk.IntVar(value=180) + self.depth_fps_var = tk.DoubleVar(value=5.0) + self.smooth_var = tk.BooleanVar(value=False) + self.ema_var = tk.DoubleVar(value=0.35) # 3D / Pixel-shift self.sbs_var = tk.BooleanVar(value=True) - self.fg_shift_var = tk.DoubleVar(value=7.0) - self.mg_shift_var = tk.DoubleVar(value=3.0) - self.bg_shift_var = tk.DoubleVar(value=-5.0) + self.fg_shift_var = tk.DoubleVar(value=8) + self.mg_shift_var = tk.DoubleVar(value=2.0) + self.bg_shift_var = tk.DoubleVar(value=-4.0) self.pixelshift_rgb_var = tk.BooleanVar(value=False) # Preview / window self.preview_var = tk.BooleanVar(value=True) - self.force_preview_var = tk.BooleanVar(value=False) + self.force_preview_var = tk.BooleanVar(value=True) self.mask_preview_var = tk.BooleanVar(value=False) self.preview_x_var = tk.IntVar(value=60) self.preview_y_var = tk.IntVar(value=60) diff --git a/presets/Best3DSettings.json b/presets/Best3DSettings.json deleted file mode 100644 index 33dbc18..0000000 --- a/presets/Best3DSettings.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "fg_shift": 2.5, - "mg_shift": 1.5, - "bg_shift": -6.0, - "zero_parallax_strength": 0.02, - "max_pixel_shift": 0.03, - "parallax_balance": 0.8, - "sharpness_factor": 0.2, - "use_ffmpeg": true, - "enable_feathering": true, - "enable_edge_masking": false, - "use_floating_window": true, - "auto_crop_black_bars": false, - "skip_blank_frames": true, - "dof_strength": 2.5, - "convergence_strength": 0.01, - "enable_dynamic_convergence": true -} \ No newline at end of file diff --git a/presets/Cinema Wide Boost.json b/presets/Cinema Wide Boost.json new file mode 100644 index 0000000..7200460 --- /dev/null +++ b/presets/Cinema Wide Boost.json @@ -0,0 +1,32 @@ +{ + "fg_shift": 7.5, + "mg_shift": 2.5, + "bg_shift": -3.5, + "zero_parallax_strength": -0.005, + "max_pixel_shift": 0.1, + "parallax_balance": 1.0, + "sharpness_factor": 0.2, + "use_ffmpeg": true, + "enable_feathering": true, + "enable_edge_masking": true, + "use_floating_window": false, + "auto_crop_black_bars": false, + "skip_blank_frames": false, + "use_subject_tracking": false, + "dof_strength": 0.8, + "convergence_strength": 0.03, + "enable_dynamic_convergence": true, + "depth_pop_gamma": 0.87, + "depth_pop_mid": 0.45, + "depth_stretch_lo": 0.06, + "depth_stretch_hi": 0.95, + "fg_pop_multiplier": 1.3, + "bg_push_multiplier": 1.15, + "subject_lock_strength": 1.4, + "saturation": 1.15, + "contrast": 1.1, + "brightness": 0.05, + "ipd_enabled": true, + "ipd_factor": 1.15, + "preset_version": "3.5" +} \ No newline at end of file diff --git a/presets/Cinema-Pop Max.json b/presets/Cinema-Pop Max.json new file mode 100644 index 0000000..a736842 --- /dev/null +++ b/presets/Cinema-Pop Max.json @@ -0,0 +1,32 @@ +{ + "fg_shift": 12.0, + "mg_shift": 4.0, + "bg_shift": -6.0, + "zero_parallax_strength": -0.015, + "max_pixel_shift": 0.032, + "parallax_balance": 0.60, + "sharpness_factor": 0.2, + "use_ffmpeg": true, + "enable_feathering": true, + "enable_edge_masking": true, + "use_floating_window": true, + "auto_crop_black_bars": false, + "skip_blank_frames": false, + "use_subject_tracking": true, + "dof_strength": 0.65, + "convergence_strength": 0.020, + "enable_dynamic_convergence": true, + "depth_pop_gamma": 0.80, + "depth_pop_mid": 0.55, + "depth_stretch_lo": 0.05, + "depth_stretch_hi": 0.95, + "fg_pop_multiplier": 1.25, + "bg_push_multiplier": 1.05, + "subject_lock_strength": 1.0, + "saturation": 1.0, + "contrast": 1.0, + "brightness": 0.0, + "ipd_enabled": true, + "ipd_factor": 1.0, + "preset_version": "3.6-cinema-pop" +} diff --git a/presets/Cinematic Pop Balanced.json b/presets/Cinematic Pop Balanced.json new file mode 100644 index 0000000..4b5c73f --- /dev/null +++ b/presets/Cinematic Pop Balanced.json @@ -0,0 +1,32 @@ +{ + "fg_shift": 8.5, + "mg_shift": -1.0, + "bg_shift": -5.5, + "zero_parallax_strength": -0.005, + "max_pixel_shift": 0.1, + "parallax_balance": 0.9, + "sharpness_factor": 0.2, + "use_ffmpeg": true, + "enable_feathering": true, + "enable_edge_masking": true, + "use_floating_window": false, + "auto_crop_black_bars": false, + "skip_blank_frames": false, + "use_subject_tracking": false, + "dof_strength": 0.9, + "convergence_strength": 0.03, + "enable_dynamic_convergence": false, + "depth_pop_gamma": 0.82, + "depth_pop_mid": 0.5, + "depth_stretch_lo": 0.04, + "depth_stretch_hi": 0.97, + "fg_pop_multiplier": 1.35, + "bg_push_multiplier": 1.25, + "subject_lock_strength": 0.95, + "saturation": 1.1, + "contrast": 1.2, + "brightness": 0.05, + "ipd_enabled": false, + "ipd_factor": 1.1, + "preset_version": "3.5" +} \ No newline at end of file diff --git a/presets/Cinematic Punch Grade.json b/presets/Cinematic Punch Grade.json new file mode 100644 index 0000000..b995c18 --- /dev/null +++ b/presets/Cinematic Punch Grade.json @@ -0,0 +1,32 @@ +{ + "fg_shift": 10.0, + "mg_shift": -3.5, + "bg_shift": -5.5, + "zero_parallax_strength": -0.005, + "max_pixel_shift": 0.1, + "parallax_balance": 0.9, + "sharpness_factor": 0.2, + "use_ffmpeg": true, + "enable_feathering": true, + "enable_edge_masking": true, + "use_floating_window": false, + "auto_crop_black_bars": false, + "skip_blank_frames": false, + "use_subject_tracking": false, + "dof_strength": 0.9, + "convergence_strength": 0.03, + "enable_dynamic_convergence": false, + "depth_pop_gamma": 1.11, + "depth_pop_mid": 0.55, + "depth_stretch_lo": 0.05, + "depth_stretch_hi": 1.0, + "fg_pop_multiplier": 1.2, + "bg_push_multiplier": 1.1, + "subject_lock_strength": 0.8, + "saturation": 1.25, + "contrast": 1.5, + "brightness": 0.1, + "ipd_enabled": false, + "ipd_factor": 1.1, + "preset_version": "3.5" +} \ No newline at end of file diff --git a/presets/Cinematic immersion.json b/presets/Cinematic immersion.json new file mode 100644 index 0000000..ff049ec --- /dev/null +++ b/presets/Cinematic immersion.json @@ -0,0 +1,32 @@ +{ + "fg_shift": 10.0, + "mg_shift": 3.0, + "bg_shift": -6.0, + "zero_parallax_strength": 0.01, + "max_pixel_shift": 0.02, + "parallax_balance": 0.8, + "sharpness_factor": 0.2, + "use_ffmpeg": true, + "enable_feathering": true, + "enable_edge_masking": true, + "use_floating_window": true, + "auto_crop_black_bars": false, + "skip_blank_frames": false, + "use_subject_tracking": true, + "dof_strength": 0.7, + "convergence_strength": 0.012, + "enable_dynamic_convergence": true, + "depth_pop_gamma": 0.86, + "depth_pop_mid": 0.47, + "depth_stretch_lo": 0.05, + "depth_stretch_hi": 0.94, + "fg_pop_multiplier": 1.18, + "bg_push_multiplier": 1.08, + "subject_lock_strength": 1.1, + "saturation": 1.0, + "contrast": 1.0, + "brightness": 0.0, + "ipd_enabled": true, + "ipd_factor": 1.0, + "preset_version": "3.5" +} \ No newline at end of file diff --git a/presets/Clean Cinema Depth.json b/presets/Clean Cinema Depth.json new file mode 100644 index 0000000..83953cb --- /dev/null +++ b/presets/Clean Cinema Depth.json @@ -0,0 +1,32 @@ +{ + "fg_shift": 7.7, + "mg_shift": 2.5, + "bg_shift": -3.7, + "zero_parallax_strength": 0.0, + "max_pixel_shift": 0.1, + "parallax_balance": 0.85, + "sharpness_factor": 0.2, + "use_ffmpeg": true, + "enable_feathering": true, + "enable_edge_masking": true, + "use_floating_window": false, + "auto_crop_black_bars": false, + "skip_blank_frames": false, + "use_subject_tracking": false, + "dof_strength": 0.9, + "convergence_strength": 0.03, + "enable_dynamic_convergence": true, + "depth_pop_gamma": 0.87, + "depth_pop_mid": 0.45, + "depth_stretch_lo": 0.06, + "depth_stretch_hi": 0.95, + "fg_pop_multiplier": 1.3, + "bg_push_multiplier": 1.15, + "subject_lock_strength": 1.1, + "saturation": 1.0, + "contrast": 1.0, + "brightness": 0.0, + "ipd_enabled": true, + "ipd_factor": 1.1, + "preset_version": "3.5" +} \ No newline at end of file diff --git a/presets/Natural Cinema Depth.json b/presets/Natural Cinema Depth.json new file mode 100644 index 0000000..4bdb1be --- /dev/null +++ b/presets/Natural Cinema Depth.json @@ -0,0 +1,32 @@ +{ + "fg_shift": 7.5, + "mg_shift": 3.5, + "bg_shift": -5.0, + "zero_parallax_strength": -0.015, + "max_pixel_shift": 0.03, + "parallax_balance": 0.6, + "sharpness_factor": 0.2, + "use_ffmpeg": true, + "enable_feathering": true, + "enable_edge_masking": true, + "use_floating_window": true, + "auto_crop_black_bars": false, + "skip_blank_frames": false, + "use_subject_tracking": true, + "dof_strength": 0.65, + "convergence_strength": 0.02, + "enable_dynamic_convergence": true, + "depth_pop_gamma": 1.05, + "depth_pop_mid": 0.5, + "depth_stretch_lo": 0.05, + "depth_stretch_hi": 0.95, + "fg_pop_multiplier": 1.25, + "bg_push_multiplier": 1.05, + "subject_lock_strength": 1.0, + "saturation": 1.0, + "contrast": 1.0, + "brightness": 0.0, + "ipd_enabled": true, + "ipd_factor": 1.0, + "preset_version": "3.5" +} \ No newline at end of file diff --git a/presets/Pop-Depth.json b/presets/Pop-Depth.json new file mode 100644 index 0000000..f5b228f --- /dev/null +++ b/presets/Pop-Depth.json @@ -0,0 +1,32 @@ +{ + "fg_shift": 8.0, + "mg_shift": 2.0, + "bg_shift": -3.0, + "zero_parallax_strength": -0.005, + "max_pixel_shift": 0.07, + "parallax_balance": 0.95, + "sharpness_factor": 0.2, + "use_ffmpeg": true, + "enable_feathering": true, + "enable_edge_masking": true, + "use_floating_window": true, + "auto_crop_black_bars": false, + "skip_blank_frames": false, + "use_subject_tracking": false, + "dof_strength": 0.7, + "convergence_strength": 0.03, + "enable_dynamic_convergence": true, + "depth_pop_gamma": 0.94, + "depth_pop_mid": 0.48, + "depth_stretch_lo": 0.04, + "depth_stretch_hi": 0.96, + "fg_pop_multiplier": 1.29, + "bg_push_multiplier": 1.16, + "subject_lock_strength": 0.7, + "saturation": 1.5, + "contrast": 1.1, + "brightness": 0.04, + "ipd_enabled": true, + "ipd_factor": 1.13, + "preset_version": "3.5" +} \ No newline at end of file diff --git a/presets/Vivid Cinema Depth.json b/presets/Vivid Cinema Depth.json new file mode 100644 index 0000000..533f3a2 --- /dev/null +++ b/presets/Vivid Cinema Depth.json @@ -0,0 +1,32 @@ +{ + "fg_shift": 7.5, + "mg_shift": 3.0, + "bg_shift": -4.5, + "zero_parallax_strength": 0.0, + "max_pixel_shift": 0.07, + "parallax_balance": 0.95, + "sharpness_factor": 0.2, + "use_ffmpeg": true, + "enable_feathering": true, + "enable_edge_masking": true, + "use_floating_window": true, + "auto_crop_black_bars": false, + "skip_blank_frames": false, + "use_subject_tracking": false, + "dof_strength": 0.7, + "convergence_strength": 0.03, + "enable_dynamic_convergence": true, + "depth_pop_gamma": 1.0, + "depth_pop_mid": 0.5, + "depth_stretch_lo": 0.05, + "depth_stretch_hi": 0.95, + "fg_pop_multiplier": 1.29, + "bg_push_multiplier": 1.16, + "subject_lock_strength": 0.7, + "saturation": 1.5, + "contrast": 1.1, + "brightness": 0.04, + "ipd_enabled": true, + "ipd_factor": 1.1, + "preset_version": "3.5" +} \ No newline at end of file diff --git a/presets/VividDepth.json b/presets/VividDepth.json new file mode 100644 index 0000000..0653e78 --- /dev/null +++ b/presets/VividDepth.json @@ -0,0 +1,31 @@ +{ + "fg_shift": 2.5, + "mg_shift": -1.5, + "bg_shift": -5.5, + "zero_parallax_strength": 0.012, + "max_pixel_shift": 0.08, + "parallax_balance": 0.8, + "sharpness_factor": 0.2, + "use_ffmpeg": true, + "enable_feathering": false, + "enable_edge_masking": false, + "use_floating_window": false, + "auto_crop_black_bars": false, + "skip_blank_frames": false, + "dof_strength": 0.9, + "convergence_strength": 0.0, + "enable_dynamic_convergence": true, + "depth_pop_gamma": 1.1, + "depth_pop_mid": 0.45, + "depth_stretch_lo": 0.05, + "depth_stretch_hi": 0.94, + "fg_pop_multiplier": 1.1, + "bg_push_multiplier": 1.05, + "subject_lock_strength": 1.2, + "saturation": 1.4, + "contrast": 1.15, + "brightness": 0.04, + "ipd_enabled": true, + "ipd_factor": 1.05, + "preset_version": "3.5" +} \ No newline at end of file diff --git a/presets/balanced_depth.json b/presets/balanced_depth.json deleted file mode 100644 index 45768be..0000000 --- a/presets/balanced_depth.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "fg_shift": 8.0, - "mg_shift": -3.0, - "bg_shift": -6.0, - "convergence_offset": 0.013, - "max_pixel_shift": 0.035, - "parallax_balance": 0.35, - "sharpness_factor": 1.0, - "use_ffmpeg": true, - "enable_feathering": true, - "enable_edge_masking": true, - "use_floating_window": true, - "auto_crop_black_bars": true, - "skip_blank_frames": true -} diff --git a/presets/custom_preset.json b/presets/custom_preset.json new file mode 100644 index 0000000..6ca2696 --- /dev/null +++ b/presets/custom_preset.json @@ -0,0 +1,32 @@ +{ + "fg_shift": 8.0, + "mg_shift": 1.5, + "bg_shift": -2.5, + "zero_parallax_strength": 0.01, + "max_pixel_shift": 0.02, + "parallax_balance": 0.8, + "sharpness_factor": 0.2, + "use_ffmpeg": false, + "enable_feathering": true, + "enable_edge_masking": true, + "use_floating_window": true, + "auto_crop_black_bars": true, + "skip_blank_frames": false, + "use_subject_tracking": true, + "dof_strength": 2.0, + "convergence_strength": 0.0, + "enable_dynamic_convergence": true, + "depth_pop_gamma": 0.85, + "depth_pop_mid": 0.5, + "depth_stretch_lo": 0.05, + "depth_stretch_hi": 0.95, + "fg_pop_multiplier": 1.2, + "bg_push_multiplier": 1.1, + "subject_lock_strength": 1.0, + "saturation": 1.0, + "contrast": 1.0, + "brightness": 0.0, + "ipd_enabled": true, + "ipd_factor": 1.0, + "preset_version": "3.5" +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9ebbf8a..09715dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,15 @@ diffusers safetensors mediapy scenedetect -Flask +easydict +trimesh +evo +pycolmap +gsplat +plyfile +einops +addict +moviepy +omegaconf +flask +mss