<a href="https://colab.research.google.com/github/ArkS0001/VQA--Visual-Question-Answering/blob/main/VQA_%403.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
# Update package lists and install Tesseract OCR (optional fallback)
!apt-get update
!apt-get install -y tesseract-ocr

# Install required Python packages
!pip install easyocr torch torchvision transformers timm


0% [Working]            Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
0% [Connecting to archive.ubuntu.com (185.125.190.81)] [Connecting to security.ubuntu.com] [Waiting                                                                                                     Hit:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:3 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:4 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:5 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:6 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:7 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:8 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:9 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Reading package lists... Done
W: Skipping acquire of config

In [15]:
import logging
import re
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from transformers import ViTFeatureExtractor, ViTModel, pipeline as hf_pipeline
import easyocr

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# -----------------------------
# Visual Feature Extraction Module using ViT
# -----------------------------
class VisualFeatureExtractor(nn.Module):
    def __init__(self):
        super(VisualFeatureExtractor, self).__init__()
        # Load a pre-trained Vision Transformer model and its feature extractor.
        self.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
        self.vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.vit_model.eval()  # set to evaluation mode

    def forward(self, image: Image.Image) -> torch.Tensor:
        # Convert PIL image to a numpy array
        image_np = np.array(image)
        # Preprocess image using ViTFeatureExtractor (handles resizing, normalization, etc.)
        inputs = self.feature_extractor(images=image_np, return_tensors="pt")
        with torch.no_grad():
            outputs = self.vit_model(**inputs)
            # Use the [CLS] token representation as global image feature.
            features = outputs.last_hidden_state[:, 0, :]  # shape: (batch, hidden_size)
        return features

# -----------------------------
# OCR Module using EasyOCR
# -----------------------------
class OCRModule:
    def __init__(self, languages=["en"]):
        self.reader = easyocr.Reader(languages, gpu=torch.cuda.is_available())

    def extract_text(self, image: Image.Image) -> str:
        try:
            # EasyOCR expects a numpy array (in BGR format); convert RGB to BGR
            image_np = np.array(image)[:, :, ::-1]
            result = self.reader.readtext(image_np, detail=0)
            return " ".join(result)
        except Exception as e:
            logger.error("Error during OCR extraction: %s", e)
            return ""

    def extract_numbers(self, text: str) -> list:
        numbers = re.findall(r'\d+\.\d+|\d+', text)
        try:
            return [float(num) for num in numbers]
        except Exception as e:
            logger.error("Error converting numbers: %s", e)
            return []

# -----------------------------
# Diagram Understanding Module
# -----------------------------
class DiagramUnderstandingModule:
    def classify_diagram(self, ocr_text: str) -> str:
        lower_text = ocr_text.lower()
        if "start" in lower_text and "end" in lower_text:
            return "flowchart"
        elif "gantt" in lower_text or "project" in lower_text or re.search(r'\d{1,2}/\d{1,2}/\d{2,4}', lower_text):
            return "gantt chart"
        elif "bar" in lower_text or "chart" in lower_text or "graph" in lower_text:
            return "chart"
        else:
            return "diagram"

    def extract_structure(self, diagram_type: str, ocr_text: str) -> dict:
        # Placeholder for more advanced layout analysis
        return {"diagram_type": diagram_type, "content": ocr_text.strip()}

# -----------------------------
# Text Generation Module using an Instruction-Tuned Model
# -----------------------------
class TextGenerationModule:
    def __init__(self):
        # Use a state-of-the-art instruction-tuned model for detailed responses.
        # Note: "google/flan-t5-xl" is a large model; if resources are limited, consider "google/flan-t5-base".
        self.generator = hf_pipeline("text2text-generation", model="google/flan-t5-xl", tokenizer="google/flan-t5-xl")

    def generate_answer(self, prompt: str) -> str:
        result = self.generator(prompt, max_length=512, truncation=True)
        return result[0]['generated_text']

# -----------------------------
# Detailed Reasoning Module
# -----------------------------
class DetailedReasoningModule:
    def __init__(self):
        self.ocr_module = OCRModule()
        self.diagram_module = DiagramUnderstandingModule()
        self.text_gen_module = TextGenerationModule()

    def reason(self, question: str, ocr_text: str) -> str:
        # Optionally extract numbers for quantitative analysis
        numbers = self.ocr_module.extract_numbers(ocr_text)
        diagram_type = self.diagram_module.classify_diagram(ocr_text)
        structure = self.diagram_module.extract_structure(diagram_type, ocr_text)

        # Construct a detailed prompt that includes extracted OCR text and diagram classification.
        prompt = (
            f"Below is text extracted from an image:\n\n"
            f"{structure['content']}\n\n"
            f"This content appears to represent a {structure['diagram_type']}. "
            f"Based on this, answer the following question in detail:\n\"{question}\"\n\n"
            f"Provide a comprehensive explanation that covers the layout, textual content, numerical data (if any), and overall structure."
        )
        detailed_answer = self.text_gen_module.generate_answer(prompt)
        return detailed_answer

# -----------------------------
# Main VQA Pipeline
# -----------------------------
class VQAPipeline:
    def __init__(self, device: str = "cpu"):
        self.device = device
        self.visual_extractor = VisualFeatureExtractor().to(self.device)
        self.ocr_module = OCRModule()
        self.reasoning_module = DetailedReasoningModule()
        # For possible image processing (if needed) we include a simple transform.
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    def process(self, image_path: str, question: str) -> str:
        try:
            image = Image.open(image_path).convert("RGB")
        except Exception as e:
            logger.error("Error loading image: %s", e)
            return "Error loading image."

        # Visual Feature Extraction (this feature vector can be used in advanced fusion models)
        visual_features = self.visual_extractor(image)
        logger.info("Extracted visual features shape: %s", visual_features.shape)

        # OCR Extraction
        ocr_text = self.ocr_module.extract_text(image)
        logger.info("Extracted OCR text: %s", ocr_text.strip())

        # Detailed Reasoning with Text Generation
        answer = self.reasoning_module.reason(question, ocr_text)
        return answer


In [16]:
# For file upload in Colab:
from google.colab import files
uploaded = files.upload()  # Upload your image file here

# Get the first uploaded image file name
image_path = list(uploaded.keys())[0]

# Define your question. Try variations such as:
# "What type of diagram is this and how is it structured?"
# "Explain the process flow and key elements in the diagram."
question = "What type of diagram is this and provide a detailed explanation of its contents?"

# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Instantiate and run the VQA pipeline
pipeline = VQAPipeline(device=device)
answer = pipeline.process(image_path, question)
print("Predicted Answer:\n", answer)


Saving worked5.png to worked5 (7).png
Using device: cuda


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]



Progress: |██████████████████████████████████████████████████| 100.0% Complete



Progress: |██████████████████████████████████████████████████| 100.0% Complete

config.json:   0%|          | 0.00/1.44k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/53.0k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.45G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Device set to use cuda:0


RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor