<a href="https://colab.research.google.com/github/AdopleAIOrg/Image-visual-question-and-answer/blob/main/Image_visual_question_and_answer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers
!pip install streamlit

In [None]:
import streamlit as st
import requests
from PIL import Image
import torch
from transformers import ViltProcessor, ViltForQuestionAnswering

class ImageVisulaQuestionAnswer:

    def __init__(self):

      """
      Initializes the ImageVisualQuestionAnswering class by loading the pre-trained ViLT model and processor.
      """

      # Load the pre-trained ViLT processor for image and text encoding
      self.processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

      # Load the pre-trained ViLT model for Visual Question Answering
      self.model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

    def _get_question(self) -> tuple:

      """
      Displays the image uploader and question input fields and returns the uploaded image and the user's question.

      Returns:
          tuple: A tuple containing the PIL image and the user's question.
      """

      st.title("Visual Question Answering with ViLT")

      # upload image
      uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"])
      if uploaded_file is not None:
          image = Image.open(uploaded_file)
      else:
          # default image
          url = "http://images.cocodataset.org/val2017/000000039769.jpg"
          image = Image.open(requests.get(url, stream=True).raw)

      # ask question
      question = st.text_input("Enter your question")

      return image, question

    def get_answer(self) -> None:

        """
        Gets the user's question and displays the answer using the ViLT model.
        """

        # Get the uploaded image and user's question from the _get_question method
        image, question = self._get_question()

        # Encode the image and question using the ViLT processor
        encoding = self.processor(image, question, return_tensors="pt")

        if question:
            # Pass the encoded image and question through the ViLT model
            outputs = self.model(**encoding)

            # Extract the logits from the model outputs
            logits = outputs.logits

            # Get the index of the predicted answer using argmax
            idx = logits.argmax(-1).item()

            # Get the index of the predicted answer using argmax
            answer = self.model.config.id2label[idx]

            # Display the predicted answer
            st.write("Answer:", answer)
        else:
            # If no question is provided, display a warning
            st.warning("Please enter a question")

        # Show the image in the Streamlit app
        st.image(image, caption="Image", use_column_width=True)

if __name__ == "__main__":

    image_qa = ImageVisulaQuestionAnswer()
    image_qa.get_answer()

In [None]:
!streamlit run app.py & npx localtunnel --port 8501