In [None]:
%load_ext autoreload
%autoreload 2

import os
from src.video_preprocessing.download_videos.youtube_download import preprocess_video
from src.video_preprocessing.scene_detection.scene_detect import detect_scenes
from src.video_preprocessing.download_videos.download_utils import (
    transcribe_audio_files,
    extract_and_store_audio,
)
from src.ocr.pytesseract_image_to_text import extract_text_from_image
from sentence_transformers import SentenceTransformer
from src.text_embedder.embedder import text_to_embedding_transformer

from src.llm.ollama_implementation.ollama_experiment import (
    prompt_llm_summary,
    generate_caption_using_llava,
    prompt_llm_extensive_summary,
)
from src.video_preprocessing.download_videos.download_utils import (
    transcription_to_text,
    create_metadata,
)

from PIL import Image

import pandas as pd
import torch
from loguru import logger
import pickle

from src.clip.clip_model import CLIPEmbeddingsModel

import tqdm
from pathlib import Path
import os

# Load Data

In [None]:
# Load csv file
filename = "test_frames_gt.csv"
df = pd.read_csv(filename)
df

In [None]:
# print pickle file 
with open("bio_3_3_th5.pickle", "rb") as file:
    data = pickle.load(file)

logger.info(f"new_pickle: {data}")

In [None]:
# set image paths for the experiments

clip_model = CLIPEmbeddingsModel()

extracted_data_path = [data[key]['img_path'] for key in data.keys() if
                       'img_path' in data[key]]

clip_model.img_paths = None

clip_model.img_paths = extracted_data_path

extracted_data_ocr_text = [data[key]['ocr_extracted_text'] for key in data.keys() if
                           'ocr_extracted_text' in data[key]]

logger.info(f"Extracted_data_ocr_text: {extracted_data_ocr_text}")

logger.info(f"extracted_data_path: {extracted_data_path}")

# Note: the data is correctly mapped 
print(extracted_data_path[0])
print(extracted_data_ocr_text[0])

In [None]:
# create dataframe to save results of the experiments
columns = ['Prompt', 'GT_Keyframe', 'Top_1', 'Top_2', 'Top_3']

df_test = pd.DataFrame(columns=columns)

df_ocr_only = df_test
df_ocr_lava = df_test
df_ocr_transcriptions = df_test

df_short_llm_summary = df_test
df_extensive_summary = df_test

df_clip_llm_summary = df_test
df_clip_extensive_summary = df_test

In [None]:
def add_to_df(prompt, gt_keyframe, result):
    # Create a new row with the provided data
    return {
        'Prompt': prompt,
        'GT_Keyframe': gt_keyframe,
        'Top_1': extract_keyframe_number(result[0]) if len(result) > 0 else None,
        'Top_2': extract_keyframe_number(result[1]) if len(result) > 1 else None,
        'Top_3': extract_keyframe_number(result[2]) if len(result) > 2 else None
    }

# ietrate over the dataframe and get the results
def get_results(df):
    for _, row in df.iterrows():
        logger.info(row['Prompt'])
        prompt = row['Prompt']
        gt_keyframe = row['GT_Keyframe']

        # Search for similar images
        output = clip_model.search_similar_images_top_3(prompt, gt_keyframe)
        res_row = add_to_df(prompt, gt_keyframe, output)
        rows.append(res_row)

    return rows

import os

def extract_keyframe_number(path):
    """
    Extracts the scene number from the given file path.

    Parameters:
    path (str): The full path of the file.

    Returns:
    str: The extracted scene number.
    """
    # Get filename without extension
    filename = os.path.splitext(os.path.basename(path))[0]
    
    # Extract '032' from filename
    scene_number = filename.split('-Scene-')[-1].split('-')[0]
    
    return scene_number

In [None]:
# Embedded with standard Tokenizer: Only OCR
clip_model.text_embeddings = None

logger.info(f"Embedded with standard Tokenizer: Only OCR")

extracted_data_ocr_text = [data[key]['ocr_extracted_text'] for key in data.keys() if
                           'ocr_extracted_text' in data[key]]

