Skip to content

Commit

Permalink
♻️ Download smxm from the hugginface hub
Browse files Browse the repository at this point in the history
Signed-off-by: Gabriele Picco <piccogabriele@gmail.com>
  • Loading branch information
GabrielePicco committed Oct 12, 2022
1 parent 3895466 commit 237d5d9
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 34 deletions.
14 changes: 5 additions & 9 deletions zshot/linker/linker_smxm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,21 @@
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast

from zshot.config import MODELS_CACHE_PATH
from zshot.linker.linker import Linker
from zshot.linker.smxm.data import (
ByDescriptionTaggerDataset,
encode_data,
tagger_multiclass_collator
)
from zshot.linker.smxm.model import BertTaggerMultiClass, device
from zshot.linker.smxm.utils import (
SmxmInput,
get_entities_names_descriptions,
load_model,
predictions_to_span_annotations,
)
from zshot.utils.data_models import Span

SMXM_MODEL_FILES_URL = (
"https://ibm.box.com/shared/static/duni7p7i4gbk0prksc6zv5uahiemfy00.zip"
)
SMXM_MODEL_FOLDER_NAME = "BertTaggerMultiClass_config03_mode_tagger_multiclass_filtered_classes__entity_descriptions_mode_annotation_guidelines__per_gpu_train_batch_size_7/checkpoint"
MODEL_NAME = "gabriele-picco/smxm"


class LinkerSMXM(Linker):
Expand All @@ -47,9 +43,9 @@ def is_end2end(self) -> bool:
def load_models(self):
""" Load SMXM model """
if self.model is None:
self.model = load_model(
SMXM_MODEL_FILES_URL, MODELS_CACHE_PATH, SMXM_MODEL_FOLDER_NAME
)
self.model = BertTaggerMultiClass.from_pretrained(
MODEL_NAME, output_hidden_states=True
).to(device)

def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]:
"""
Expand Down
2 changes: 1 addition & 1 deletion zshot/linker/smxm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.utils.data import Dataset
from transformers import BertTokenizerFast

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from zshot.linker.smxm.model import device


class ByDescriptionTaggerDataset(Dataset):
Expand Down
25 changes: 1 addition & 24 deletions zshot/linker/smxm/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import os
import zipfile
from typing import List, Tuple

import torch
from transformers import BertTokenizerFast

from zshot.linker.smxm.model import device
from zshot.utils.data_models import Entity
from zshot.utils.data_models import Span
from zshot.linker.smxm.model import BertTaggerMultiClass
from zshot.utils import download_file

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class SmxmInput(dict):
Expand All @@ -34,24 +29,6 @@ def __init__(
super().__init__(**config)


def load_model(url: str, output_path: str, folder_name: str) -> BertTaggerMultiClass:
filename = url.rsplit("/", 1)[1]
model_zipfile_path = os.path.join(output_path, filename)
model_folder_path = os.path.join(output_path, folder_name)

if not os.path.isdir(model_folder_path):
download_file(url, output_path)
with zipfile.ZipFile(model_zipfile_path, "r") as model_zip:
model_zip.extractall(output_path)
os.remove(model_zipfile_path)

model = BertTaggerMultiClass.from_pretrained(
model_folder_path, output_hidden_states=True
).to(device)

return model


def predictions_to_span_annotations(
sentences: List[str],
predictions: List[List[int]],
Expand Down

0 comments on commit 237d5d9

Please sign in to comment.