# Video Scene Extraction with Image Captioning and Text Query

This project extends the functionality of video scene extraction by incorporating an image captioning model (BLIP) alongside the CLIP model for improved query-to-video relevance. It enables users to extract the most relevant scene from a video based on a text query, using both visual context and automatically generated captions for better scene understanding.

### Features
Scene Detection: Automatically detects scenes in a video using the scenedetect library.

Frame Extraction: Extracts frames from each detected scene for analysis.

Image-to-Text (I2T) Embedding: Uses the BLIP model to generate captions for each frame, improving the contextual understanding of video content.

Text-Image Similarity: Uses CLIP to compute the similarity between the provided text query and the frames' captions.

Scene Extraction: Extracts and saves the most relevant scene based on the highest similarity to the query as a new video.

Gradio Interface: A simple web interface to upload videos and input search queries for scene extraction.

In [1]:
# Setting up logging
import logging

logger = logging.getLogger('my_custom_logger')

logger.setLevel(logging.DEBUG)

ch = logging.StreamHandler()

formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)

logger.addHandler(ch)

In [None]:
import cv2
import torch
from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import scenedetect
import logging
import gradio as gr
import datetime

import warnings
warnings.filterwarnings('ignore')


class VideoProcessor:
    def __init__(self, device="cuda"):
        self.device = device if torch.cuda.is_available() else "cpu"

        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)

        self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
        self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(self.device)

    def detect_scenes(self, video_path):
        """
        Detect scenes in the video using SceneDetect.
        """
        scene_manager = scenedetect.SceneManager()
        scene_manager.add_detector(scenedetect.detectors.ContentDetector())
        video = scenedetect.open_video(video_path)
        scene_manager.detect_scenes(video)
        return scene_manager.get_scene_list()

    def extract_frames(self, video_path, scene_start, scene_end):
        """
        Extract frames from a specific scene in the video.
        """
        cap = cv2.VideoCapture(video_path)
        cap.set(cv2.CAP_PROP_POS_FRAMES, scene_start)
        frames = []
        for t in range(scene_start, scene_end + 1):
            cap.set(cv2.CAP_PROP_POS_FRAMES, t)
            ret, frame = cap.read()
            if ret:
                pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                frames.append(pil_image)

        cap.release()
        return frames

    def get_query_embedding(self, query):
        """
        Get the embedding of the query using CLIP.
        """
        try:
            text_inputs = self.processor(text=[query], return_tensors="pt", padding=True).to(self.device)
            with torch.no_grad():
                query_embedding = self.model.get_text_features(input_ids=text_inputs.input_ids)
            return query_embedding
        except Exception as ex:
            logger.error('CLIP query embeddings exception: {ex}')
            return

    def compute_I2T_embeddings(self, frames):
        """
        Compute the embeddings of frames query using BLIP.
        """
        query_embeddings = []
        count = 0
        for frame in frames:
            try:
                inputs = self.blip_processor(frame, return_tensors="pt").to(self.device)
                out = self.blip_model.generate(**inputs)
                query = self.blip_processor.decode(out[0], skip_special_tokens=True)

                count+=1
                logger.debug(f'frame NO. {count} - query: {query}')

                query_embedding = self.get_query_embedding(query)
                query_embeddings.append(query_embedding)
                logger.info(f'count: {count} end')
            except Exception as ex:
                logger.error(f'I2T embeddings exception: {ex}')
                continue
        
        return torch.stack(query_embeddings)

    def process_scene(self, frames_in_scene, query_embedding):
        """
        Process each scene to compute the similarity between the frames and the query.
        """
        try:
            scene_embeddings = self.compute_I2T_embeddings(frames_in_scene)
            similarities = torch.nn.functional.cosine_similarity(query_embedding, scene_embeddings)
            return similarities.mean().item()
        except Exception as e:
            logger.error(f"Error processing scene: {e}")
            return -1

    def extract_scene_from_video(self, video_path, scene_start, scene_end, output_video_path):
        """
        Extract the most relevant scene from the video and save to a fixed file.
        """
        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS)
        output_video = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps,
                                       (int(cap.get(3)), int(cap.get(4))))

        cap.set(cv2.CAP_PROP_POS_FRAMES, scene_start)
        frames = []
        for t in range(scene_start, scene_end + 1):
            cap.set(cv2.CAP_PROP_POS_FRAMES, t)
            ret, frame = cap.read()
            if ret:
                frames.append(frame)

        for frame in frames:
            output_video.write(frame)

        cap.release()
        output_video.release()

    def process_video(self, video_path, query, output_scene_path="output_scene.mp4"):
        """
        Main function to process the video, detect scenes, compute embeddings, and extract the relevant scene.
        """

        logger.debug(f"Step 1: Detect scenes")
        scene_list = self.detect_scenes(video_path)

        logger.info(f"Detected {len(scene_list)} scenes")

        logger.debug(f"Step 2: Extract frames for each scene")
        frames = []
        for scene in scene_list:
            scene_start, scene_end = map(int, scene)
            frames_in_scene = self.extract_frames(video_path, scene_start, scene_end)
            frames.append((frames_in_scene, scene_start, scene_end))

        logger.info(f"Extracted {len(frames)} frames from {len(frames)} scenes")

        logger.debug(f"Step 3: Get query embedding")
        query_embedding = self.get_query_embedding(query)

        logger.debug(f"Step 4: Process each scene sequentially to find the most relevant one")
        top_scene = None
        highest_similarity = -1

        for frames_in_scene, scene_start, scene_end in frames:
            scene_similarity = self.process_scene(frames_in_scene, query_embedding)
            if scene_similarity > highest_similarity:
                highest_similarity = scene_similarity
                top_scene = (scene_start, scene_end)

        logger.debug(f"Step 5: Extract the most relevant scene")
        if top_scene:
            scene_start, scene_end = top_scene
            self.extract_scene_from_video(video_path, scene_start, scene_end, output_scene_path)
            logger.info(f"Extracted scene saved as {output_scene_path}")
            return output_scene_path
        else:
            logger.info("No relevant scene found.")
            return None

# Gradio interface
def process_video_with_gradio(video_file, query):
    try:
        error_message = None

        # Check if query is empty
        if not query or not video_file:
            error_message = "Query and Video are required. Please enter a search term."

        if error_message:
            return [None, error_message]
    
        logger.info("initialize Video Processor")
        video_processor = VideoProcessor()

        # Process the video and get the output scene
        output_scene_path = video_processor.process_video(video_file, query, f"output_{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M')}")

        if output_scene_path:
            return [output_scene_path, None]
        else:
            return [None, "No relevant scene found."]
    except Exception as ex:
        logger.error(f"Process video exception: {ex}")


# Set up Gradio interface
iface = gr.Interface(
    fn=process_video_with_gradio,
    inputs=[
        gr.Video(),
        gr.Textbox(label="Query", placeholder="Enter the search query")
    ],
    outputs=[gr.Video(label="Extracted Video"), gr.Label(label="Error Message", value="", elem_id="error-message")],
    live=False,
    allow_flagging="never"
)
iface.css = """
    #error-message {
        color: red;
        font-weight: bold;
    }
"""

if __name__ == "__main__":
    iface.launch()
