diff --git a/adelecv/api/modification_models/convert.py b/adelecv/api/modification_models/convert.py index 5914780..18f0c8c 100644 --- a/adelecv/api/modification_models/convert.py +++ b/adelecv/api/modification_models/convert.py @@ -50,7 +50,7 @@ def __init__( self._weights_path = weights_path self._supported_formats = ['onnx'] self._converter = { - 'onnx': TorchToOnnx(img_shape) + 'onnx': TorchToOnnx(self._input_shape) } def run( diff --git a/adelecv/api/modification_models/conveter.py b/adelecv/api/modification_models/conveter.py index 4612484..ff4887d 100644 --- a/adelecv/api/modification_models/conveter.py +++ b/adelecv/api/modification_models/conveter.py @@ -5,8 +5,8 @@ class BaseConverter(abc.ABC): - def __init__(self, input_shape): - self._dummy_input = torch.zeros(input_shape) + def __init__(self, input_shape: list[int]): + self._dummy_input = torch.zeros(input_shape) # BxCxHxW @property def dummy_input(self): diff --git a/adelecv/ui/dashboard/callbacks/table_models.py b/adelecv/ui/dashboard/callbacks/table_models.py index 63917d1..c0ef734 100644 --- a/adelecv/ui/dashboard/callbacks/table_models.py +++ b/adelecv/ui/dashboard/callbacks/table_models.py @@ -2,8 +2,8 @@ from dash.exceptions import PreventUpdate from adelecv.api.config import Settings -from adelecv.api.modification_models.export import ExportWeights -from adelecv.ui.dashboard.app import app +from adelecv.api.modification_models import ExportWeights, ConvertWeights +from adelecv.ui.dashboard.app import app, _task @app.callback( @@ -24,7 +24,8 @@ def export_weights(n_clicks, rows, derived_virtual_selected_rows): @app.callback( - Input("convert-weights", "n_clicks"), + Output("download-converted-onnx", "data"), + Input("convert-weights-format-onnx", "n_clicks"), State('stats-models-table', "derived_virtual_data"), State('stats-models-table', "derived_virtual_selected_rows"), prevent_initial_call=True @@ -33,6 +34,10 @@ def convert_weights(n_clicks, rows, derived_virtual_selected_rows): if not n_clicks: raise PreventUpdate() - print('convert weights', derived_virtual_selected_rows) - print(rows) - # return '' + id_selected = {rows[i]['_id'] for i in derived_virtual_selected_rows} + img_shape = _task.img_shape[0], _task.img_shape[0], 3 + zip_path = ConvertWeights( + img_shape, Settings.WEIGHTS_PATH + ).run(id_selected, 'onnx') + + return dcc.send_file(zip_path.as_posix()) diff --git a/adelecv/ui/dashboard/components/table_models.py b/adelecv/ui/dashboard/components/table_models.py index d137850..af3a057 100644 --- a/adelecv/ui/dashboard/components/table_models.py +++ b/adelecv/ui/dashboard/components/table_models.py @@ -20,13 +20,23 @@ def table_models(df: pd.DataFrame) -> dbc.Container: dbc.DropdownMenuItem( "Export weights", id='export-weights' ), - dbc.DropdownMenuItem( - "Convert weights", id='convert-weights' - ), + dbc.DropdownMenu( + label="Convert weights", + id='convert-weights', + children=[ + dbc.DropdownMenuItem( + "ONNX", id='convert-weights-format-onnx' + ), + ], + color="secondary", + direction="end", + style={"margin-bottom": "1%"} + ), ], style={"margin-bottom": "1%"} ), dcc.Download(id="download-weights"), + dcc.Download(id="download-converted-onnx"), html.Div( [ dash_table.DataTable( diff --git a/adelecv/ui/dashboard/task/base.py b/adelecv/ui/dashboard/task/base.py index 9a420c5..efd928e 100644 --- a/adelecv/ui/dashboard/task/base.py +++ b/adelecv/ui/dashboard/task/base.py @@ -42,6 +42,7 @@ def __init__(self): self._dataset = None self._tb = None self._session_dataset = None + self._img_shape = None def launch(self, env_path: Path | None): if env_path is not None: @@ -71,3 +72,11 @@ def _run_optimize(self) -> None: @property def stats_models(self) -> pd.DataFrame: return self._stats_models + + @property + def img_shape(self) -> list[int, int]: + return self._img_shape + + @img_shape.setter + def img_shape(self, new_val: list[int, int]): + self._img_shape = new_val diff --git a/adelecv/ui/dashboard/task/segmentation_task.py b/adelecv/ui/dashboard/task/segmentation_task.py index a47c877..290bd73 100644 --- a/adelecv/ui/dashboard/task/segmentation_task.py +++ b/adelecv/ui/dashboard/task/segmentation_task.py @@ -56,3 +56,4 @@ def load_dataset( batch_size ) self._create_dataset_session() + self.img_shape = img_size