Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion client/dive-common/apispec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ interface Api {

getTrainingConfigurations(): Promise<TrainingConfigs>;
runTraining(
folderIds: string[], pipelineName: string, config: string, annotatedFramesOnly: boolean
folderIds: string[],
pipelineName: string,
config: string,
annotatedFramesOnly: boolean,
labelText?: string,
): Promise<unknown>;

loadMetadata(datasetId: string): Promise<DatasetMeta>;
Expand Down
4 changes: 2 additions & 2 deletions client/platform/desktop/backend/native/common.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,8 @@ describe('native.common', () => {

await common.dataFileImport(settings, final.id, '/home/user/data/annotationImport/foreign.meta.json');
const meta2 = await common.loadMetadata(settings, final.id, urlMapper);
expect(meta2.confidenceFilters).toStrictEqual({ "default": 0.8 });
expect(meta2.type).toBe("image-sequence"); // Ensure meta import cannot change immutable fields.
expect(meta2.confidenceFilters).toStrictEqual({ default: 0.8 });
expect(meta2.type).toBe('image-sequence'); // Ensure meta import cannot change immutable fields.
});

it('import with CSV annotations without specifying track file', async () => {
Expand Down
2 changes: 2 additions & 0 deletions client/platform/desktop/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ export interface RunTraining {
trainingConfig: string;
// train only on annotated frames
annotatedFramesOnly: boolean;
// contents of labels.txt file
labelText?: string;
}

export interface ConversionArgs {
Expand Down
7 changes: 6 additions & 1 deletion client/platform/desktop/frontend/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,18 @@ async function runPipeline(itemId: string, pipeline: Pipe): Promise<DesktopJob>
}

async function runTraining(
folderIds: string[], pipelineName: string, config: string, annotatedFramesOnly: boolean,
folderIds: string[],
pipelineName: string,
config: string,
annotatedFramesOnly: boolean,
labelText?: string,
): Promise<DesktopJob> {
const args: RunTraining = {
datasetIds: folderIds,
pipelineName,
trainingConfig: config,
annotatedFramesOnly,
labelText,
};
return ipcRenderer.invoke('run-training', args);
}
Expand Down
12 changes: 9 additions & 3 deletions client/platform/web-girder/api/rpc.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@ function runPipeline(itemId: string, pipeline: Pipe) {
}

function runTraining(
folderIds: string[], pipelineName: string, config: string, annotatedFramesOnly: boolean,
folderIds: string[],
pipelineName: string,
config: string,
annotatedFramesOnly: boolean,
labelText?: string,
) {
return girderRest.post('dive_rpc/train', folderIds, {
params: { pipelineName, config, annotatedFramesOnly },
return girderRest.post('dive_rpc/train', { folderIds, labelText }, {
params: {
pipelineName, config, annotatedFramesOnly,
},
});
}

Expand Down
1 change: 1 addition & 0 deletions client/platform/web-girder/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Promise<{ canceled: boolean; filePaths: string[]; fileList?: File[]}> {
input.accept = inputAnnotationTypes
.concat(inputAnnotationFileTypes.map((item) => `.${item}`)).join(',');
}

return new Promise(((resolve) => {
input.onchange = (event) => {
if (event) {
Expand Down
4 changes: 3 additions & 1 deletion client/platform/web-girder/views/Export.vue
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import {
DatasetSourceMedia, getDataset, getDatasetMedia, getUri,
} from 'platform/web-girder/api';
import { GirderMetadataStatic } from 'platform/web-girder/constants';
import { ImageSequenceType, MultiType, VideoType } from 'dive-common/constants';
import {
ImageSequenceType, MultiType, VideoType,
} from 'dive-common/constants';

export default defineComponent({
components: { AutosavePrompt },
Expand Down
2 changes: 1 addition & 1 deletion client/platform/web-girder/views/Home.vue
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import {
import { mapGetters, mapState } from 'vuex';

import RunPipelineMenu from 'dive-common/components/RunPipelineMenu.vue';
import RunTrainingMenu from 'dive-common/components/RunTrainingMenu.vue';
import { usePrompt } from 'dive-common/vue-utilities/prompt-service';
import RunTrainingMenu from './RunTrainingMenu.vue';

import { deleteResources, getUri } from '../api';
import Export from './Export.vue';
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
<script lang="ts">
import {
defineComponent, computed, PropType, ref, onBeforeMount,
defineComponent, computed, PropType, ref, onBeforeMount, watch,
} from '@vue/composition-api';

import { useApi, TrainingConfigs } from 'dive-common/apispec';
import JobLaunchDialog from 'dive-common/components/JobLaunchDialog.vue';
import ImportButton from 'dive-common/components/ImportButton.vue';
import { useRequest } from 'dive-common/use';


export default defineComponent({
name: 'RunTrainingMenu',

components: { JobLaunchDialog },
components: { JobLaunchDialog, ImportButton },

props: {
selectedDatasetIds: {
Expand Down Expand Up @@ -50,6 +52,8 @@ export default defineComponent({
const trainingDisabled = computed(() => props.selectedDatasetIds.length === 0);
const trainingOutputName = ref<string | null>(null);
const menuOpen = ref(false);
const labelText = ref<string>('');
const labelFile = ref<File>();

async function runTrainingOnFolder() {
const outputPipelineName = trainingOutputName.value;
Expand All @@ -60,6 +64,15 @@ export default defineComponent({
if (!trainingConfigurations.value || !selectedTrainingConfig.value) {
throw new Error('Training configurations not found.');
}
if (labelText) {
return runTraining(
props.selectedDatasetIds,
outputPipelineName,
selectedTrainingConfig.value,
annotatedFramesOnly.value,
labelText.value,
);
}
return runTraining(
props.selectedDatasetIds,
outputPipelineName,
Expand All @@ -71,6 +84,21 @@ export default defineComponent({
trainingOutputName.value = null;
}

watch(labelFile, () => {
if (labelFile.value) {
const reader = new FileReader();
reader.onload = (evt) => {
labelText.value = evt.target?.result as string;
};
reader.readAsText(labelFile.value);
}
});

const clearLabelText = () => {
labelText.value = '';
};


return {
trainingConfigurations,
selectedTrainingConfig,
Expand All @@ -82,6 +110,8 @@ export default defineComponent({
successMessage,
dismissJobDialog,
runTrainingOnFolder,
labelFile,
clearLabelText,
};
},
});
Expand Down Expand Up @@ -136,6 +166,11 @@ export default defineComponent({
<p>
Specify the name of the resulting pipeline
and configuration file to use for training.
Check the
<a href="https://kitware.github.io/dive/Pipeline-Documentation/#training">
documentation
</a>
for more information about these options.
</p>
<v-alert
dense
Expand All @@ -149,9 +184,10 @@ export default defineComponent({
<v-text-field
v-model="trainingOutputName"
outlined
hide-details
class="my-4"
label="Output Name"
label="New Model Name"
hint="Choose a name for the newly trained model"
persistent-hint
/>
<v-select
v-model="selectedTrainingConfig"
Expand All @@ -161,10 +197,18 @@ export default defineComponent({
label="Configuration File"
:items="trainingConfigurations.configs"
/>
<v-file-input
v-model="labelFile"
Comment thread
BryonLewis marked this conversation as resolved.
icon="mdi-folder-open"
label="Labels.txt mapping file (optional)"
hint="Combine or rename output classes using a labels.txt file"
persistent-hint
clearable
@click:clear="clearLabelText"
/>
<v-checkbox
v-model="annotatedFramesOnly"
label="Use annotated frames only"
dense
hint="Train only on frames with groundtruth and ignore frames without annotations"
persistent-hint
class="pt-0"
Expand Down
32 changes: 29 additions & 3 deletions docs/Pipeline-Documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,35 @@ Run model training on ground truth annotations. Currently, training configurati

### Options

* Output Name - a recognizeable name for the pipeline that results from the training run.
* Configuration File - chosen from the options below
* Use anootation Frames Only - by default, training runs include all frames from the chosen input datasets, and frames without annotations are considered negatives examples. If you choose to use annotated frames only, frames or images with zero annotations will be discarded. This option is useful for trying to train on datasets that are only partially annotated.
![Training options dialog](images/TrainingMenu.png)

#### New Model Name

A recognizable name for the pipeline that results from the training run.

#### Configuration File

One of the configuration options in the table below.

#### Labels.txt file

This **optional** file controls the output classes that a newly trained model will generate.

* Use if you annotated using higher granularity labels (such as species names) and want to train a classifier using more
* Or you want to restrict your training session to only train on certain kinds of ground-truth data.

The following example `labels.txt` shows how to train a `FISH` classifier by combining `redfish` and `bluefish`, preserve the `ROCK` label, and omit every other label.

``` text
FISH redfish bluefish
ROCK
```

By default, all classes from all input datasets are preserved in the output model.

#### Use annotation frames only

By default, training runs include all frames from the chosen input datasets, and frames without annotations are considered negatives examples. If you choose to use annotated frames only, frames or images with zero annotations will be discarded. This option is useful for trying to train on datasets that are only partially annotated.

### Configurations

Expand Down
Binary file added docs/images/TrainingMenu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 13 additions & 4 deletions server/dive_server/crud_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from girder.models.token import Token
from girder.models.upload import Upload
from girder_jobs.models.job import Job
from pydantic import BaseModel
import pymongo

from dive_server import crud
Expand All @@ -20,6 +21,11 @@
from . import crud_dataset


class RunTrainingArgs(BaseModel):
folderIds: List[str]
labelText: Optional[str]


def _get_queue_name(user: types.GirderUserModel, default="celery") -> str:
if user.get(constants.UserPrivateQueueEnabledMarker, False):
return f'{user["login"]}@private'
Expand Down Expand Up @@ -226,18 +232,18 @@ def ensure_csv_detections_file(
def run_training(
user: types.GirderUserModel,
token: types.GirderModel,
folderIds: List[str],
bodyParams: RunTrainingArgs,
pipelineName: str,
config: str,
annotatedFramesOnly: bool,
) -> types.GirderModel:
detection_list = []
folder_list = []
folder_names = []
if folderIds is None or len(folderIds) == 0:
if len(bodyParams.folderIds) == 0:
raise RestException("No folderIds in param")

for folderId in folderIds:
for folderId in bodyParams.folderIds:
folder = Folder().load(folderId, level=AccessType.READ, user=user)
if folder is None:
raise RestException(f"Cannot access folder {folderId}")
Expand Down Expand Up @@ -276,16 +282,19 @@ def run_training(
pipeline_name=pipelineName,
config=config,
annotated_frames_only=annotatedFramesOnly,
label_text=bodyParams.labelText,
girder_client_token=str(token["_id"]),
girder_job_title=(f"Running training on {len(folder_list)} datasets"),
girder_job_type="private" if job_is_private else "training",
),
)
newjob.job[constants.JOBCONST_PRIVATE_QUEUE] = job_is_private
newjob.job[constants.JOBCONST_TRAINING_INPUT_IDS] = folderIds
newjob.job[constants.JOBCONST_TRAINING_INPUT_IDS] = bodyParams.folderIds
newjob.job[constants.JOBCONST_RESULTS_FOLDER_ID] = str(results_folder['_id'])
newjob.job[constants.JOBCONST_TRAINING_CONFIG] = config
newjob.job[constants.JOBCONST_PIPELINE_NAME] = pipelineName
newjob.job[constants.JOBCONST_LABEL_TEXT] = bodyParams.labelText

Job().save(newjob.job)
return newjob.job

Expand Down
15 changes: 10 additions & 5 deletions server/dive_server/views_rpc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

from girder.api import access
from girder.api.describe import Description, autoDescribeRoute
from girder.api.rest import Resource
Expand All @@ -7,7 +9,7 @@

from dive_utils.types import PipelineDescription

from . import crud_rpc
from . import crud, crud_rpc


class RpcResource(Resource):
Expand Down Expand Up @@ -41,9 +43,11 @@ def run_pipeline_task(self, folder, pipeline: PipelineDescription):
@autoDescribeRoute(
Description("Run training on a folder")
.jsonParam(
"folderIds",
description="Array of folderIds to run training on",
"body",
description="JSON object with Array of folderIds to run training on\
and labels.txt file content",
paramType="body",
schema={"folderIds": List[str], "labelText": str},
)
.param(
"pipelineName",
Expand All @@ -66,11 +70,12 @@ def run_pipeline_task(self, folder, pipeline: PipelineDescription):
required=False,
)
)
def run_training(self, folderIds, pipelineName, config, annotatedFramesOnly):
def run_training(self, body, pipelineName, config, annotatedFramesOnly):
user = self.getCurrentUser()
token = Token().createToken(user=user, days=14)
run_training_args = crud.get_validated_model(crud_rpc.RunTrainingArgs, **body)
return crud_rpc.run_training(
user, token, folderIds, pipelineName, config, annotatedFramesOnly
user, token, run_training_args, pipelineName, config, annotatedFramesOnly
)

@access.user
Expand Down
Loading