In [1]:
# This script sets up an image captioning server using a pre-trained Vision Transformer model.
# The server generates captions for images provided via URL or file path.

import requests  # For making HTTP requests
import torch  # For using PyTorch capabilities
from PIL import Image  # To handle image loading
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, GPT2TokenizerFast
from urllib.parse import urlparse  # For URL validation
import os  # For file path operations
import litserve as ls  # LitServe library for API server

# Function to check if the input is a valid URL
def is_valid_url(url):
    try:
        result = urlparse(url)
        return all([result.scheme, result.netloc, result.path])
    except ValueError:
        return False

# Function to load an image from a URL or local file path
def get_image(image_source):
    if is_valid_url(image_source):
        return Image.open(requests.get(image_source, stream=True).raw)
    elif os.path.isfile(image_source):
        return Image.open(image_source)
    raise ValueError("Invalid image path or URL.")

# Define the image captioning API class
class CaptioningAPI(ls.LitAPI):
    def initialize(self, compute_device):
        # Determine the computation device (GPU if available, else CPU)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Load the pre-trained Vision-Encoder-Decoder model
        model_name = "nlpconnect/vit-gpt2-image-captioning"
        self.caption_model = VisionEncoderDecoderModel.from_pretrained(model_name).to(self.device)

        # Load the tokenizer and image processor
        self.caption_tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
        self.image_processor = ViTImageProcessor.from_pretrained(model_name)

    # Decode the client request to extract image path
    def parse_request(self, client_request):
        return client_request.get("image_path", "")

    # Predict caption for the input image
    def generate_caption(self, image_source):
        # Load the image
        image = get_image(image_source)

        # Preprocess the image
        processed_image = self.image_processor(image, return_tensors="pt").to(self.device)

        # Generate caption
        outputs = self.caption_model.generate(**processed_image)
        caption = self.caption_tokenizer.decode(outputs[0], skip_special_tokens=True)

        return caption

    # Encode the server response to return the generated caption
    def format_response(self, generated_caption):
        return {"generated_caption": generated_caption}

# Main entry point to run the server
if __name__ == "__main__":
    # Initialize the API and server
    api_instance = CaptioningAPI()
    caption_server = ls.LitServer(api_instance, accelerator="auto", devices=1, workers_per_device=1)

    # Run the server on the specified port
    caption_server.run(port=8000)


ModuleNotFoundError: No module named 'litserve'