diff --git a/client/platform/desktop/api/main.ts b/client/platform/desktop/api/main.ts index ccaabf3e1..eeec85a90 100644 --- a/client/platform/desktop/api/main.ts +++ b/client/platform/desktop/api/main.ts @@ -73,7 +73,7 @@ async function getTrainingConfigurations(): Promise { async function runTraining( // eslint-disable-next-line @typescript-eslint/no-unused-vars - folderId: string, pipelineName: string, config: string, + folderIds: string[], pipelineName: string, config: string, ): Promise { return Promise.resolve(); } diff --git a/client/platform/web-girder/api/viame.service.ts b/client/platform/web-girder/api/viame.service.ts index b6e84f127..6237f7aab 100644 --- a/client/platform/web-girder/api/viame.service.ts +++ b/client/platform/web-girder/api/viame.service.ts @@ -92,8 +92,8 @@ async function getTrainingConfigurations(): Promise { return data; } -function runTraining(folderId: string, pipelineName: string, config: string) { - return girderRest.post('/viame/train', null, { params: { folderId, pipelineName, config } }); +function runTraining(folderIds: string[], pipelineName: string, config: string) { + return girderRest.post('/viame/train', folderIds, { params: { pipelineName, config } }); } function saveMetadata(folderId: string, metadata: object) { diff --git a/client/viame-web-common/apispec.ts b/client/viame-web-common/apispec.ts index 075e7f81e..2b8d901e8 100644 --- a/client/viame-web-common/apispec.ts +++ b/client/viame-web-common/apispec.ts @@ -65,7 +65,7 @@ interface Api { runPipeline(itemId: string, pipeline: Pipe): Promise; getTrainingConfigurations(): Promise; - runTraining(folderId: string, pipelineName: string, config: string): Promise; + runTraining(folderIds: string[], pipelineName: string, config: string): Promise; loadDetections(datasetId: string): Promise<{ [key: string]: TrackData }>; saveDetections(datasetId: string, args: SaveDetectionsArgs): Promise; diff --git a/client/viame-web-common/components/RunTrainingMenu.vue b/client/viame-web-common/components/RunTrainingMenu.vue index eb3954d92..c862644f1 100644 --- a/client/viame-web-common/components/RunTrainingMenu.vue +++ b/client/viame-web-common/components/RunTrainingMenu.vue @@ -31,7 +31,7 @@ export default defineComponent({ selectedTrainingConfig.value = resp.default; }); - const trainingDisabled = computed(() => props.selectedDatasetIds.length !== 1); + const trainingDisabled = computed(() => props.selectedDatasetIds.length === 0); const trainingOutputName = ref(null); const menuOpen = ref(false); @@ -46,7 +46,7 @@ export default defineComponent({ try { await runTraining( - props.selectedDatasetIds[0], + props.selectedDatasetIds, trainingOutputName.value, selectedTrainingConfig.value, ); diff --git a/server/viame_server/viame.py b/server/viame_server/viame.py index c6f0e0628..2fe528447 100644 --- a/server/viame_server/viame.py +++ b/server/viame_server/viame.py @@ -108,13 +108,10 @@ def run_pipeline_task(self, folder, pipeline: PipelineDescription): @access.user @autoDescribeRoute( Description("Run training on a folder") - .modelParam( - "folderId", - description="The folder containing the training data", - model=Folder, - paramType="query", - required=True, - level=AccessType.WRITE, + .jsonParam( + "folderIds", + description="Array of folderIds to run training on", + paramType="body", ) .param( "pipelineName", @@ -129,20 +126,33 @@ def run_pipeline_task(self, folder, pipeline: PipelineDescription): required=True, ) ) - def run_training(self, folder, pipelineName, config): + def run_training(self, folderIds, pipelineName, config): user = self.getCurrentUser() token = Token().createToken(user=user, days=14) - detections = list( - Item().find({"meta.detection": str(folder["_id"])}).sort([("created", -1)]) - ) - detection = detections[0] if detections else None + detection_list = [] + folder_list = [] + folder_names = [] + if folderIds is None or len(folderIds) == 0: + raise Exception("No folderIds in param") + + for folderId in folderIds: + folder = Folder().load(folderId, level=AccessType.READ, user=user) + if folder is None: + raise Exception(f"Cannot access folder {folderId}") + folder_names.append(folder['name']) + detections = list( + Item().find({"meta.detection": str(folderId)}).sort([("created", -1)]) + ) + detection = detections[0] if detections else None - if not detection: - raise Exception(f"No detections for folder {folder['name']}") + if not detection: + raise Exception(f"No detections for folder {folder['name']}") - # Ensure detection has a csv format - csv_detection_file(folder, detection, user) + # Ensure detection has a csv format + csv_detection_file(folder, detection, user) + detection_list.append(detection) + folder_list.append(folder) # Ensure the folder to upload results to exists results_folder = training_output_folder(user) @@ -151,12 +161,14 @@ def run_training(self, folder, pipelineName, config): queue="training", kwargs=dict( results_folder=results_folder, - source_folder=folder, - groundtruth=detection, + source_folder_list=folder_list, + groundtruth_list=detection_list, pipeline_name=pipelineName, config=config, girder_client_token=str(token["_id"]), - girder_job_title=(f"Running training on folder: {str(folder['name'])}"), + girder_job_title=( + f"Running training on folder: {', '.join(folder_names)}" + ), girder_job_type="training", ), ) diff --git a/server/viame_tasks/tasks.py b/server/viame_tasks/tasks.py index b4de201ed..95c4a4d3c 100644 --- a/server/viame_tasks/tasks.py +++ b/server/viame_tasks/tasks.py @@ -4,7 +4,7 @@ import tempfile from pathlib import Path from subprocess import DEVNULL, Popen -from typing import Dict +from typing import Dict, List from girder_client import GirderClient from girder_worker.app import app @@ -184,27 +184,24 @@ def run_pipeline(self: Task, params: PipelineJob): def train_pipeline( self: Task, results_folder: Dict, - source_folder: Dict, - groundtruth: Dict, + source_folder_list: List[Dict], + groundtruth_list: List[Dict], pipeline_name: str, config: str, ): """ Train a pipeline by making a call to viame_train_detector - :param source_folder: The Girder Folder to pull training data from + :param source_folder_list: The Girder Folders to pull training data from :param results_folder: The Girder Folder to place the results of training into - :param groundtruth: The relative path to either the file containing detections, - or the folder containing that file. + :param groundtruth_list: A list of relative paths to either a file containing detections, + or a folder containing that file. :param pipeline_name: The base name of the resulting pipeline. """ conf = Config() gc: GirderClient = self.girder_client manager: JobManager = self.job_manager - # Generator of items - training_data = gc.listItem(source_folder["_id"]) - viame_install_path = Path(conf.viame_install_path) pipeline_base_path = Path(conf.pipeline_base_path) training_executable = viame_install_path / "bin" / "viame_train_detector" @@ -212,20 +209,47 @@ def train_pipeline( pipeline_name = pipeline_name.replace(" ", "_") + if len(source_folder_list) != len(groundtruth_list): + raise Exception("Ground truth doesn't exist for all folders") + + # List of folderIds used for training + trained_on_list: List[str] = [] + # List of[input folder / ground truth file] pairs for creating input lists + input_groundtruth_list: List[[Path, Path]] = [] # root_data_dir is the directory passed to `viame_train_detector` with tempfile.TemporaryDirectory() as _temp_dir_string: manager.updateStatus(JobStatus.FETCHING_INPUT) root_data_dir = Path(_temp_dir_string) - download_path = Path(tempfile.mkdtemp(dir=root_data_dir)) - # Download data onto server - gc.downloadItem(str(groundtruth["_id"]), download_path) - for item in training_data: - gc.downloadItem(str(item["_id"]), download_path) + for index in range(len(source_folder_list)): + source_folder = source_folder_list[index] + groundtruth = groundtruth_list[index] + download_path = Path(tempfile.mkdtemp(dir=root_data_dir)) + trained_on_list.append(str(source_folder["_id"])) - # Organize data - groundtruth_path = download_path / groundtruth["name"] - organize_folder_for_training(root_data_dir, download_path, groundtruth_path) + # Generator of items + training_data = gc.listItem(source_folder["_id"]) + + # Download data onto server + gc.downloadItem(str(groundtruth["_id"]), download_path) + for item in training_data: + gc.downloadItem(str(item["_id"]), download_path) + + # Organize data + groundtruth_path = download_path / groundtruth["name"] + groundtruth_file = organize_folder_for_training( + root_data_dir, download_path, groundtruth_path + ) + input_groundtruth_list.append([download_path, groundtruth_file]) + + input_folder_file_list = root_data_dir / "input_folder_list.txt" + ground_truth_file_list = root_data_dir / "input_truth_list.txt" + with open(input_folder_file_list, "w+") as data_list: + folder_paths = [f"{item[0]}\n" for item in input_groundtruth_list] + data_list.writelines(folder_paths) + with open(ground_truth_file_list, "w+") as truth_list: + truth_paths = [f"{item[1]}\n" for item in input_groundtruth_list] + truth_list.writelines(truth_paths) # Completely separate directory from `root_data_dir` with tempfile.TemporaryDirectory() as _training_output_path: @@ -233,15 +257,17 @@ def train_pipeline( command = [ f". {conf.viame_install_path}/setup_viame.sh &&", str(training_executable), - "-i", - str(root_data_dir), + "-il", + str(input_folder_file_list), + "-it", + str(ground_truth_file_list), "-c", str(config_file), + "--no-query", ] process_log_file = tempfile.TemporaryFile() process_err_file = tempfile.TemporaryFile() - manager.updateStatus(JobStatus.RUNNING) # Call viame_train_detector process = Popen( @@ -279,7 +305,7 @@ def train_pipeline( pipeline_name, metadata={ "trained_pipeline": True, - "trained_on": str(source_folder["_id"]), + "trained_on": trained_on_list, }, ) diff --git a/server/viame_tasks/utils.py b/server/viame_tasks/utils.py index 722de707e..c9721e738 100644 --- a/server/viame_tasks/utils.py +++ b/server/viame_tasks/utils.py @@ -66,18 +66,4 @@ def organize_folder_for_training( groundtruth = data_dir / "groundtruth.csv" shutil.move(str(downloaded_groundtruth), groundtruth) - # Generate labels.txt - labels = set() - with open(groundtruth, 'r') as groundtruth_infile: - for line in groundtruth_infile.readlines(): - if not line.strip().startswith('#'): - row = [c.strip() for c in line.split(",")] - - # Confidence pairs start at the 9th index - # 9th index is label, 10th is confidence, 11th is another label, etc. - for label in row[9::2]: - labels.add(label) - - with open(root_training_dir / "labels.txt", "w") as labels_file: - label_lines = [f"{label}\n" for label in labels] - labels_file.writelines(label_lines) + return groundtruth