In [19]:
import os
import tempfile
import torch

import mlflow
import mlflow.pyfunc
from mlflow.tracking import MlflowClient

from utils.modelling import get_device

In [None]:
ML_HOST = None  # replace
ML_PORT = None  # replace
ML_SERVER = f"http://{ML_HOST}:{ML_PORT}"

mlflow.set_tracking_uri(ML_SERVER)
mlflow.set_registry_uri(ML_SERVER)

In [21]:
client = MlflowClient()
assert len(client.list_experiments()) == 6

In [22]:
for rm in client.search_model_versions("name='{}'".format("baseline_lstm")):
    print((rm.name, rm.version, rm.run_id))

('baseline_lstm', '1', 'a8efecd449054715aa9bc3500964fcbb')
('baseline_lstm', '2', '3f863e1dae724de69098369c0ab80e70')
('baseline_lstm', '3', 'd97a6bba3a52478eb956382b2d0a9f9b')
('baseline_lstm', '4', 'b8bed3ea0cd94b2e8fd066c51ae1d841')


In [23]:
def load_torchscript_model(run_id: str, src: str = ".", use_cuda: bool = False) -> torch.nn.Module:
    with tempfile.TemporaryDirectory() as tmpdir:
        path_to_tmp_files = client.download_artifacts(run_id, src, tmpdir)
        return torch.jit.load(os.path.join(path_to_tmp_files, "model.pt"), map_location=get_device())

In [24]:
# Use highest version of the model
RUN_ID = "b8bed3ea0cd94b2e8fd066c51ae1d841"

model = load_torchscript_model(RUN_ID)

In [25]:
model

RecursiveScriptModule(
  original_name=LSTM
  (lstm1): RecursiveScriptModule(original_name=LSTMCell)
  (lstm2): RecursiveScriptModule(original_name=LSTMCell)
  (linear): RecursiveScriptModule(original_name=Linear)
)