Some of Facebook AI Research's ESM models are fairly large. One option for speeding up the prediction is converting the Pytorch model to ONNX and using ONNX Runtime to run the inference.
While the conversion should be possible just using the familiar torch.onnx.export()
syntax, if you want to extract embeddings from the model, you'll most likely want to pass the repr_layers
argument to the model's forward()
Trying to pass this argument inside the export()
call will cause an exception (because while ProteinBertModel.forward()
expects a plain-old []
, a JIT-compatible format such as torch.tensor([])
is required for the export operation).
Option 1: Using Docker
docker build -t esm-2-onnx -f Dockerfile.CUDA . # if you have a CUDA-capable machine with nvidia-container-runtime
docker run -it --gpus=0 esm-2-onnx:latest bash # running using just 1 GPU (also set inside the Dockerfile)
# or
docker build -t esm-2-onnx -f Dockerfile.CPU . # otherwise
docker run -it esm-2-onnx:latest bash
Consider adding a bind-mount argument to the
docker run
command in order to persist the model files:-v /mnt/models/esm1b:/mnt/models/esm1b
Option 2: Without using Docker
pip install -r requirements.CUDA.txt # if you have a CUDA-capable machine
# or
pip install -r requirements.CPU.txt # otherwise
mkdir -p $(dirname $CONVERTED_GRAPH_PATH) $(dirname $OPTIMIZED_GRAPH_PATH)
python src/ --model-path $MODEL_PATH --converted-model-path $CONVERTED_GRAPH_PATH
python -m onnxruntime_tools.optimizer_cli --float16 --opt_level 99 --use_gpu \
--model_type bert --hidden_size 1024 --num_heads 16 --input $CONVERTED_GRAPH_PATH \
--output $OPTIMIZED_GRAPH_PATH # convert to float 16 precision and apply all available optimizations
python src/ --model-path $OPTIMIZED_GRAPH_PATH
When comparing the prediction outputs between the original Pytorch model and the ONNX version, there seems to be a big loss of precision (significantly more pronounced compared to the HuggingFace ProtBERT models with the the optimizations applied to the ESM model). Performing a few comparisons and judging the signficance of the precision loss is cruicial before adpoting the ONNX version of model for your workloads.