diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 64b9d037..68a76f07 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,7 +28,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Cache pip uses: actions/cache@v4 @@ -95,7 +95,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Cache pip uses: actions/cache@v4 @@ -158,7 +158,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Install system deps (Qt/OpenCV runtime) shell: bash diff --git a/.github/workflows/deploy_docs.yml b/.github/workflows/deploy_docs.yml index c37aabab..badb1839 100644 --- a/.github/workflows/deploy_docs.yml +++ b/.github/workflows/deploy_docs.yml @@ -22,7 +22,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Install dependencies run: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c5ff9634..87e8c850 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -59,7 +59,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Cache pip uses: actions/cache@v4 @@ -129,7 +129,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Cache pip uses: actions/cache@v4 @@ -191,7 +191,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Install system deps (Qt/OpenCV runtime) shell: bash diff --git a/README.md b/README.md index 234af069..c1890855 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ cd VideoAnnotationTool ### Step 1 – Create a new Conda environment ```bash -conda create -n VideoAnnotationTool python=3.11 -y +conda create -n VideoAnnotationTool python=3.12 -y conda activate VideoAnnotationTool ``` diff --git a/annotation_tool/controllers/classification/inference_manager.py b/annotation_tool/controllers/classification/inference_manager.py index cd2e7552..ebb3ac19 100644 --- a/annotation_tool/controllers/classification/inference_manager.py +++ b/annotation_tool/controllers/classification/inference_manager.py @@ -887,6 +887,7 @@ def start_batch_inference(self, start_idx: int, end_idx: int): ) label_map = self._get_label_map_from_config() + self.panel.show_inference_loading(True) worker = BatchInferenceWorker( self.config_path, self.base_dir, @@ -902,6 +903,7 @@ def start_batch_inference(self, start_idx: int, end_idx: int): worker.start() def _on_batch_inference_success(self, _metrics: dict, results_list: list): + self.panel.show_inference_loading(False) if self.model is None: return @@ -946,6 +948,7 @@ def _on_batch_inference_success(self, _metrics: dict, results_list: list): self.controller.saveStateRefreshRequested.emit() def _on_batch_inference_error(self, error_msg): + self.panel.show_inference_loading(False) QMessageBox.critical(self.panel, "Batch Inference Error", f"An error occurred during batch inference:\n\n{error_msg}") def clear_smart_annotations_for_path(self, path: str): diff --git a/annotation_tool/controllers/localization/localization_editor_controller.py b/annotation_tool/controllers/localization/localization_editor_controller.py index 74c7ba27..2f3dfba7 100644 --- a/annotation_tool/controllers/localization/localization_editor_controller.py +++ b/annotation_tool/controllers/localization/localization_editor_controller.py @@ -74,6 +74,7 @@ def __init__(self, localization_panel): def reset_ui(self): self.localization_panel.annot_mgmt.update_schema({}) self.localization_panel.table.set_data([]) + self.localization_panel.show_inference_loading(False) self.localization_panel.setEnabled(False) self.current_video_path = None self.current_sample_id = "" @@ -578,6 +579,9 @@ def _prompt_inference_range(self): def _on_head_smart_inference_requested(self, head_name: str): if not self.current_video_path or not self.current_sample_id: return + if self.inference_manager.has_running_threads(): + self.statusMessageRequested.emit("Inference", "Localization inference is already running.", 1200) + return labels = self._head_labels(head_name) if not labels: @@ -597,6 +601,7 @@ def _on_head_smart_inference_requested(self, head_name: str): return self._pending_inference_head = str(head_name or "") + self.localization_panel.show_inference_loading(True) self.statusMessageRequested.emit("Inference", "Running localization inference...", 1200) self.inference_manager.start_inference( self.current_video_path, @@ -646,6 +651,7 @@ def _prediction_confidence(event: dict) -> float: return 1.0 def _on_inference_success(self, predicted_events: list): + self.localization_panel.show_inference_loading(False) if not self.current_video_path or not self.current_sample_id: self._pending_inference_head = None return @@ -704,6 +710,7 @@ def _on_inference_success(self, predicted_events: list): self._pending_inference_head = None def _on_inference_error(self, error_msg: str): + self.localization_panel.show_inference_loading(False) self._pending_inference_head = None QMessageBox.critical(self.localization_panel, "Inference Error", f"Failed to run model:\n{error_msg}") diff --git a/annotation_tool/ui/classification/__init__.py b/annotation_tool/ui/classification/__init__.py index 94db5d35..e7b4d21c 100644 --- a/annotation_tool/ui/classification/__init__.py +++ b/annotation_tool/ui/classification/__init__.py @@ -13,6 +13,7 @@ QLabel, QLineEdit, QMenu, + QProgressDialog, QPushButton, QRadioButton, QScrollArea, @@ -22,6 +23,7 @@ QWidget, ) +from ui.dialogs import BusyStatusDialog from utils import resource_path @@ -303,6 +305,20 @@ def set_smart_state(self, predicted_label: str, confidence_score: float, is_smar def get_row_smart_widgets(self, label_text: str): return self._smart_controls_by_label.get(str(label_text or "")) + def set_inference_loading(self, is_loading: bool): + self.btn_smart_infer.setEnabled(not is_loading) + self.btn_smart_infer.setText("Loading..." if is_loading else "Smart Inference") + for _conf_btn, accept_btn, reject_btn in self._smart_controls_by_label.values(): + accept_btn.setEnabled(not is_loading) + reject_btn.setEnabled(not is_loading) + + def set_inference_loading(self, is_loading: bool): + self.btn_smart_infer.setEnabled(not is_loading) + self.btn_smart_infer.setText("Loading..." if is_loading else "Smart Inference") + for _conf_btn, accept_btn, reject_btn in self._smart_controls_by_label.values(): + accept_btn.setEnabled(not is_loading) + reject_btn.setEnabled(not is_loading) + class DynamicMultiLabelGroup(QWidget): value_changed = pyqtSignal(str, list) @@ -449,6 +465,13 @@ def set_smart_state(self, predicted_label: str, confidence_score: float, is_smar def get_row_smart_widgets(self, label_text: str): return self._smart_controls_by_label.get(str(label_text or "")) + def set_inference_loading(self, is_loading: bool): + self.btn_smart_infer.setEnabled(not is_loading) + self.btn_smart_infer.setText("Loading..." if is_loading else "Smart Inference") + for _conf_btn, accept_btn, reject_btn in self._smart_controls_by_label.values(): + accept_btn.setEnabled(not is_loading) + reject_btn.setEnabled(not is_loading) + class ClassificationAnnotationPanel(QWidget): add_head_clicked = pyqtSignal(str) @@ -539,6 +562,7 @@ def __init__(self, parent=None): self.chart_widget.setVisible(False) self._configure_train_defaults() + self._configure_inference_feedback() self.clear_dynamic_labels() self.manual_box.setEnabled(False) self._update_confirm_button_state() @@ -704,6 +728,26 @@ def _configure_train_defaults(self): self.btn_stop_train.setEnabled(False) + def _configure_inference_feedback(self): + self._inference_loading_dialog = BusyStatusDialog( + "Inference", + "Loading model and running inference. Please wait...", + self, + ) + self._inference_loading_dialog.hide() + + def _set_inference_controls_loading(self, is_loading: bool): + self.head_tabs_widget.setEnabled(not is_loading) + self.clear_sel_btn.setEnabled(not is_loading) + self.btn_batch_infer.setEnabled(not is_loading) + self.btn_run_batch.setEnabled(not is_loading) + self.spin_start.setEnabled(not is_loading) + self.spin_end.setEnabled(not is_loading) + + for group in self.label_groups.values(): + if hasattr(group, "set_inference_loading"): + group.set_inference_loading(is_loading) + def _toggle_batch_widget(self): self.batch_input_widget.setVisible(not self.batch_input_widget.isVisible()) @@ -747,6 +791,7 @@ def reset_smart_inference(self): self.is_batch_mode_active = False self.pending_batch_results = {} self.chart_widget.setVisible(False) + self.show_inference_loading(False) def reset_train_ui(self): self.train_progress.setValue(0) @@ -771,7 +816,19 @@ def update_action_list(self, action_names: list): self._validate_batch_range() def show_inference_loading(self, is_loading: bool): - _ = is_loading + is_loading = bool(is_loading) + self._set_inference_controls_loading(is_loading) + + if is_loading: + self._inference_loading_dialog.set_message("Loading model and running inference. Please wait...") + self._inference_loading_dialog.show() + self._inference_loading_dialog.raise_() + self._inference_loading_dialog.activateWindow() + self.setCursor(Qt.CursorShape.WaitCursor) + return + + self._inference_loading_dialog.hide() + self.unsetCursor() def display_inference_result(self, target_head: str, predicted_label: str, conf_dict: dict): score = 0.0 diff --git a/annotation_tool/ui/dialogs.py b/annotation_tool/ui/dialogs.py index c83c1564..4d308a07 100644 --- a/annotation_tool/ui/dialogs.py +++ b/annotation_tool/ui/dialogs.py @@ -2,7 +2,7 @@ from PyQt6.QtWidgets import ( QDialog, QVBoxLayout, QRadioButton, QTreeView, QDialogButtonBox, QAbstractItemView, QGroupBox, QFormLayout, QLineEdit, QHBoxLayout, - QFrame, QListWidget, QComboBox, QPushButton, QLabel, + QFrame, QListWidget, QComboBox, QPushButton, QLabel, QProgressBar, QMessageBox, QWidget, QListWidgetItem, QStyle, QButtonGroup, QScrollArea ) from PyQt6.QtCore import QDir, Qt, QSize @@ -132,3 +132,26 @@ def __init__(self, error_string: str, parent=None) -> None: self.setDetailedText(f"System Diagnostic Logs:\n{error_string}") self.setStandardButtons(QMessageBox.StandardButton.Ok) + + +class BusyStatusDialog(QDialog): + def __init__(self, title: str, message: str, parent=None) -> None: + super().__init__(parent) + self.setWindowTitle(title) + self.setModal(True) + + layout = QVBoxLayout(self) + + self._label = QLabel(message, self) + self._label.setWordWrap(True) + layout.addWidget(self._label) + + self._progress = QProgressBar(self) + self._progress.setRange(0, 0) + self._progress.setTextVisible(False) + layout.addWidget(self._progress) + + self.setMinimumWidth(320) + + def set_message(self, message: str) -> None: + self._label.setText(message) diff --git a/annotation_tool/ui/localization/__init__.py b/annotation_tool/ui/localization/__init__.py index f66c7c6a..c7178513 100644 --- a/annotation_tool/ui/localization/__init__.py +++ b/annotation_tool/ui/localization/__init__.py @@ -28,6 +28,7 @@ localization_label_text_hex, normalize_hex_color, ) +from ui.dialogs import BusyStatusDialog from utils import resource_path @@ -400,6 +401,7 @@ def update_schema(self, label_definitions): "scroll": scroll, "labels": labels, "label_colors": dict(definition.get("label_colors", {})), + "smart_infer_btn": smart_infer_btn, } smart_infer_btn.clicked.connect(lambda _, h=head: self.smartInferenceRequested.emit(h)) self._populate_head_buttons(head) @@ -641,6 +643,15 @@ def _handle_add_head(self): if ok and name.strip(): self.headAdded.emit(name.strip()) + def set_inference_loading(self, is_loading: bool): + self._tabs.setEnabled(not is_loading) + for page_info in self._head_pages.values(): + smart_infer_btn = page_info.get("smart_infer_btn") + if smart_infer_btn is None: + continue + smart_infer_btn.setEnabled(not is_loading) + smart_infer_btn.setText("Loading..." if is_loading else "Smart Inference") + class _AnnotationManagementAdapter(QObject): def __init__(self, spotting_tabs: QTabWidget, parent=None): @@ -650,6 +661,9 @@ def __init__(self, spotting_tabs: QTabWidget, parent=None): def update_schema(self, label_definitions): self.tabs.update_schema(label_definitions) + def set_inference_loading(self, is_loading: bool): + self.tabs.set_inference_loading(is_loading) + class _SmartWidgetAdapter(QObject): """ @@ -783,5 +797,38 @@ def __init__(self, parent=None): self.btn_prev_event.clicked.connect(lambda: self.eventNavigateRequested.emit(-1)) self.btn_next_event.clicked.connect(lambda: self.eventNavigateRequested.emit(1)) + self._inference_loading_dialog = BusyStatusDialog( + "Inference", + "Loading model and running inference. Please wait...", + self, + ) + self._inference_loading_dialog.hide() + + def show_inference_loading(self, is_loading: bool): + is_loading = bool(is_loading) + self.annot_mgmt.set_inference_loading(is_loading) + self.table.table.setEnabled(not is_loading) + self.btn_prev_event.setEnabled(not is_loading) + self.btn_next_event.setEnabled(not is_loading) + + if self.table.btn_set_time is not None: + if is_loading: + self.table.btn_set_time.setEnabled(False) + else: + selection_model = self.table.table.selectionModel() + has_selection = bool(selection_model and selection_model.selectedRows()) + self.table.btn_set_time.setEnabled(has_selection) + + if is_loading: + self._inference_loading_dialog.set_message("Loading model and running inference. Please wait...") + self._inference_loading_dialog.show() + self._inference_loading_dialog.raise_() + self._inference_loading_dialog.activateWindow() + self.setCursor(Qt.CursorShape.WaitCursor) + return + + self._inference_loading_dialog.hide() + self.unsetCursor() + __all__ = ["LocalizationAnnotationPanel"] diff --git a/docs/installation.md b/docs/installation.md index 152909ac..8157127d 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -8,7 +8,7 @@ Pre-built binaries for Windows, macOS, and Linux are available on the [GitHub Re ## Requirements -- Python **3.11** or later +- Python **3.12** or later - PyQt6 - Other dependencies (see `requirements.txt`) diff --git a/requirements.txt b/requirements.txt index 0b59c4d4..96782030 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,11 @@ PyQt6 pyinstaller torch-geometric==2.7.0 -opensportslib +opensportslib==0.1.1 +datasets==4.8.2 imageio-ffmpeg==0.6.0 lightning==2.6.1 tabulate==0.10.0 wandb +pytest +pytest-qt diff --git a/tests/gui/test_workflow_classification.py b/tests/gui/test_workflow_classification.py index fa31e2c3..1d50130e 100644 --- a/tests/gui/test_workflow_classification.py +++ b/tests/gui/test_workflow_classification.py @@ -367,6 +367,46 @@ def test_classification_smart_inference_persists_confidence_and_confirm_strips_i assert "confidence_score" not in sample["labels"]["action"] +@pytest.mark.gui +def test_classification_inference_loading_cue_toggles_controls( + window, + monkeypatch, + qtbot, + synthetic_project_json, +): + project_json_path = synthetic_project_json("classification") + monkeypatch.setattr(window.dataset_explorer_controller, "check_and_close_current_project", lambda: True) + monkeypatch.setattr( + "controllers.dataset_explorer_controller.QFileDialog.getOpenFileName", + lambda *args, **kwargs: (str(project_json_path), "JSON Files (*.json)"), + ) + window.dataset_explorer_controller.import_annotations() + + first_index = window.tree_model.index(0, 0) + assert first_index.isValid() + window.dataset_explorer_panel.tree.setCurrentIndex(first_index) + qtbot.wait(50) + + panel = window.classification_panel + group = panel.label_groups["action"] + + panel.show_inference_loading(True) + qtbot.wait(50) + + assert panel._inference_loading_dialog.isVisible() is True + assert group.btn_smart_infer.isEnabled() is False + assert group.btn_smart_infer.text() == "Loading..." + assert panel.head_tabs_widget.isEnabled() is False + + panel.show_inference_loading(False) + qtbot.wait(50) + + assert panel._inference_loading_dialog.isVisible() is False + assert group.btn_smart_infer.isEnabled() is True + assert group.btn_smart_infer.text() == "Smart Inference" + assert panel.head_tabs_widget.isEnabled() is True + + @pytest.mark.gui def test_classification_clear_smart_restores_manual_or_removes_label_when_no_manual( window, diff --git a/tests/gui/test_workflow_dense_description.py b/tests/gui/test_workflow_dense_description.py index 8aadec94..10e78e19 100644 --- a/tests/gui/test_workflow_dense_description.py +++ b/tests/gui/test_workflow_dense_description.py @@ -212,7 +212,7 @@ def test_dense_add_button_text_is_defined_in_ui(window): @pytest.mark.gui -def test_dense_add_description_modal_flow_creates_event_and_resumes_playback( +def test_dense_add_description_modal_flow_creates_event_at_player_position( window, monkeypatch, qtbot, @@ -230,7 +230,7 @@ def test_dense_add_description_modal_flow_creates_event_and_resumes_playback( assert first_index.isValid() window.dataset_explorer_panel.tree.setCurrentIndex(first_index) qtbot.wait(50) - window.dense_editor_controller.on_media_position_changed(7777) + monkeypatch.setattr(window.center_panel.player, "position", lambda: 7777) monkeypatch.setattr( "controllers.dense_description.dense_editor_controller.QInputDialog.getMultiLineText", lambda *args, **kwargs: (" Added from popup ", True), diff --git a/tests/gui/test_workflow_localization.py b/tests/gui/test_workflow_localization.py index ecdb4df2..cd7dfb62 100644 --- a/tests/gui/test_workflow_localization.py +++ b/tests/gui/test_workflow_localization.py @@ -113,6 +113,69 @@ def fake_start_inference(video_path, start_ms, end_ms, model_id, head_name, labe assert captured["input_fps"] == pytest.approx(25.0) +@pytest.mark.gui +def test_localization_inference_loading_cue_toggles_controls( + window, + monkeypatch, + qtbot, + synthetic_project_json, +): + project_json_path = synthetic_project_json("localization") + monkeypatch.setattr(window.dataset_explorer_controller, "check_and_close_current_project", lambda: True) + monkeypatch.setattr( + "controllers.dataset_explorer_controller.QFileDialog.getOpenFileName", + lambda *args, **kwargs: (str(project_json_path), "JSON Files (*.json)"), + ) + window.dataset_explorer_controller.import_annotations() + + first_index = window.tree_model.index(0, 0) + assert first_index.isValid() + window.dataset_explorer_panel.tree.setCurrentIndex(first_index) + qtbot.wait(50) + + controller = window.localization_editor_controller + panel = window.localization_panel + + monkeypatch.setattr(controller, "_prompt_model_id", lambda: "jeetv/snpro-snbas-2024") + monkeypatch.setattr(controller, "_prompt_inference_range", lambda: (0, 5000)) + + captured = {} + + def fake_start_inference(video_path, start_ms, end_ms, model_id, head_name, labels, input_fps): + captured.update( + { + "video_path": video_path, + "start_ms": start_ms, + "end_ms": end_ms, + "model_id": model_id, + "head_name": head_name, + "labels": list(labels), + "input_fps": input_fps, + } + ) + + monkeypatch.setattr(controller.inference_manager, "start_inference", fake_start_inference) + monkeypatch.setattr( + "controllers.localization.localization_editor_controller.QMessageBox.critical", + lambda *args, **kwargs: None, + ) + + controller._on_head_smart_inference_requested("ball_action") + qtbot.wait(50) + + assert captured["video_path"] == controller.current_video_path + assert panel._inference_loading_dialog.isVisible() is True + assert panel.spottingTabs.isEnabled() is False + assert panel.table.table.isEnabled() is False + + controller._on_inference_error("synthetic failure") + qtbot.wait(50) + + assert panel._inference_loading_dialog.isVisible() is False + assert panel.spottingTabs.isEnabled() is True + assert panel.table.table.isEnabled() is True + + @pytest.mark.gui # Workflow: Localization annotation round-trip with timestamp edit: # 1) create event(label+time) + save + reopen, then 2) change time + save + reopen and verify final timestamp.