In [1]:
"""Stages a model for use in production.

If based on a checkpoint, the model is converted to torchscript, saved locally,
and uploaded to W&B.

If based on a model that is already converted and uploaded, the model file is downloaded locally.

For details on how the W&B artifacts backing the checkpoints and models are handled,
see the documenation for stage_model.find_artifact.
"""
import argparse
from pathlib import Path
import tempfile

import torch
import wandb

#from text_recognizer.lit_models import TransformerLitModel

In [4]:
from training.util import setup_data_and_model_from_args

In [6]:
# these names are all set by the pl.loggers.WandbLogger
MODEL_CHECKPOINT_TYPE = "model"
BEST_CHECKPOINT_ALIAS = "best"
MODEL_CHECKPOINT_PATH = "model.ckpt"
LOG_DIR = Path("training") / "logs"

STAGED_MODEL_TYPE = "prod-ready"  # we can choose the name of this type, and ideally it's different from checkpoints
STAGED_MODEL_FILENAME = "model.pt"  # standard nomenclature; pytorch_model.bin is also used

PROJECT_ROOT = Path('.').resolve().parents[1]

In [8]:
api = wandb.Api()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [9]:
DEFAULT_ENTITY = api.default_entity
DEFAULT_FROM_PROJECT = "deepconc"
DEFAULT_TO_PROJECT = "deepconc"
DEFAULT_STAGED_MODEL_NAME = "diseaseclassifier"

PROD_STAGING_ROOT = PROJECT_ROOT / "tomatodiagnosis"/ "diseaseclassifier" / "artifacts"

In [13]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--staged_model_name', type=str, default='my_model')

args = parser.parse_args()

usage: ipykernel_launcher.py [-h] [--staged_model_name STAGED_MODEL_NAME]
ipykernel_launcher.py: error: unrecognized arguments: -f /storage/cfg/.local/share/jupyter/runtime/kernel-873b12b8-9d10-4862-8d72-c41a92bd955d.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [14]:
prod_staging_directory = PROD_STAGING_ROOT / args.staged_model_name
prod_staging_directory.mkdir(exist_ok=True, parents=True)
entity = _get_entity_from(args)
# if we're just fetching an already compiled model
if args.fetch:
     # find it and download it
    staged_model = f"{entity}/{args.from_project}/{args.staged_model_name}:latest"
    artifact = download_artifact(staged_model, prod_staging_directory)
    print_info(artifact)

NameError: name 'args' is not defined

In [None]:
# otherwise, we'll need to download the weights, compile the model, and save it
with wandb.init(
    job_type="stage", project=args.to_project, dir=LOG_DIR
):  # log staging to W&B so prod and training are connected
    # find the model checkpoint and retrieve its artifact name and an api handle
    ckpt_at, ckpt_api = find_artifact(
        entity, args.from_project, type=MODEL_CHECKPOINT_TYPE, alias=args.ckpt_alias, run=args.run
    )

    # get the run that produced that checkpoint
    logging_run = get_logging_run(ckpt_api)
    print_info(ckpt_api, logging_run)
    metadata = get_checkpoint_metadata(logging_run, ckpt_api)

    # create an artifact for the staged, deployable model
    staged_at = wandb.Artifact(args.staged_model_name, type=STAGED_MODEL_TYPE, metadata=metadata)
    with tempfile.TemporaryDirectory() as tmp_dir:
        # download the checkpoint to a temporary directory
        download_artifact(ckpt_at, tmp_dir)
        # reload the model from that checkpoint
        model = load_model_from_checkpoint(metadata, directory=tmp_dir)
        # save the model to torchscript in the staging directory
        save_model_to_torchscript(model, directory=prod_staging_directory)

    # upload the staged model so it can be downloaded elsewhere
    upload_staged_model(staged_at, from_directory=prod_staging_directory)

In [None]:
def main(args):
    prod_staging_directory = PROD_STAGING_ROOT / args.staged_model_name
    prod_staging_directory.mkdir(exist_ok=True, parents=True)
    entity = _get_entity_from(args)
    # if we're just fetching an already compiled model
    if args.fetch:
        # find it and download it
        staged_model = f"{entity}/{args.from_project}/{args.staged_model_name}:latest"
        artifact = download_artifact(staged_model, prod_staging_directory)
        print_info(artifact)
        return  # and we're done