Skip to content

Latest commit

 

History

History
35 lines (20 loc) · 1.52 KB

torch_ort.rst

File metadata and controls

35 lines (20 loc) · 1.52 KB

Torch ORT Callback

Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. See installation instructions here.

This is primarily useful for when training with a Transformer model. The ORT callback works when a single model is specified as self.model within the LightningModule as shown below.

Note

Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.

from pytorch_lightning import LightningModule, Trainer
from transformers import AutoModel

from pl_bolts.callbacks import ORTCallback


class MyTransformerModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.model = AutoModel.from_pretrained('bert-base-cased')

    ...


model = MyTransformerModel()
trainer = Trainer(gpus=1, callbacks=ORTCallback())
trainer.fit(model)

For even easier setup and integration, have a look at our Lightning Flash integration for Text Classification <lightning_flash:text_classification_ort>, Translation <lightning_flash:translation_ort> and Summarization <lightning_flash:summarization_ort>.