In [1]:
import os
import numpy as np
import torch
import openai
import argparse
import pprint

from tqdm import tqdm
from typing import List, cast
from dotenv import load_dotenv

from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from PIL import Image

from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device

from scenedetect import open_video,SceneManager,StatsManager, save_images
from scenedetect.detectors import ContentDetector


from llama_index.core import SimpleDirectoryReader
from llama_index.multi_modal_llms.openai import OpenAIMultiModal

In [5]:
filename = 'input_vid.mp4'
query = "What activities are the astronauts performing?"

In [6]:

def video_to_images(video_path, output_folder):
    output = output_folder
    video = open_video(video_path)

    scene_manager = SceneManager(stats_manager=StatsManager())
    scene_manager.add_detector(ContentDetector())
    scene_manager.detect_scenes(video)

    scene_list = scene_manager.get_scene_list()
    for index, scene in enumerate(scene_list):
        padded_index = f'{index:03}'
        save_images(scene_list=[scene], 
                    video=video,
                    image_extension='png',
                    image_name_template=f'$VIDEO_NAME-Scene-{padded_index}',
                    output_dir=output,
                    num_images=1)

def retrieve(output_folder,query):
    Embedding_model_name = "vidore/colpali-v1.2"
    Embedding_model = ColPali.from_pretrained(
        Embedding_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cuda:0",  # or "mps" if on Apple Silicon
    ).eval()

    processor = ColPaliProcessor.from_pretrained(Embedding_model_name)    
    # フォルダ内のPNGファイルをファイル名順に取得
    images = []
    png_files = sorted([filename for filename in os.listdir(output_folder) if filename.endswith('.png')])
    
    # 画像を開いてリストに追加
    for filename in png_files:
        image_path = os.path.join(output_folder, filename)
        images.append(Image.open(image_path))

    # Run inference - docs
    dataloader = DataLoader(
        dataset=ListDataset[str](images),
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: processor.process_images(x),
    )
    ds: List[torch.Tensor] = []
    for batch_doc in tqdm(dataloader):
        with torch.no_grad():
            batch_doc = {k: v.to(Embedding_model.device) for k, v in batch_doc.items()}
            embeddings_doc = Embedding_model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))

    # Run inference - queries
    dataloader = DataLoader(
        dataset=ListDataset[str]([query]),
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: processor.process_queries(x),
    )
    
    qs: List[torch.Tensor] = []
    for batch_query in dataloader:
        with torch.no_grad():
            batch_query = {k: v.to(Embedding_model.device) for k, v in batch_query.items()}
            embeddings_query = Embedding_model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
    scores = processor.score(qs, ds).cpu().numpy()
    idx_top_n = scores.argsort(axis=1)[:, -5:][:, ::-1]
    
    return idx_top_n



def run_llama(txt, query, output_folder, qa_tmpl_str, idx_top_n):
    
    # 入力画像
    img = []
    for i in idx_top_n[0]:
        img.append(f"{output_folder}input_vid-Scene-{str(i).zfill(3)}.png")

    # クエリ
    query_str = query

    # ドキュメント
    image_documents = SimpleDirectoryReader(
        input_dir=output_folder, input_files=img
    ).load_data()
    context_str = "".join(txt)

    # LLM読み込み
    openai_mm_llm = OpenAIMultiModal(
        model="gpt-4o", api_key=os.getenv('OPENAI_API_KEY'), max_new_tokens=1500
    )

    # 回答文を生成
    response_1 = openai_mm_llm.complete(
        prompt=qa_tmpl_str.format(
            context_str=context_str, query_str=query_str, ),
        image_documents=image_documents,
    )
    print(response_1.text)
    
def generate_answer(output_folder, query, idx_top_n):
    load_dotenv()
    openai.api_key = os.getenv('OPENAI_API_KEY')
    # テキスト情報をまとめる
    txt = ["As I look back on the mission that we've had here on the International Space Station, I'm proud to have been a part of much of the science activities that happened over the last two months. I didn't think I would do another spacewalk and to now have the chance to have done four more was just icing on the cake for a wonderful mission. The 10th one, do you like the first one? No, a little more comfortable. It's hard to put into words just what it was like to be a part of this expedition, the Expedition 63. It'll be kind of a memory that will last a lifetime for me. It's been a true honor. Try and space X, Undock sequence commanded. The thrusters looking good. The hardest part was getting us launched, but the most important part is bringing us home. I've been trying that day. We love you. Hurry home for weeks and don't get my dog. Slash down. Welcome back to Planet Earth and thanks for flying SpaceX. We're literally on our own. Space dads are back on Earth after a 19-hour return journey from space. The Earth is a very important part of the planet. The Earth is a very important part of the planet. The Earth is a very important part of the planet. The Earth is a very important part of the planet. The Earth is a very important part of the planet."]    
    
    qa_tmpl_str = (
        """
     Given the provided information, including relevant images and retrieved context from the video, \
     accurately and precisely answer the query without any additional prior knowledge.\n"
        "Please ensure honesty and responsibility, refraining from any racist or sexist remarks.\n"
        "---------------------\n"
        "Context: {context_str}\n"
        "---------------------\n"
        "Query: {query_str}\n"
        "Answer: "
    """
    )
    run_llama(txt, query, output_folder, qa_tmpl_str, idx_top_n)
    
def main(filename,query):
    video_path = './video/' + filename
    output_folder = "./img/" + filename + '/'
    
    video_to_images(video_path, output_folder)

    idx_top_n = retrieve(output_folder, query)
    print(idx_top_n)
    generate_answer(output_folder, query, idx_top_n)
    


In [7]:
main(filename,query)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.04it/s]


[[12 20 11  5  8]]
The astronauts are performing several activities, including:

1. Conducting science activities on the International Space Station.
2. Participating in spacewalks, with one astronaut completing four additional spacewalks.
3. Preparing for and executing the undocking sequence from the ISS.
4. Returning to Earth after a 19-hour journey from space.
