Skip to content
Merged
36 changes: 34 additions & 2 deletions medcat-trainer/webapp/api/api/views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
import os
from smtplib import SMTPException
from tempfile import NamedTemporaryFile
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Any
import shutil

from background_task.models import Task, CompletedTask
from django.contrib.auth.views import PasswordResetView
Expand Down Expand Up @@ -579,11 +580,42 @@ def save_models(request):
project = ProjectAnnotateEntities.objects.get(id=p_id)
cat = get_medcat(project=project)

cat.cdb.save(project.concept_db.cdb_file.path)
if project.concept_db is not None:
# CDB / vocab based
cat.cdb.save(project.concept_db.cdb_file.path, overwrite=True)
else:
# ModelPack based project
_overwrite_model_pack(cat, project.model_pack.path)
Comment thread
alhendrickson marked this conversation as resolved.

return Response({'message': 'Models saved'})


def _overwrite_model_pack(cat, model_path: str):
# NOTE: cannot overwrite, so working around
Comment thread
alhendrickson marked this conversation as resolved.
# currently CAT.save_model_pack does not provide a method to
# allow overwriting an existing model pack
with TemporaryDirectory() as tmp_dir:
# making new folder name so that it's copied
# to the specific path rather than into the folder
temp_folder = os.path.join(tmp_dir, "model_copy")
shutil.move(model_path, temp_folder)
try:
cat.save_model_pack(
os.path.dirname(model_path),
pack_name=os.path.basename(model_path),
add_hash_to_pack_name=False)
except Exception as e:
logger.warning("Unable to save model pack. Restoring previous state. Issue while saving model:", exc_info=e)
if os.path.exists(model_path):
shutil.rmtree(model_path) # remove partial/corrupt output
# restore original
try:
shutil.move(temp_folder, model_path)
except Exception as restore_err:
logger.error("Failed to restore model pack:", exc_info=restore_err)
raise


@api_view(http_method_names=['POST'])
def get_create_entity(request):
label = request.data['label']
Expand Down
Loading