[Reference](https://chaimrand.medium.com/capturing-and-deploying-pytorch-models-with-torch-export-480f0d9ea8fd)

In [1]:
import torch

NUM_TOKENS = 1024
MAX_SEQ_LEN = 256
PAD_ID = 0
START_ID = 1
END_ID = 2

# Set up an image-to-text model.
def get_model():

    # import transformers utilities
    from transformers import (
        VisionEncoderDecoderModel,
        VisionEncoderDecoderConfig,
        AutoConfig
    )

    config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(
        encoder_config=AutoConfig.for_model("vit"),  # vit encoder
        decoder_config=AutoConfig.for_model("gpt2")  # gpt2 decoder
    )
    config.decoder.vocab_size = NUM_TOKENS
    config.decoder.use_cache = False
    config.decoder_start_token_id = START_ID
    config.pad_token_id = PAD_ID
    config.eos_token_id = END_ID
    config.max_length = MAX_SEQ_LEN

    model = VisionEncoderDecoderModel(config=config)
    model.encoder.pooler = None  # remove unused pooler
    model.eval() # prepare the model for evaluation
    return model

In [2]:
# generate the next token
def generate_token(decoder, encoder_hidden_states, sequence):
    outputs = decoder(
        sequence,
        encoder_hidden_states
    )
    logits = outputs[0][:, -1, :]
    return torch.argmax(logits, dim=-1, keepdim=True)

# simple auto-regressive sequence generator
def image_to_text_generator(encoder, decoder, image):
    # run encoder
    encoder_hidden_states = encoder(image)[0]

    # initialize sequence
    generated_ids = torch.ones(
        (image.shape[0], 1),
        dtype=torch.long,
        device=image.device
    ) * START_ID

    for _ in range(MAX_SEQ_LEN):
        # generate next token
        next_token = generate_token(
            decoder,
            encoder_hidden_states,
            generated_ids
        )
        generated_ids = torch.cat([generated_ids, next_token], dim=-1)
        if (next_token == END_ID).all():
            break

    return generated_ids

In [3]:
import os, time, random, torch

torch.manual_seed(42)
random.seed(42)

BATCH_SIZE = 64
EXPORT_PATH = '/tmp/export/'

def test_inference(model_path=EXPORT_PATH, mode=None, compile=False):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    rnd_image = torch.randn(BATCH_SIZE, 3, 224, 224).to(device)
    encoder, decoder = load_model(model_path, mode)
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    if compile:
        encoder = torch.compile(encoder, mode="reduce-overhead")
        decoder = torch.compile(decoder, dynamic=True)
        # run a few warmup rounds
        for i in range(10):
            image_to_text_generator(encoder, decoder, random_image)

    t0 = time.perf_counter()

    # optionally enable mixed precision
    with torch.amp.autocast(device, dtype=torch.bfloat16, enabled=True):
        with torch.no_grad():
            caption = image_to_text_generator(encoder, decoder, rnd_image)

    total_time = time.perf_counter() - t0
    print(f'batched inference total time: {total_time}')

In [4]:
class DecoderWrapper(torch.nn.Module):
    def __init__(self, decoder_model):
        super().__init__()
        self.decoder = decoder_model

    def forward(self, input_ids, encoder_hidden_states):
        return self.decoder(
            input_ids=input_ids,
            encoder_hidden_states=encoder_hidden_states,
            use_cache=False,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=False
        )

def load_model(path=EXPORT_PATH, mode=None):
    model = get_model()
    encoder = model.encoder
    decoder = model.decoder
    return encoder, DecoderWrapper(decoder)

# Model Capturing and Deployment Strategies

In [5]:
def capture_model(model, path=EXPORT_PATH):
    # weights only
    weights_path = os.path.join(EXPORT_PATH, "weights.pth")
    torch.save(model.state_dict(), weights_path)

def load_model(path=EXPORT_PATH, mode=None):
    if mode == 'weights':
        model = get_model()
        weights_path = os.path.join(path,"weights.pth")
        state_dict = torch.load(weights_path, map_location="cpu")
        model.load_state_dict(state_dict)
        return model.encoder, DecoderWrapper(model.decoder)
    else:
        model = get_model()
        return model.encoder, DecoderWrapper(model.decoder)

In [6]:
def capture_model(model, path=EXPORT_PATH):
    # weights only
    weights_path = os.path.join(EXPORT_PATH, "weights.pth")
    torch.save(model.state_dict(), weights_path)

    encoder = model.encoder
    decoder = DecoderWrapper(model.decoder)

    # torchscript encoder using trace
    example = torch.randn(1, 3, 224, 224)
    encoder_jit = torch.jit.trace(encoder, example)
    # optionally apply jit.freeze optimization
    encoder_jit = torch.jit.freeze(encoder_jit)
    encoder_path = os.path.join(path, "encoder.pt")
    torch.jit.save(encoder_jit, encoder_path)

    try:
        # torchscript decoder using scripting
        decoder_jit = torch.jit.script(decoder)
        # optionally apply jit.freeze optimization
        decoder_jit = torch.jit.freeze(decoder_jit)
        decoder_path = os.path.join(path, "decoder.pt")
        torch.jit.save(decoder_jit, decoder_path)
    except Exception as e:
        print(f'torch.jit.script(model.decoder) failed\n{e}')

def load_model(path=EXPORT_PATH, mode=None):
    if mode == 'weights':
        model = get_model()
        weights_path = os.path.join(path,"weights.pth")
        state_dict = torch.load(weights_path, map_location="cpu")
        model.load_state_dict(state_dict)
        return model.encoder, DecoderWrapper(model.decoder)
    elif mode == 'torchscript':
        encoder_path = os.path.join(path, "encoder.pt")
        decoder_path = os.path.join(path, "decoder.pt")
        encoder = torch.jit.load(encoder_path)
        decoder = torch.jit.load(decoder_path)
        # optionally apply target-device optimization
        encoder = torch.jit.optimize_for_inference(encoder)
        decoder = torch.jit.optimize_for_inference(decoder)
        return encoder, decoder
    else:
        model = get_model()
        return model.encoder, DecoderWrapper(model.decoder)