In [None]:
import torch
from transformers import ViltProcessor, ViltForQuestionAnswering
from PIL import Image

In [None]:
class ViLTInference:
    def __init__(self, model_path, device=None):
        # Load the processor and model
        self.processor = ViltProcessor.from_pretrained(model_path)
        self.model = ViltForQuestionAnswering.from_pretrained(model_path)
        
        # Set the device
        if device:
            self.device = torch.device(device)
        else:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.model.to(self.device)
        self.model.eval()

    def predict(self, image_path, question):
        try:
            # Load and preprocess the image
            image = Image.open(image_path).convert("RGB")

            # Process the image and question
            inputs = self.processor(images=image, text=question, return_tensors="pt").to(self.device)

            # Forward pass
            with torch.no_grad():
                outputs = self.model(**inputs)

            # Get the predicted class index
            predicted_class_idx = outputs.logits.argmax(-1).item()

            # Map predicted class index to 'Yes' or 'No'
            answer = "Yes" if predicted_class_idx == 1 else "No"

            return answer

        except Exception as e:
            print(f"Error during inference: {e}")
            return None

In [None]:
if __name__ == "__main__":
    model_path = "fine-tuned-vilt-model"
    image_path = "image.jpg"
    question = "Is there any abnormality in the left lung?"

    # Initialize the inference module
    vilt_inference = ViLTInference(model_path)

    # Perform inference
    result = vilt_inference.predict(image_path, question)

    # Display the result
    if result is not None:
        print(f"Predicted Answer: {result}")