Skip to content

Commit

Permalink
add conversation to onnx in api
Browse files Browse the repository at this point in the history
  • Loading branch information
AsakoKabe committed May 2, 2023
1 parent 34b1635 commit b7be167
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ The Dash library was used for dashboard. It is based on components and callbacks
### 📝 Citing
```
@misc{Mamatin:2023,
Author = {Denis Mamtin},
Author = {Denis Mamatin},
Title = {AdeleCV},
Year = {2023},
Publisher = {GitHub},
Expand Down
2 changes: 1 addition & 1 deletion adelecv/api/models/base/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def save_weights(self) -> None:
if not os.path.exists(path):
os.mkdir(path)
save_path = path / f'{self._id}.pt'
torch.save(self._torch_model, save_path)
torch.save(self._torch_model.cpu(), save_path)
get_logger().debug(
"Save weights model: %s, path: %s", str(self), save_path.as_posix()
)
Expand Down
4 changes: 3 additions & 1 deletion adelecv/api/modification_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .convert import ConvertWeights
from .export import ExportWeights

__all__ = [
"ExportWeights"
"ExportWeights",
"ConvertWeights"
]
109 changes: 109 additions & 0 deletions adelecv/api/modification_models/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from __future__ import annotations

import os
import zipfile
from pathlib import Path
from uuid import uuid4

import torch

from adelecv.api.config import Settings
from adelecv.api.logs import get_logger

from .conveter import TorchToOnnx


def _create_zip(
converted_weights_path: Path,
zip_path: Path
) -> None:
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zip_file:
for entry in converted_weights_path.rglob("*"):
zip_file.write(
entry, entry.relative_to(
converted_weights_path
)
)
get_logger().info(
"Create zip with converted weights, path: %s", zip_path.as_posix()
)


class ConvertWeights:
"""
Class for conversation weights.
:param weights_path: Path to saved weights.
"""

def __init__(
self,
img_shape: list[int] | tuple[int], # HxWxC
weights_path: Path = Settings.WEIGHTS_PATH,

):
if len(img_shape) != 3:
raise ValueError("Input shape must be in the format HxWxC")

# BxCxHxW
self._input_shape = (1, img_shape[2], img_shape[0], img_shape[1])
self._weights_path = weights_path
self._supported_formats = ['onnx']
self._converter = {
'onnx': TorchToOnnx(img_shape)
}

def run(
self,
id_selected: None | set[str] | list[str] = None,
new_format: None | str = None
) -> Path:
"""
Converting selected models to the specified format.
:param new_format: format weights for conversation
:param id_selected: List with id models from stats_models
:return: Path to created zip file with other formats weights.
"""

if new_format not in self.supported_formats:
raise ValueError(
f"{new_format} format is not supported for conversion. "
f"Supported formats: {self.supported_formats}"
)

id_convert = uuid4().hex
path_to_save = self.weights_path.parent / f'converted_{id_convert}'
os.mkdir(path_to_save.as_posix())
path_to_zip = self.weights_path.parent / f'converted_{id_convert}.zip'

for id_model in id_selected:
get_logger().info(
"Сonvert weights model: %s to %s format",
id_model, new_format
)
path_weights = self.weights_path / f'{id_model}.pt'
torch_model = torch.load(path_weights)
torch_model.eval()
self._convert(torch_model, new_format, id_model, path_to_save)
_create_zip(path_to_save, path_to_zip)

return path_to_zip

def _convert(
self,
torch_model: torch.nn.Module,
new_format: str,
id_model: str,
path_to_save: Path
) -> None:
path_to_save_weights = path_to_save / f'{new_format}_{id_model}'
self._converter[new_format].convert(torch_model, path_to_save_weights)

@property
def supported_formats(self) -> list[str]:
return self._supported_formats

@property
def weights_path(self) -> Path:
return self._weights_path
44 changes: 44 additions & 0 deletions adelecv/api/modification_models/conveter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import abc
from pathlib import Path

import torch


class BaseConverter(abc.ABC):
def __init__(self, input_shape):
self._dummy_input = torch.zeros(input_shape)

@property
def dummy_input(self):
return self._dummy_input

@abc.abstractmethod
def convert(
self,
torch_model: torch.nn.Module,
path_to_save_weights: Path
) -> None:
pass


class TorchToOnnx(BaseConverter):
def convert(
self,
torch_model: torch.nn.Module,
path_to_save_weights
) -> None:
torch.onnx.export(
torch_model, # model being run
self.dummy_input, # model input (or a tuple for multiple inputs)
path_to_save_weights,
# where to save the model (can be a file or file-like object)
export_params=True,
# store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True,
# whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}}
)
Empty file.
5 changes: 1 addition & 4 deletions adelecv/ui/dashboard/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,7 @@ def main(envfile: str = './.env') -> None:
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
)
app.run(
port=Settings.DASHBOARD_PORT,
debug=False
)
app.run()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "adelecv"
version = "0.0.1"
version = "0.0.2"
authors = ["Denis Mamatin <mamatin-denis@yandex.ru>"]
description = ""
readme = "README.md"
Expand Down

0 comments on commit b7be167

Please sign in to comment.