<a href="https://colab.research.google.com/github/NUMAIRn/AI-Video-Activity-Clip-Extractor/blob/main/AI_Video_Activity_Clip_Extractor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install sentence_transformers transformers opencv-python Pillow

In [None]:
import cv2
import numpy as np
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
from sentence_transformers import SentenceTransformer, util

# Load pre-trained image captioning model
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

# Load sentence-transformer model for similarity matching
similarity_model = SentenceTransformer('all-MiniLM-L6-v2')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Caption generation settings
max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

# Function to generate captions for frames
def predict_step(images):
    pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    output_ids = model.generate(pixel_values, **gen_kwargs)
    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]
    return preds

# Function to process video, generate captions, and save to a text file
def generate_captions(video_path, captions_file):
    cap = cv2.VideoCapture(video_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    frame_count = 0
    images_to_caption = []
    timestamps = []

    # Process the video frame by frame to capture the first frame of each second
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Capture first frame of each second
        if frame_count % fps == 0:
            # Convert frame (numpy array) to PIL image
            pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            images_to_caption.append(pil_image)
            timestamps.append(frame_count // fps)

        frame_count += 1

    # Generate captions for the selected frames
    captions = predict_step(images_to_caption)

    # Save captions to a text file with timestamps
    with open(captions_file, 'w') as f:
        for caption, timestamp in zip(captions, timestamps):
            f.write(f"{caption}: {timestamp} sec\n")

    cap.release()
    return captions, timestamps

# Function to extract part of the video based on the query using semantic similarity
def extract_video_by_query(video_path, captions_file, query, output_path, threshold=0.5, max_gap=10):
    # Load the captions and timestamps from the text file
    captions = []
    timestamps = []
    with open(captions_file, 'r') as f:
        for line in f:
            caption, timestamp = line.rsplit(':', 1)
            captions.append(caption.strip())
            timestamps.append(int(timestamp.strip().replace("sec", "")))

    # Generate embeddings for the captions
    caption_embeddings = similarity_model.encode(captions, convert_to_tensor=True)

    # Generate embedding for the query
    query_embedding = similarity_model.encode(query, convert_to_tensor=True)

    # Calculate cosine similarities between the query and captions
    similarities = util.pytorch_cos_sim(query_embedding, caption_embeddings).squeeze()

    # Find the timestamps of captions that are most similar to the query
    matched_timestamps = []
    for idx, similarity in enumerate(similarities):
        if similarity.item() > threshold:  # Only consider captions with high similarity
            matched_timestamps.append(timestamps[idx])

    if not matched_timestamps:
        print(f"No relevant matches found for query: {query}")
        return

    # Step 1: Sort matched timestamps
    matched_timestamps = sorted(matched_timestamps)

    # Step 2: Group the timestamps that are close to each other based on max_gap (e.g., 3 seconds)
    grouped_timestamps = []
    group = [matched_timestamps[0]]

    for i in range(1, len(matched_timestamps)):
        if matched_timestamps[i] - matched_timestamps[i-1] <= max_gap:
            group.append(matched_timestamps[i])
        else:
            grouped_timestamps.append(group)
            group = [matched_timestamps[i]]

    grouped_timestamps.append(group)  # Add the last group

    # Step 3: Choose the first group (closest match)
    best_group = grouped_timestamps[0]
    start_time = min(best_group)
    end_time = max(best_group)

    # Step 4: Extract video based on start and end times
    cap = cv2.VideoCapture(video_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Define the codec and create VideoWriter object for the output video
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

    frame_count = 0
    start_frame = start_time * fps
    end_frame = (end_time + 1) * fps  # Include all frames up to the end of the last second

    # Extract frames within the time range
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret or frame_count > end_frame:
            break

        if start_frame <= frame_count <= end_frame:
            out.write(frame)

        frame_count += 1

    cap.release()
    out.release()
    cv2.destroyAllWindows()

def process_video_and_extract(video_path, query, output_video_path, captions_file='captions.txt', similarity_threshold=0.5):
    # Step 1: Generate captions and save to file
    captions, timestamps = generate_captions(video_path, captions_file)

    # Step 2: Extract video based on query using semantic similarity
    extract_video_by_query(video_path, captions_file, query, output_video_path, threshold=similarity_threshold)

process_video_and_extract(
    video_path='input_video_filepath.mp4',
    query='user query for video extraction',
    output_video_path='extracted_video.mp4',
    similarity_threshold=0.5  # Adjust this to control how close the match should be
)
