In [26]:
!pip install moviepy pytube gradio transformers langchain faiss-cpu opencv-python-headless sentence-transformers



In [27]:
!pip install -U langchain--community



In [28]:
!pip install --trusted-host pypi.org --trusted-host files.pythonhosted.org openai-whisper



In [29]:
!pip install --trusted-host pypi.org --trusted-host files.pythonhosted.org torch



In [30]:
import uuid
import logging
import cv2
from pytube import YouTube
from pydub import AudioSegment
import gradio as gr
from langchain.vectorstores import FAISS
from transformers import pipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document
import whisper
from sentence_transformers import SentenceTransformer

In [31]:
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

In [32]:
LLM_MODEL = "facebook/bart-large-cnn"  # For summarization
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"  # For embeddings

In [33]:
def load_summarizer():
    return pipeline("summarization", model=LLM_MODEL)

def load_embedding_model():
    return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)

In [34]:
vector_db = None
# -------------------------------
# 📤 AUDIO EXTRACTION
# -------------------------------
def extract_audio(video_path):
    try:
        audio_path = f"{uuid.uuid4().hex}.wav"
        audio = AudioSegment.from_file(video_path)
        audio.export(audio_path, format="wav")
        return audio_path
    except Exception as e:
        logger.error(f"Audio extraction failed: {e}")
        return None


In [42]:
def transcribe_audio(audio_path):
    try:
        model = whisper.load_model("base")
        result = model.transcribe(audio_path)
        return result['text']
    except Exception as e:
        logger.error(f"Transcription failed: {e}")
        return None

In [43]:
def generate_embeddings(content):
    try:
        if isinstance(content, bytes):
            content = content.decode("utf-8")
        return load_embedding_model.embed_query(content,convert_to_tensor=True)
    except Exception as e:
        logger.error(f"Embedding generation failed: {e}")
        return None

In [44]:
def extract_keyframes(video_path):
    try:
        cap = cv2.VideoCapture(video_path)
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        interval = fps * 2  # Extract a frame every 2 seconds

        image_paths = []
        count = 0

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            if count % interval == 0:
                image_path = f"frame_{uuid.uuid4().hex}.jpg"
                cv2.imwrite(image_path, frame)
                image_paths.append(image_path)
            count += 1

        cap.release()
        return image_paths
    except Exception as e:
        logger.error(f"Image extraction failed: {e}")
        return None

In [45]:
# def store_in_vector_db(transcription, audio_path,frame_paths):
#     global vector_db
#     segments = transcription.get('segments', [])
#     documents = []
#     for segment in segments:
#         text = segment.get('text', '')
#         if not text.strip():
#             continue
#         start = segment.get('start', 0.0)
#         end = segment.get('end', 0.0)
#         chunk_embedding = generate_embeddings(text)
#         documents.append(Document(
#             page_content=text,
#             metadata={
#                 "type": "audio_chunk",
#                 "audio_path": audio_path,
#                 "start_time": start,
#                 "end_time": end,
#                 "embedding": chunk_embedding
#             }
#         ))
#     vector_db = FAISS.from_documents(documents, embedding_model)
#     logger.info("Vector database updated successfully!")
def store_in_vector_db(transcription, audio_path, frame_paths):
    global vector_db
    documents = []

    # Process transcription data
    if transcription:
        segments = transcription.get('segments', [])
        if segments:  # Handle segmented transcription
            for segment in segments:
                text = segment.get('text', '').strip()
                if not text:
                    continue
                start = segment.get('start', 0.0)
                end = segment.get('end', 0.0)
                chunk_embedding = generate_embeddings(text)
                documents.append(Document(
                    page_content=text,
                    metadata={
                        "type": "audio_chunk",
                        "audio_path": audio_path,
                        "start_time": start,
                        "end_time": end,
                        "embedding": chunk_embedding
                    }
                ))
        else:  # Handle plain transcription text
            transcription_text = transcription.get('text', '').strip()
            if transcription_text:
                transcription_embedding = generate_embeddings(transcription_text)
                documents.append(Document(
                    page_content=transcription_text,
                    metadata={
                        "type": "transcript",
                        "embedding": transcription_embedding
                    }
                ))

    # Process frame data
    if frame_paths:
        for frame_path in frame_paths:
          with open(frame_path, "rb") as img_file:
              frame_embedding = generate_embeddings(img_file.read())
              documents.append(Document(page_content=f"Frame at {frame_path}", metadata={"type": "frame", "path": frame_path, "embedding": frame_embedding}))
    vector_db = FAISS.from_documents(documents, load_embedding_model)
    logger.info("Vector database updated successfully!")


