In [1]:
from neuron_padded_generation import NeuronBlipForQuestionAnswering, TextEncoderWrapper, VisionModelWrapper, DecoderPaddedGenerator
from transformers import BlipForQuestionAnswering , BlipProcessor
import torch
from PIL import Image
import os

# import numpy as np
# from io import BytesIO
import requests

import torch_neuronx

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
NEURON_MODEL = "neuron_models"
LOG_DIR = "logs"
max_decoder_length = 8                  # Maximum output token length

In [3]:
def trace(model, directory, compiler_args=f"--auto-cast-type fp16 --logfile {LOG_DIR}/log-neuron-cc.txt"):
    if os.path.isfile(directory):
        print(f"Provided path ({directory}) should be a directory, not a file")
        return

    os.makedirs(directory, exist_ok=True)
    os.makedirs(LOG_DIR, exist_ok=True)

    # skip trace if the model is already traced
    if not os.path.isfile(os.path.join(directory, 'text_decoder.pt')):
        print("Tracing text_decoder")
        inputs = (
            torch.ones((1, 20), dtype=torch.int64),
            torch.ones((1, 8, 768), dtype=torch.float32),
            torch.ones((1, 8), dtype=torch.int64),
            torch.ones((1, 20), dtype=torch.int64),
            torch.tensor([3]),
        )

        decoder = torch_neuronx.trace(model.text_decoder.decoder.decoder, inputs, compiler_args=compiler_args)
        torch.jit.save(decoder, os.path.join(directory, 'text_decoder.pt'))
    else:
        print('Skipping text_decoder.pt')

    if not os.path.isfile(os.path.join(directory, 'vision_model.pt')):
        print("Tracing vision_model")
        inputs = (
            torch.ones((1, 3, 384, 384), dtype=torch.float32)
        )

        vision = torch_neuronx.trace(model.vision_model.model, inputs, compiler_args=compiler_args)
        torch.jit.save(vision, os.path.join(directory, 'vision_model.pt'))
    else:
        print('Skipping vision_model.pt')

    if not os.path.isfile(os.path.join(directory, 'text_encoder.pt')):
        print("Tracing text_encoder")
        inputs = (
            torch.ones((1, 8), dtype=torch.int64),
            torch.ones((1, 8), dtype=torch.int64),
            torch.ones((1, 577, 768), dtype=torch.float32),
            torch.ones((1, 577), dtype=torch.int64),
        )

        encoder = torch_neuronx.trace(model.text_encoder.model, inputs, compiler_args=compiler_args)
        torch.jit.save(encoder, os.path.join(directory, 'text_encoder.pt'))
    else:
        print('Skipping text_encoder.pt')

    traced_model = NeuronBlipForQuestionAnswering.from_pretrained(directory, 1)[0]

    return traced_model


In [4]:
def infer(model, processor, text, image):
    # Truncate and pad the max length to ensure that the token size is compatible with fixed-sized encoder (Not necessary for pure CPU execution)
    input =  processor(image, text, max_length=max_decoder_length, truncation=True, padding='max_length', return_tensors="pt")
    output = model.generate(**input, max_length=max_decoder_length)
    results = processor.decode(output[0], skip_special_tokens=True)
    print(results)

In [5]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

text = "Where is the pet?"

processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model_cpu = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
model_cpu.save_pretrained(NEURON_MODEL)
processor.save_pretrained(NEURON_MODEL)
model_cpu.config.max_length = max_decoder_length
model_cpu.eval()
text_decoder_max_length = model_cpu.text_decoder.config.max_length

print('Default CPU Results:')
infer(model_cpu, processor, text, image)
print()

model_padded = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
model_padded.text_encoder = TextEncoderWrapper(model_padded.text_encoder)
model_padded.vision_model = VisionModelWrapper(model_padded.vision_model)
model_padded.text_decoder = DecoderPaddedGenerator.from_model(model_padded.text_decoder)

model_padded.config.max_length = max_decoder_length
model_padded.eval()

print('Padded CPU Results:')
infer(model_padded, processor, text, image)
print()

traced_model = trace(model_padded, NEURON_MODEL)
traced_model.config.max_length = max_decoder_length
traced_model.text_decoder.config.max_length = text_decoder_max_length
print('Traced Results:')
infer(traced_model, processor, text, image)
print()

print("Complete!")


Default CPU Results:
couch

Padded CPU Results:
couch

Tracing text_decoder
....
Compiler status PASS
Tracing vision_model
...
Compiler status PASS
Tracing text_encoder
...
Compiler status PASS
Traced Results:
couch

Complete!