#logger.info(extracted_data_ocr_text)

# get the embedder model
embedder_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')

ocr_embeddings = text_to_embedding_transformer(extracted_data_ocr_text, embedder_model)

#logger.info('OCR Embeddings: ', ocr_embeddings)

clip_model.text_embeddings = ocr_embeddings

rows = get_results(df)

df_ocr_only = pd.DataFrame(rows, columns=['Prompt', 'GT_Keyframe', 'Top_1', 'Top_2', 'Top_3'])

# Save dataframe to CSV
df_ocr_only.to_csv('df_ocr_only.csv', index=False)

# LOAD MODEL 

In [None]:
# Embedded with standard Tokenizer: OCR * Transcriptions 
logger.info(f"Embedded with standard Tokenizer: OCR * Transcriptions")

clip_model.text_embeddings = None

extracted_data_ocr_text = [data[key]['ocr_extracted_text'] for key in data.keys() if
                           'ocr_extracted_text' in data[key]]

extracted_data_transcriptions = [data[key]['transcription'] for key in data.keys() if
                                 'transcription' in data[key]]

# concatenate ocr and transcriptions
concat_result = [a + ' ' + b for a, b in zip(extracted_data_ocr_text, extracted_data_transcriptions)]

# get the embedder model
embedder_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')

concat_result_embeddings = text_to_embedding_transformer(concat_result, embedder_model)

clip_model.text_embeddings = concat_result_embeddings

result = []

rows = get_results(df)

df_ocr_transcriptions = pd.DataFrame(rows, columns=['Prompt', 'GT_Keyframe', 'Top_1', 'Top_2', 'Top_3'])

# save on disk
df_ocr_transcriptions.to_csv('df_ocr_transcriptions.csv', index=False)

In [None]:
# Embedded with standard Tokenizer: OCR * LLAVA
logger.info(f"Embedded with standard Tokenizer: OCR * LLAVA")

clip_model.text_embeddings = None

clip_model.img_paths = None

clip_model.img_paths = extracted_data_path

extracted_data_ocr_text = [data[key]['ocr_extracted_text'] for key in data.keys() if
                           'ocr_extracted_text' in data[key]]

extracted_data_llava_result = [data[key]['llava_result'] for key in data.keys() if
                               'llava_result' in data[key]]

# concatenate ocr and transcriptions
concat_result = [a + ' ' + b for a, b in zip(extracted_data_ocr_text, extracted_data_llava_result)]

# get the embedder model
embedder_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')

concat_result_embeddings = text_to_embedding_transformer(concat_result, embedder_model)

clip_model.text_embeddings = concat_result_embeddings

result = []

rows = get_results(df)

df_ocr_lava = pd.DataFrame(rows, columns=['Prompt', 'GT_Keyframe', 'Top_1', 'Top_2', 'Top_3'])

# save on disk
df_ocr_lava.to_csv('df_ocr_lava.csv', index=False)

In [None]:
# Embedded with standard Tokenizer - short_llm_summary: OCR * Transcriptions * LLAVA 
# TODO: Need to get embedding with standard tokenizer for short_llm_summary


logger.info(f"Embedded with standard Tokenizer - clip_llm_summary: OCR * LLAVA * Transcriptions")

clip_model.text_embeddings = None

# Assuming that the pickle file has standard tokenizer embeddings 
extracted_data_text = [data[key]['ocr_extracted_text'] for key in data.keys() if
                       'ocr_extracted_text' in data[key]]

extracted_data_transcriptions = [data[key]['transcription'] for key in data.keys() if
                                 'transcription' in data[key]]

extracted_data_llava_result = [data[key]['llava_result'] for key in data.keys() if
                               'llava_result' in data[key]]




In [None]:
# print pickle file 
with open("data_standard_tokenizer.pickle", "rb") as file:
    data_std_tokenizer = pickle.load(file)

logger.info(f"new_pickle: {data_std_tokenizer}")

In [None]:
# Embedded with standard Tokenizer - extensive_summary: OCR * Transcriptions * LLAVA
# TODO: Need to get embedding with standard tokenizer for short_llm_summary