In [46]:
def retrieve_data(query, top_n=5):
    global vector_db

    if vector_db is None:
        logger.error("Vector database is not initialized!")
        return {"frames": [], "transcripts": [], "audio_segments": []}

    query_embedding = generate_embeddings(query)
    results = vector_db.similarity_search_by_vector(query_embedding, k=top_n)

    frames = [result.metadata["path"] for result in results if result.metadata.get("type") == "frame"]
    transcripts = [result.page_content for result in results if result.metadata.get("type") == "transcript"]
    audio_segments = [
        (result.metadata["audio_path"], result.metadata["start_time"], result.metadata["end_time"])
        for result in results
        if result.metadata.get("type") == "audio_chunk"
    ]

    return {"frames": frames, "transcripts": transcripts, "audio_segments": audio_segments}


def extract_audio_clips(audio_segments):
    combined_audio = AudioSegment.empty()
    for audio_path, start_time, end_time in audio_segments:
        try:
            audio = AudioSegment.from_wav(audio_path)
            start_ms = int(start_time * 1000)
            end_ms = int(end_time * 1000)
            clip = audio[start_ms:end_ms]
            combined_audio += clip
        except Exception as e:
            logger.error(f"Failed to extract audio clip from {audio_path}: {e}")

    if combined_audio:
        combined_clip_path = f"combined_clip_{uuid.uuid4().hex}.wav"
        combined_audio.export(combined_clip_path, format="wav")
        logger.info(f"Combined audio clip saved at {combined_clip_path}")
        return combined_clip_path
    else:
        logger.warning("No valid audio clips were processed.")
        return None


def process_query(query, top_n=5):
    results = retrieve_data(query, top_n)
    frames = results.get("frames", [])
    transcripts = results.get("transcripts", [])
    audio_segments = results.get("audio_segments", [])
    combined_audio_path = None
    if audio_segments:
        combined_audio_path = extract_audio_clips(audio_segments)

    return {
        "frames": frames,
        "transcripts": transcripts,
        "combined_audio_path": combined_audio_path,
    }


In [47]:
def process_media(uploaded_file=None, youtube_url=None, query=None):
    try:
        if not (uploaded_file or youtube_url):
            return {"error": "Please provide either a video file or a YouTube URL."}
        if youtube_url:
            video_path = download_youtube_video(youtube_url)
            if not video_path:
                return {"error": "Failed to download YouTube video."}
        else:
            video_path = f"{uuid.uuid4().hex}.mp4"
            with open(video_path, "wb") as f:
                f.write(uploaded_file)
        audio_path = extract_audio(video_path)
        if not audio_path:
            return {"error": "Audio extraction failed."}
        transcription = transcribe_audio(audio_path)
        if not transcription:
            return {"error": "Transcription failed."}
        frame_paths = extract_keyframes(video_path)
        store_in_vector_db(transcription, audio_path=audio_path, frame_paths=frame_paths)
        if query:
            results = retrieve_data(query)
            frames = results.get("frames", [])
            transcripts = results.get("transcripts", [])
            audio_segments = results.get("audio_segments", [])
            combined_clip_path = None
            if audio_segments:
                combined_clip_path = extract_audio_clips(audio_segments)

            return {
                "audio_path": audio_path,
                "transcripts": transcripts,
                "frames": frames,
                "combined_audio_path": combined_clip_path,
            }
        return {
            "audio_path": audio_path,
            "transcription": transcription,
            "frames": frame_paths,
            "transcripts": None,
            "combined_audio_path": None,
        }

    except Exception as e:
        logger.error(f"Processing failed: {e}")
        return {"error": f"Processing Error: {e}"}


In [48]:
# Gradio Interface function
def gradio_interface(video=None, youtube_url=None, query=None):
    results = process_media(uploaded_file=video, youtube_url=youtube_url, query=query)

    if "error" in results:
        return None, results["error"], []

    return (
        results.get("combined_audio_path", results.get("audio_path")),
        results.get("transcription", ""),
        results.get("frames", []),
    )


# Gradio Interface
interface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.File(label="📹 Upload Video (Optional)", type="binary"),
        gr.Textbox(label="🔗 YouTube URL (Optional)", placeholder="Enter YouTube URL"),
        gr.Textbox(label="🔍 Query", placeholder="Ask a question about the video content")
    ],
    outputs=[
        gr.Audio(label="🎵 Extracted Audio"),
        gr.Textbox(label="📝 Transcript"),
        gr.Gallery(label="🖼 Relevant Frames")
    ],
    title="🎓 Video RAG Tool",
    description="Upload a video or provide a YouTube URL. Extract audio, transcribe, and query video content!",
)

if __name__ == "__main__":
    interface.launch(debug=True, share=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://065413a34ea5cf41b6.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


100%|███████████████████████████████████████| 139M/139M [00:02<00:00, 56.0MiB/s]
  checkpoint = torch.load(fp, map_location=device)
ERROR:__main__:Processing failed: 'str' object has no attribute 'get'


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://065413a34ea5cf41b6.gradio.live
