diff --git a/webapp/src/components/Settings/AutocompleteStringField.tsx b/webapp/src/components/Settings/AutocompleteStringField.tsx new file mode 100644 index 00000000..9c617c2f --- /dev/null +++ b/webapp/src/components/Settings/AutocompleteStringField.tsx @@ -0,0 +1,53 @@ +import { + Autocomplete, + TextField, + TextFieldProps, + createFilterOptions, +} from "@mui/material"; +import React from "react"; +import { FieldProps, FIELD_COMMON_PROPS } from "./utils"; + +const filterOptions = createFilterOptions(); + +const AutocompleteStringField: React.FC< + Omit & FieldProps & { options: string[] } +> = ({ value, onChange, options, disabled, ...props }) => ( + false} + options={options} + filterOptions={(options, params) => { + const filtered = filterOptions(options, params); + + if (params.inputValue !== "" && !options.includes(params.inputValue)) { + filtered.push(params.inputValue); + } + + return filtered; + }} + value={value} + disabled={disabled} + onChange={onChange && ((_, newValue) => onChange(newValue as string))} + // The autoSelect prop would normally cause onChange when onBlur happens, + // except it doesn't work when the input is empty, so we do it manually. + onBlur={ + onChange && + ((event: React.FocusEvent) => { + if (event.target.value !== value) { + onChange(event.target.value); + } + }) + } + renderInput={(params) => ( + + )} + /> +); + +export default React.memo(AutocompleteStringField); diff --git a/webapp/src/components/Settings/EditableArray.tsx b/webapp/src/components/Settings/EditableArray.tsx new file mode 100644 index 00000000..6e8fcc37 --- /dev/null +++ b/webapp/src/components/Settings/EditableArray.tsx @@ -0,0 +1,86 @@ +import { AddCircle, DeleteForever } from "@mui/icons-material"; +import { + Box, + Divider, + IconButton, + Paper, + Tooltip, + dividerClasses, +} from "@mui/material"; +import React from "react"; +import { splicedArray } from "./utils"; + +const EditableArrayDivider: React.FC<{ + disabled: boolean; + title: string; + onAdd: () => void; + onRemove?: () => void; +}> = ({ disabled, title, onAdd, onRemove }) => ( + theme.spacing(2) }, + [`& .${dividerClasses.wrapper}`]: { padding: 0 }, + }} + > + + theme.palette.grey[400] }} + > + + + + {onRemove && ( + + + theme.palette.grey[400] }} + > + + + + + )} + +); + +const EditableArray = ({ + array, + disabled, + title, + newItem, + onChange, + renderItem, +}: { + array: T[]; + disabled: boolean; + title: string; + newItem: T; + onChange: (array: T[]) => void; + renderItem: (item: T, index: number, array: T[]) => React.ReactNode; +}) => ( + + {array.map((item, index, array) => ( + + onChange(splicedArray(array, index, 0, newItem))} + onRemove={() => onChange(splicedArray(array, index, 1))} + /> + {renderItem(item, index, array)} + + ))} + onChange([...array, newItem])} + /> + +); + +export default EditableArray; diff --git a/webapp/src/components/Settings/StringField.tsx b/webapp/src/components/Settings/StringField.tsx index 8d258eee..8aac1705 100644 --- a/webapp/src/components/Settings/StringField.tsx +++ b/webapp/src/components/Settings/StringField.tsx @@ -20,8 +20,7 @@ const StringField = ({ select={Boolean(options)} inputProps={{ sx: { textOverflow: "ellipsis" } }} value={value ?? ""} - error={value === ""} - helperText={value === "" ? "Set a value" : undefined} + {...(value === "" && { error: true, helperText: "Set a value" })} onChange={ onChange && (nullable diff --git a/webapp/src/components/Settings/utils.ts b/webapp/src/components/Settings/utils.ts index dfe4b7f6..78e64d96 100644 --- a/webapp/src/components/Settings/utils.ts +++ b/webapp/src/components/Settings/utils.ts @@ -5,3 +5,11 @@ export const FIELD_COMMON_PROPS = { variant: "standard", InputLabelProps: { shrink: true }, } as const; + +// Similar to Array.prototype.splice(), but returning a copy instead of modifying in place. +export const splicedArray = ( + array: T[], + start: number, + deleteCount: number, + ...insert: T[] +) => [...array.slice(0, start), ...insert, ...array.slice(start + deleteCount)]; diff --git a/webapp/src/pages/Settings.tsx b/webapp/src/pages/Settings.tsx index 5f996fba..b0b94d11 100644 --- a/webapp/src/pages/Settings.tsx +++ b/webapp/src/pages/Settings.tsx @@ -18,16 +18,18 @@ import { InputBaseComponentProps, inputClasses, inputLabelClasses, - Paper, Typography, } from "@mui/material"; import noData from "assets/void.svg"; import AccordionLayout from "components/AccordionLayout"; import Loading from "components/Loading"; +import AutocompleteStringField from "components/Settings/AutocompleteStringField"; +import EditableArray from "components/Settings/EditableArray"; import JSONField from "components/Settings/JSONField"; import NumberField from "components/Settings/NumberField"; import StringArrayField from "components/Settings/StringArrayField"; import StringField from "components/Settings/StringField"; +import { splicedArray } from "components/Settings/utils"; import _ from "lodash"; import React from "react"; import { useParams } from "react-router-dom"; @@ -42,6 +44,8 @@ import { SupportedLanguage, SupportedModelContract, SupportedSpacyModels, + TemperatureScaling, + ThresholdConfig, } from "types/api"; import { PickByValue } from "types/models"; import { UNKNOWN_ERROR } from "utils/const"; @@ -104,6 +108,15 @@ const FIELDS_TRIGGERING_STARTUP_TASKS: (keyof AzimuthConfig)[] = [ "metrics", ]; +type KnownPostprocessor = TemperatureScaling | ThresholdConfig; + +const KNOWN_POSTPROCESSORS: { + [T in KnownPostprocessor as T["class_name"]]: Partial; +} = { + "azimuth.utils.ml.postprocessing.TemperatureScaling": { temperature: 1 }, + "azimuth.utils.ml.postprocessing.Thresholding": { threshold: 0.5 }, +}; + const Columns: React.FC<{ columns?: number }> = ({ columns = 1, children }) => ( {children} @@ -122,11 +135,8 @@ const KeyValuePairs: React.FC = ({ children }) => ( ); -const updateArrayAt = (array: T[], index: number, update: Partial) => [ - ...array.slice(0, index), - { ...array[index], ...update }, - ...array.slice(index + 1), -]; +const updateArrayAt = (array: T[], index: number, update: Partial) => + splicedArray(array, index, 1, { ...array[index], ...update }); type Props = { open: boolean; @@ -362,6 +372,10 @@ const Settings: React.FC = ({ open, onClose }) => { ), }); + const getDefaultPostprocessors = (pipelineIndex: number) => + config.pipelines?.[pipelineIndex]?.postprocessors ?? // TODO pipelineIndex might not correspond if the user added or removed pipelines + defaultConfig.pipelines![0].postprocessors!; + const displayToggleSectionTitle = ( field: keyof AzimuthConfig, section: string = field @@ -397,8 +411,7 @@ const Settings: React.FC = ({ open, onClose }) => { onChange={(...[, checked]) => updatePipeline(pipelineIndex, { postprocessors: checked - ? config.pipelines![pipelineIndex].postprocessors ?? - defaultConfig.pipelines![0].postprocessors + ? getDefaultPostprocessors(pipelineIndex) : null, }) } @@ -557,138 +570,185 @@ const Settings: React.FC = ({ open, onClose }) => { - {resultingConfig.pipelines?.length && ( - <> - {displaySectionTitle("Pipelines")} - - {resultingConfig.pipelines.map((pipeline, pipelineIndex) => ( - - - {displaySectionTitle("General")} - - - - updatePipeline(pipelineIndex, { name }) - } - /> - - - - - {displaySectionTitle("Model")} - + {displaySectionTitle("Pipelines")} + updatePartialConfig({ pipelines })} + renderItem={(pipeline, pipelineIndex) => ( + + + {displaySectionTitle("General")} + + + updatePipeline(pipelineIndex, { name })} + /> + + + {displaySectionTitle("Model")} + + + + updateModel(pipelineIndex, { class_name }) + } + /> + + updateModel(pipelineIndex, { remote }) + } + /> + updateModel(pipelineIndex, { args })} + /> + + updateModel(pipelineIndex, { kwargs }) + } + /> + + + {displayPostprocessorToggleSection(pipelineIndex, pipeline)} + + updatePipeline(pipelineIndex, { postprocessors }) + } + renderItem={(postprocessor, index, postprocessors) => ( + - - updateModel(pipelineIndex, { class_name }) - } - /> - - updateModel(pipelineIndex, { remote }) - } - /> - - updateModel(pipelineIndex, { args }) + options={Object.keys(KNOWN_POSTPROCESSORS)} + value={postprocessor.class_name} + autoFocus + disabled={ + isUpdatingConfig || pipeline.postprocessors === null } - /> - - updateModel(pipelineIndex, { kwargs }) + onChange={(class_name) => + updatePipeline(pipelineIndex, { + postprocessors: splicedArray( + postprocessors, + index, + 1, + { + args: [], + kwargs: {}, + remote: null, + ...(KNOWN_POSTPROCESSORS[ + class_name as keyof typeof KNOWN_POSTPROCESSORS + ] || + postprocessor.class_name in + KNOWN_POSTPROCESSORS || // true spreads nothing + postprocessor), + class_name, + } + ), + }) } /> - - - - - {displayPostprocessorToggleSection(pipelineIndex, pipeline)} - - {( - pipeline.postprocessors ?? - defaultConfig.pipelines![0].postprocessors - )?.map((postprocessor, index) => ( - - + {!(postprocessor.class_name in KNOWN_POSTPROCESSORS) && ( + <> + updatePostprocessor(pipelineIndex, index, { + remote, + }) } - onChange={(class_name) => + /> + updatePostprocessor(pipelineIndex, index, { - class_name, + args, }) } /> - {"temperature" in postprocessor && ( - - updatePostprocessor(pipelineIndex, index, { - temperature, - kwargs: { temperature }, - }) - } - {...FLOAT} - /> - )} - {"threshold" in postprocessor && ( - - updatePostprocessor(pipelineIndex, index, { - threshold, - kwargs: { threshold }, - }) - } - {...PERCENTAGE} - /> - )} - - - ))} + + updatePostprocessor(pipelineIndex, index, { + kwargs, + }) + } + /> + + )} + {"temperature" in postprocessor && ( + + updatePostprocessor(pipelineIndex, index, { + temperature, + kwargs: { temperature }, + }) + } + {...FLOAT} + /> + )} + {"threshold" in postprocessor && ( + + updatePostprocessor(pipelineIndex, index, { + threshold, + kwargs: { threshold }, + }) + } + {...PERCENTAGE} + /> + )} + - - - ))} + )} + /> + - - )} + )} + /> {displaySectionTitle("Metrics")} {CUSTOM_METRICS.map((metricName, index) => ( diff --git a/webapp/src/types/api.ts b/webapp/src/types/api.ts index c2cc4996..d9d860cb 100644 --- a/webapp/src/types/api.ts +++ b/webapp/src/types/api.ts @@ -77,6 +77,8 @@ export type SupportedModelContract = components["schemas"]["SupportedModelContract"]; export type SupportedSpacyModels = components["schemas"]["SupportedSpacyModels"]; +export type TemperatureScaling = components["schemas"]["TemperatureScaling"]; +export type ThresholdConfig = components["schemas"]["ThresholdConfig"]; export type TopWordsResponse = components["schemas"]["TopWordsResponse"]; export type TopWordsResult = components["schemas"]["TopWordsResult"]; export type Utterance = components["schemas"]["Utterance"];