From 237d5d966e5d1e54b96bbce1fb6fbe6c1e1eb2eb Mon Sep 17 00:00:00 2001 From: Gabriele Picco Date: Wed, 12 Oct 2022 15:50:02 +0100 Subject: [PATCH] :recycle: Download smxm from the hugginface hub Signed-off-by: Gabriele Picco --- zshot/linker/linker_smxm.py | 14 +++++--------- zshot/linker/smxm/data.py | 2 +- zshot/linker/smxm/utils.py | 25 +------------------------ 3 files changed, 7 insertions(+), 34 deletions(-) diff --git a/zshot/linker/linker_smxm.py b/zshot/linker/linker_smxm.py index d605793..c2bc5c0 100644 --- a/zshot/linker/linker_smxm.py +++ b/zshot/linker/linker_smxm.py @@ -5,25 +5,21 @@ from torch.utils.data import DataLoader from transformers import BertTokenizerFast -from zshot.config import MODELS_CACHE_PATH from zshot.linker.linker import Linker from zshot.linker.smxm.data import ( ByDescriptionTaggerDataset, encode_data, tagger_multiclass_collator ) +from zshot.linker.smxm.model import BertTaggerMultiClass, device from zshot.linker.smxm.utils import ( SmxmInput, get_entities_names_descriptions, - load_model, predictions_to_span_annotations, ) from zshot.utils.data_models import Span -SMXM_MODEL_FILES_URL = ( - "https://ibm.box.com/shared/static/duni7p7i4gbk0prksc6zv5uahiemfy00.zip" -) -SMXM_MODEL_FOLDER_NAME = "BertTaggerMultiClass_config03_mode_tagger_multiclass_filtered_classes__entity_descriptions_mode_annotation_guidelines__per_gpu_train_batch_size_7/checkpoint" +MODEL_NAME = "gabriele-picco/smxm" class LinkerSMXM(Linker): @@ -47,9 +43,9 @@ def is_end2end(self) -> bool: def load_models(self): """ Load SMXM model """ if self.model is None: - self.model = load_model( - SMXM_MODEL_FILES_URL, MODELS_CACHE_PATH, SMXM_MODEL_FOLDER_NAME - ) + self.model = BertTaggerMultiClass.from_pretrained( + MODEL_NAME, output_hidden_states=True + ).to(device) def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: """ diff --git a/zshot/linker/smxm/data.py b/zshot/linker/smxm/data.py index 177d896..34a76e5 100644 --- a/zshot/linker/smxm/data.py +++ b/zshot/linker/smxm/data.py @@ -4,7 +4,7 @@ from torch.utils.data import Dataset from transformers import BertTokenizerFast -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +from zshot.linker.smxm.model import device class ByDescriptionTaggerDataset(Dataset): diff --git a/zshot/linker/smxm/utils.py b/zshot/linker/smxm/utils.py index 8fa005d..b31ccf3 100644 --- a/zshot/linker/smxm/utils.py +++ b/zshot/linker/smxm/utils.py @@ -1,16 +1,11 @@ -import os -import zipfile from typing import List, Tuple import torch from transformers import BertTokenizerFast +from zshot.linker.smxm.model import device from zshot.utils.data_models import Entity from zshot.utils.data_models import Span -from zshot.linker.smxm.model import BertTaggerMultiClass -from zshot.utils import download_file - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class SmxmInput(dict): @@ -34,24 +29,6 @@ def __init__( super().__init__(**config) -def load_model(url: str, output_path: str, folder_name: str) -> BertTaggerMultiClass: - filename = url.rsplit("/", 1)[1] - model_zipfile_path = os.path.join(output_path, filename) - model_folder_path = os.path.join(output_path, folder_name) - - if not os.path.isdir(model_folder_path): - download_file(url, output_path) - with zipfile.ZipFile(model_zipfile_path, "r") as model_zip: - model_zip.extractall(output_path) - os.remove(model_zipfile_path) - - model = BertTaggerMultiClass.from_pretrained( - model_folder_path, output_hidden_states=True - ).to(device) - - return model - - def predictions_to_span_annotations( sentences: List[str], predictions: List[List[int]],