logger.info(f"Embedded with standard Tokenizer - extensive_summary : OCR * LLAVA * Transcriptions")

clip_model.text_embeddings = None

# Assuming that the pickle file has standard tokenizer embeddings 
extracted_data_text = [data[key]['ocr_extracted_text'] for key in data.keys() if
                       'ocr_extracted_text' in data[key]]

extracted_data_transcriptions = [data[key]['transcription'] for key in data.keys() if
                                 'transcription' in data[key]]

extracted_data_llava_result = [data[key]['llava_result'] for key in data.keys() if
                               'llava_result' in data[key]]


In [None]:
# Embedded with CLIP - clip_llm_summary: OCR * Transcriptions * LLAVA 
logger.info(f" Embedded with CLIP - clip_llm_summary: OCR * Transcriptions * LLAVA")

clip_model.text_embeddings = None

clip_model.img_paths = None

clip_model.img_paths = extracted_data_path

extracted_data_text = [data[key]['clip_text_embedding'] for key in data.keys() if
                       'clip_text_embedding' in data[key]]

clip_text_embeddings = [data[0] for data in extracted_data_text]

clip_model.text_embeddings = clip_text_embeddings

if isinstance(clip_model.text_embeddings, list):
    for i, text_embedding in enumerate(clip_model.text_embeddings):
        clip_model.text_embeddings[i] = torch.tensor(text_embedding)

# create one single torch for sim search 
clip_model.text_embeddings = torch.stack(clip_model.text_embeddings, dim=0)

for i, row in df.iterrows():
    logger.info(row['Prompt'])
    # get prompt 
    PROMPT = row['Prompt']
    GT = row['GT_Keyframe']
    # search for similar images
    clip_model.search_similar_images_top_3_clip(PROMPT, GT)
    df_clip_llm_summary['Prompt'] = PROMPT
    df_clip_llm_summary['GT_Keyframe'] = GT

for i, res in enumerate(result):
    df_clip_llm_summary[f'Top_{i + 1}'] = res

# save on disk
df_clip_llm_summary.to_csv('df_clip_llm_summary.csv', index=False)

In [None]:
# Embedded with CLIP - extensive_summary: OCR * Transcriptions * LLAVA 
# TODO: Generate Embeddings is not possible without the correct paths "/Users/magic-rabbit/" 
logger.info(f" Embedded with CLIP - extensive_summary: OCR * Transcriptions * LLAVA")

clip_model.text_embeddings = None

clip_model.img_paths = None

clip_model.img_paths = extracted_data_path

print(clip_model.img_paths)

extracted_data_extensive_summary = [data[key]['llm_long_summary'] for key in data.keys() if
                                    'llm_long_summary' in data[key]]
logger.info(f"Extracted_data llm_long_summary: {extracted_data_extensive_summary}")

embeddings = clip_model.generate_dataset_embeddings(extracted_data_extensive_summary)

clip_model.text_embeddings = embeddings

logger.info(f"Clip_model.text_embeddings: {clip_model.text_embeddings}")

if isinstance(clip_model.text_embeddings, list):
    for i, text_embedding in enumerate(clip_model.text_embeddings):
        clip_model.text_embeddings[i] = torch.tensor(text_embedding)

# create one single torch for sim search 
clip_model.text_embeddings = torch.stack(clip_model.text_embeddings, dim=0)

for i, row in df.iterrows():
    logger.info(row['Prompt'])
    # get prompt 
    PROMPT = row['Prompt']
    GT = row['GT_Keyframe']
    # search for similar images
    #clip_model.search_similar_images(PROMPT)
    clip_model.search_similar_images_top_3_clip(PROMPT, GT)
    df_clip_extensive_summary['Prompt'] = PROMPT
    df_clip_extensive_summary['GT_Keyframe'] = GT

for i, res in enumerate(result):
    df_clip_extensive_summary[f'Top_{i + 1}'] = res

# save on disk
df_clip_extensive_summary.to_csv('df_clip_extensive_summary.csv', index=False)