Skip to content

Commit

Permalink
add conversation to onnx for ui
Browse files Browse the repository at this point in the history
  • Loading branch information
AsakoKabe committed May 2, 2023
1 parent b7be167 commit 3a449d6
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 12 deletions.
2 changes: 1 addition & 1 deletion adelecv/api/modification_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self._weights_path = weights_path
self._supported_formats = ['onnx']
self._converter = {
'onnx': TorchToOnnx(img_shape)
'onnx': TorchToOnnx(self._input_shape)
}

def run(
Expand Down
4 changes: 2 additions & 2 deletions adelecv/api/modification_models/conveter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


class BaseConverter(abc.ABC):
def __init__(self, input_shape):
self._dummy_input = torch.zeros(input_shape)
def __init__(self, input_shape: list[int]):
self._dummy_input = torch.zeros(input_shape) # BxCxHxW

@property
def dummy_input(self):
Expand Down
17 changes: 11 additions & 6 deletions adelecv/ui/dashboard/callbacks/table_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from dash.exceptions import PreventUpdate

from adelecv.api.config import Settings
from adelecv.api.modification_models.export import ExportWeights
from adelecv.ui.dashboard.app import app
from adelecv.api.modification_models import ExportWeights, ConvertWeights
from adelecv.ui.dashboard.app import app, _task


@app.callback(
Expand All @@ -24,7 +24,8 @@ def export_weights(n_clicks, rows, derived_virtual_selected_rows):


@app.callback(
Input("convert-weights", "n_clicks"),
Output("download-converted-onnx", "data"),
Input("convert-weights-format-onnx", "n_clicks"),
State('stats-models-table', "derived_virtual_data"),
State('stats-models-table', "derived_virtual_selected_rows"),
prevent_initial_call=True
Expand All @@ -33,6 +34,10 @@ def convert_weights(n_clicks, rows, derived_virtual_selected_rows):
if not n_clicks:
raise PreventUpdate()

print('convert weights', derived_virtual_selected_rows)
print(rows)
# return ''
id_selected = {rows[i]['_id'] for i in derived_virtual_selected_rows}
img_shape = _task.img_shape[0], _task.img_shape[0], 3
zip_path = ConvertWeights(
img_shape, Settings.WEIGHTS_PATH
).run(id_selected, 'onnx')

return dcc.send_file(zip_path.as_posix())
16 changes: 13 additions & 3 deletions adelecv/ui/dashboard/components/table_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,23 @@ def table_models(df: pd.DataFrame) -> dbc.Container:
dbc.DropdownMenuItem(
"Export weights", id='export-weights'
),
dbc.DropdownMenuItem(
"Convert weights", id='convert-weights'
),
dbc.DropdownMenu(
label="Convert weights",
id='convert-weights',
children=[
dbc.DropdownMenuItem(
"ONNX", id='convert-weights-format-onnx'
),
],
color="secondary",
direction="end",
style={"margin-bottom": "1%"}
),
],
style={"margin-bottom": "1%"}
),
dcc.Download(id="download-weights"),
dcc.Download(id="download-converted-onnx"),
html.Div(
[
dash_table.DataTable(
Expand Down
9 changes: 9 additions & 0 deletions adelecv/ui/dashboard/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self):
self._dataset = None
self._tb = None
self._session_dataset = None
self._img_shape = None

def launch(self, env_path: Path | None):
if env_path is not None:
Expand Down Expand Up @@ -71,3 +72,11 @@ def _run_optimize(self) -> None:
@property
def stats_models(self) -> pd.DataFrame:
return self._stats_models

@property
def img_shape(self) -> list[int, int]:
return self._img_shape

@img_shape.setter
def img_shape(self, new_val: list[int, int]):
self._img_shape = new_val
1 change: 1 addition & 0 deletions adelecv/ui/dashboard/task/segmentation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ def load_dataset(
batch_size
)
self._create_dataset_session()
self.img_shape = img_size

0 comments on commit 3a449d6

Please sign in to comment.