In [2]:
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import torch
from PIL import Image
from typing import Optional, Dict, Any


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class DiagramCaptioner:
    def __init__(self, model_name: str = "Salesforce/instructblip-flan-t5-xl"):
        
        #check which device is being used
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Loading model on device: {self.device}")
        self.model = InstructBlipForConditionalGeneration.from_pretrained(model_name)
        self.processor = InstructBlipProcessor.from_pretrained(model_name)
        self.model.to(self.device)
        self.prompt ="""
        
        You are a helpful assistant that generates captions for diagrams. 
        That retains all the important information in the diagram. The caption should be concise and information rich. 
        The caption should be in English. It should be capable of replacing the diagram in a document.

        If it is a flowchart, 
            Describe this flowchart diagram by listing:
            1. Each step or decision point in order
            2. The flow direction between elements
            3. Any conditional branches or loops
            4. The start and end points
            
            Format your response as a clear description of the process flow.

        If it has arctectural elements:
            Describe this architecture diagram by identifying:
            1. All components and their names as shown
            2. The connections between components
            3. Any data flow or communication paths
            4. The overall system structure
            
            Provide a clear description of the architectural layout and relationships.
        """
    
    def caption_image(self,
                      image: Image.Image, #required type for PIL images
                      prompt: Optional[str] = None,
                      kwargs: Optional[Dict[str, Any]] = None) -> str:
        if prompt is None:
            prompt = self.prompt

        defualt_kwargs = {
            "do_sample": True,
            "num_beams": 5,
            "max_length": 512,
            "min_length": 1,
            "top_p": 0.9,
            "repetition_penalty": 1,
            "length_penalty": 1.0,
            "temperature": 1.4,
            }
        
        if kwargs:
            defualt_kwargs.update(kwargs)

        if image.mode != 'RGB':
            image = image.convert('RGB')
        try:
            inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)

            outputs = self.model.generate(**inputs, **defualt_kwargs)

            generated_caption = self.processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
        finally:
            del inputs
            if 'outputs' in locals():
                del outputs
            
            # Clear GPU cache of unused tensors
            if torch.cuda.is_available():
                torch.cuda.empty_cache()


        return generated_caption




In [4]:
captioner = DiagramCaptioner() 
res = captioner.caption_image(
    image=Image.open("Images/O-RAN.WG3.TS.E2AP-R004-v07.00 (1)/page41_img2.png").convert("RGB")
)

print(res)

Loading model on device: cuda


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  7.25it/s]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


If it is a flowchart, describe this flowchart diagram by listing: 1. Each step or decision point in order 2. The flow direction between elements 3. Any conditional branches or loops 4. The start and end points Format your response in a clear description of the process flow. If it has arctectural elements: Describe this architecture diagram by identifying: 1. All components and their names as shown 2. The connections between components 3. Any data flow or communication paths 4. The overall system structure Provide a clear description of the architectural layout and relationships.
