In [None]:
!mkdir database
%cd database
!gdown 1-w5iuyMYw4sBh7zsWfhCJTEyV2fhisyw
from tqdm import tqdm
import zipfile
with zipfile.ZipFile('data.zip', 'r') as zip_ref:
    for file in tqdm(zip_ref.namelist(), desc='Unzipping'):
        zip_ref.extract(file)

!rm data.zip
%cd ..

/content/database
Downloading...
From (original): https://drive.google.com/uc?id=1-w5iuyMYw4sBh7zsWfhCJTEyV2fhisyw
From (redirected): https://drive.google.com/uc?id=1-w5iuyMYw4sBh7zsWfhCJTEyV2fhisyw&confirm=t&uuid=d97571e2-07d2-492a-be05-67997a120c61
To: /content/database/data.zip
100% 3.81G/3.81G [00:53<00:00, 71.3MB/s]


Unzipping: 100%|██████████| 285497/285497 [01:00<00:00, 4745.68it/s]


/content


In [None]:
!pip install -q -r requirements.txt

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/981.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m972.8/981.5 kB[0m [31m31.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m981.5/981.5 kB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m235.4/235.4 kB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.4 MB/s[0m eta [36m0:

In [None]:
%%writefile app.py
import streamlit as st
import os
import glob
import torch
from PIL import Image
import faiss
import numpy as np
import pandas as pd
import json
import math
import googletrans
import translate
import underthesea
from pyvi import ViUtils, ViTokenizer
from difflib import SequenceMatcher
from langdetect import detect
from tqdm import tqdm
import ast
import re
from fuzzywuzzy import process as fuwu_process, fuzz as fuwu_fuzz
from rapidfuzz import process as rafu_process, fuzz as rafu_fuzz

st.set_page_config(layout="wide")

# -------------------------- #
#        Session State       #
# -------------------------- #

# Initialize session state variables
if 'expander_content' not in st.session_state:
    st.session_state['expander_content'] = None

if 'copy_to_clipboard' not in st.session_state:
    st.session_state['copy_to_clipboard'] = None

if 'selected_images' not in st.session_state:
    st.session_state['selected_images'] = {}

if 'checkbox_states' not in st.session_state:
    st.session_state['checkbox_states'] = {}

if 'search_results' not in st.session_state:
    st.session_state['search_results'] = None

# -------------------------- #
#    Data Loading Functions   #
# -------------------------- #

@st.cache_resource
def load_data_and_models():
    # Load keyframes
    lst_keyframes = glob.glob('database/s_optimized_keyframes/*.webp')
    lst_keyframes.sort()

    id2img_fps = {i: img_path for i, img_path in enumerate(lst_keyframes)}
    print(f"Total keyframes loaded: {len(id2img_fps)}")

    # Load video URLs
    with open('database/vid_url.json', 'r') as f:
        vid_url = json.load(f)
    print(f"Total videos loaded: {len(vid_url)}")

    # Load FPS data
    with open('database/url_fps.json', 'r') as f:
        url_fps = json.load(f)
    print(f"Total FPS data loaded: {len(url_fps)}")

    keyframes = pd.read_csv('database/keyframes.csv', sep='|', index_col=False)

    # Load models
    from transformers import CLIPModel, CLIPImageProcessor, CLIPTokenizer
    from lavis.models import load_model_and_preprocess

    clip_model = [
        ("openai/clip-vit-base-patch32", 'clipB32'),
        # Add more models if needed
    ]

    blip_model = ("blip2_feature_extractor", "blip2fe")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Device in use: {device}')

    # Load BLIP model and preprocessors
    model, vis_processors, txt_processors = load_model_and_preprocess(
        name=blip_model[0],
        model_type="pretrain",
        is_eval=True,
        device=device
    )

    # Initialize CLIP models
    models = [CLIPModel.from_pretrained(model_name).to(device) for model_name, _ in clip_model]
    models.append(model.to(device))  # Append BLIP model

    # Initialize image processors
    image_processors = [CLIPImageProcessor.from_pretrained(model_name) for model_name, _ in clip_model]
    image_processors.append(vis_processors)  # Append BLIP image processor

    # Initialize text processors
    text_processors = [CLIPTokenizer.from_pretrained(model_name) for model_name, _ in clip_model]
    text_processors.append(txt_processors)  # Append BLIP text processor

    return lst_keyframes, id2img_fps, vid_url, url_fps, keyframes, device, models, image_processors, text_processors, clip_model, blip_model

# Load data and models
with st.spinner('Loading database and models...'):
    lst_keyframes, id2img_fps, vid_url, url_fps, keyframes, device, models, image_processors, text_processors, clip_model, blip_model = load_data_and_models()
# st.success('All data and models have been loaded successfully!')

# -------------------------- #
#      Helper Classes        #
# -------------------------- #

class Translation:
    def __init__(self, from_lang='vi', to_lang='en', mode='googletrans'):
        self.__mode = mode
        self.__from_lang = from_lang
        self.__to_lang = to_lang

        if mode == 'googletrans':
            self.translator = googletrans.Translator()
        elif mode == 'translate':
            self.translator = translate.Translator(from_lang=from_lang, to_lang=to_lang)
        else:
            raise ValueError(f"Unsupported translation mode: {mode}")

    def preprocessing(self, text):
        return text.lower()

    def __call__(self, text):
        text = self.preprocessing(text)
        if self.__mode == 'translate':
            return self.translator.translate(text)
        else:
            return self.translator.translate(text, dest=self.__to_lang).text

class Text_Preprocessing:
    def __init__(self, stopwords_path='./dict/vietnamese-stopwords-dash.txt'):
        with open(stopwords_path, 'rb') as f:
            lines = f.readlines()
        self.stop_words = [line.decode('utf8').strip() for line in lines]

    def remove_stopwords(self, text):
        text = ViTokenizer.tokenize(text)
        return " ".join([w for w in text.split() if w not in self.stop_words])

    def lowercasing(self, text):
        return text.lower()

    def text_norm(self, text):
        return underthesea.text_normalize(text)

    def text_classify(self, text):
        return underthesea.classify(text)

    def __call__(self, text):
        text = self.lowercasing(text)
        text = self.remove_stopwords(text)
        text = self.text_norm(text)
        categories = self.text_classify(text)
        return categories

@st.cache_data
def load_stopwords():
    return Text_Preprocessing()

class Myfaiss:
    def __init__(self, bin_file: str, id2img_fps, device, model, text_processor, image_processor, vid_url, url_fps):
        self.index = self.load_bin_file(bin_file)
        self.id2img_fps = id2img_fps
        self.device = device
        self.model = model
        self.text_processor = text_processor
        self.image_processor = image_processor
        self.vid_url = vid_url
        self.url_fps = url_fps

    def load_bin_file(self, bin_file: str):
        return faiss.read_index(bin_file)

    def show_images(self, image_paths):
        num_cols = 5  # Adjust as needed
        rows = [image_paths[i:i + num_cols] for i in range(0, len(image_paths), num_cols)]

        for row_idx, row in enumerate(rows):
            cols = st.columns(len(row))
            for idx, img_path in enumerate(row):
                vid_id = os.path.basename(img_path).split('.')[0]
                vid_name, frame = vid_id.split('-')
                fps = self.url_fps.get(self.vid_url.get(vid_name, ""), 1)
                timestamp = int(int(frame) / fps)
                video_url = f"{self.vid_url.get(vid_name, '')}&t={timestamp}"

                # Ensure image_id is unique by adding row and column indices
                image_id = f"{vid_name}_{frame}_{row_idx}_{idx}"
                checkbox_key = f"checkbox_{image_id}"

                # Initialize checkbox state if not present
                if image_id not in st.session_state['checkbox_states']:
                    st.session_state['checkbox_states'][image_id] = False

                with cols[idx]:
                    try:
                        st.image(img_path, width=150)  # Set fixed width
                    except Exception as e:
                        st.error(f"Error loading image: {e}")

                    # Center the button and checkbox using HTML and CSS
                    st.markdown(
                        f"""
                        <div style="display:flex; flex-direction: column; align-items: center;">
                        """,
                        unsafe_allow_html=True,
                    )

                    # Define the callback function
                    def button_callback(vid_name=vid_name, frame=frame, video_url=video_url, button_label=f"{vid_name}, {frame}"):
                        st.session_state['expander_content'] = (vid_name, frame, video_url)
                        st.session_state['copy_to_clipboard'] = button_label

                    st.button(
                        f"{vid_name}, {frame}",
                        key=f"btn_{image_id}",
                        on_click=button_callback
                    )

                    # Checkbox for selection
                    selected = st.checkbox("Select", key=checkbox_key)
                    st.session_state['checkbox_states'][image_id] = selected

                    # Update selected_images
                    if selected:
                        st.session_state['selected_images'][image_id] = (vid_name, frame, img_path)
                    else:
                        st.session_state['selected_images'].pop(image_id, None)

                    st.markdown(
                        "</div>",
                        unsafe_allow_html=True,
                    )

    def image_search(self, id_query, k, bin_file):
        query_feats = self.index.reconstruct(id_query).reshape(1, -1)
        scores, idx_image = self.index.search(query_feats, k=k)
        idx_image = idx_image.flatten()

        infos_query = list(map(self.id2img_fps.get, list(idx_image)))
        image_paths = [info for info in infos_query]

        return scores, idx_image, infos_query, image_paths

    def text_search(self, text, k):
        translater = Translation()
        if detect(text) == 'vi':
            text = translater(text)

        ###### TEXT FEATURES EXTRACTING ######
        if self.model == models[-1]:  # Assuming BLIP model is the last
            text_input = self.text_processor["eval"](text)
            sample = {"text_input": [text_input]}
            features_text = self.model.extract_features(sample, mode="text")
            text_features = features_text.text_embeds_proj[:, 0, :].detach().cpu().numpy()
        else:
            inputs = self.text_processor(text, return_tensors="pt").to(self.device)
            with torch.no_grad():
                text_features = self.model.get_text_features(**inputs).cpu().detach().numpy().astype(np.float32)

        ###### SEARCHING #####
        scores, idx_image = self.index.search(text_features, k=k)
        idx_image = idx_image.flatten()

        ###### GET INFOS KEYFRAMES_ID ######
        infos_query = list(map(self.id2img_fps.get, list(idx_image)))
        image_paths = [info for info in infos_query]

        return scores, idx_image, infos_query, image_paths

    def image_similarity_search(self, image_path, k, online=False):
        if online:
            import requests
            img = Image.open(requests.get(image_path, stream=True).raw).convert('RGB')
        else:
            img = Image.open(image_path)

        if self.model == models[-1]:  # Assuming BLIP model is the last
            image = self.image_processor["eval"](img).unsqueeze(0).to(self.device)
            sample = {"image": image}
            with torch.no_grad():
                features_image = self.model.extract_features(sample, mode="image")
                image_features = features_image.image_embeds_proj[:, 0, :].detach().cpu().numpy()
        else:
            inputs = self.image_processor(images=img, return_tensors="pt").to(self.device)
            with torch.no_grad():
                image_features = self.model.get_image_features(**inputs).detach().cpu().numpy()

        scores, idx_image = self.index.search(image_features, k=k)
        idx_image = idx_image.flatten()
        infos_query = list(map(self.id2img_fps.get, list(idx_image)))
        image_paths = [info for info in infos_query]
        return scores, idx_image, infos_query, image_paths

# -------------------------- #
#        FAISS Search        #
# -------------------------- #

# Define root features path
root_features = 'database/features'

# Prepare bin paths
bin_paths = [os.path.join(root_features, bin_name + '.bin') for _, bin_name in clip_model]
bin_paths.append(os.path.join(root_features, blip_model[1] + '.bin'))

# Ensure that the number of bin_paths matches number of models
if len(bin_paths) != len(models):
    st.error("Number of bin paths does not match number of models.")
    st.stop()

# Initialize faiss_search
faiss_search = [Myfaiss(bin_paths[i], id2img_fps, device, models[i], text_processors[i], image_processors[i], vid_url, url_fps) for i in range(len(models))]

# -------------------------- #
#           Main             #
# -------------------------- #

def main():
    # Title with gradient and centered
    st.markdown("""
        <h1 style='text-align: center; background: linear-gradient(to right, blue, purple); -webkit-background-clip: text; color: transparent;'>Image Retrieval System - AIC2024</h1>
        """, unsafe_allow_html=True)

    col1, col2 = st.columns([1, 2])

    with col1:
        # Video Details Expander
        video_details_expander = st.expander("Video details")
        with video_details_expander:
            if st.session_state['expander_content']:
                vid_name, frame, video_url = st.session_state['expander_content']
                st.video(video_url)
                st.write(f"**Video ID:** {vid_name}, {frame}")
                st.write(f"**Video URL:** {video_url}")

                # Checkbox in video details
                image_id = f"{vid_name}_{frame}"
                checkbox_key = f"checkbox_{image_id}"
                if image_id not in st.session_state['checkbox_states']:
                    st.session_state['checkbox_states'][image_id] = False

                selected = st.checkbox("Select", key=checkbox_key)
                st.session_state['checkbox_states'][image_id] = selected

                # Update selected_images
                if selected:
                    # Find img_path from id2img_fps
                    img_path = None
                    for path in id2img_fps.values():
                        if image_id in path:
                            img_path = path
                            break
                    if img_path:
                        st.session_state['selected_images'][image_id] = (vid_name, frame, img_path)
                else:
                    st.session_state['selected_images'].pop(image_id, None)
            else:
                st.write("No video selected.")

        # Selected Images Expander
        selected_images_expander = st.expander("Selected image(s)")
        with selected_images_expander:
            selected_images = st.session_state['selected_images'].values()
            if selected_images:
                for vid_name, frame, img_path in selected_images:
                    st.write(f"**{vid_name}, {frame}**")
            else:
                st.write("No images selected.")

    with col2:
        # ------------------------------ #
        #    Added Slider and Checkbox   #
        # ------------------------------ #

        # Slider for K_neighbors
        K_neighbors = st.slider(
            "Number of Neighbors (K_neighbors)",
            min_value=10,
            max_value=1000,
            value=100,
            step=10,
            help="Adjust the number of nearest neighbors to retrieve."
        )

        # Checkbox for high_performance
        if st.checkbox(
            "Use High Performance Mode",
            value=False,
            help="Toggle to use high performance search mode."
        ):
            high_performance = 1
        else:
            high_performance = 0

        # Search bar
        text_query = st.text_input("Enter a text query, a frame or an image url", placeholder='Eg: "Cảnh quay một chiếc thuyền cứu hộ đi trên băng..." || "L01_V001, 1" || "https://bitexco.c...scaled.jpg"', key="text_query")
        search_clicked = st.button("Search", key="search_button")

        if search_clicked and text_query:
            # Determine the model index based on high_performance
            if high_performance and len(faiss_search) > 1:
                search_index = 1  # Assuming the second model is for high performance
            else:
                search_index = 0  # Default to the first model

            with st.spinner('Performing search...'):
                if "https://" in text_query:
                    scores, idx_image, infos_query, image_paths = faiss_search[search_index].image_similarity_search(text_query, k=K_neighbors, online=True)
                elif re.match(r'^L\d{2}_V\d{3},\s*(\d|[1-9]\d{0,4})$', text_query):
                    ROOT_IMG = "database/s_optimized_keyframes"

                    input_vid_name, input_frame = text_query.split(', ')
                    input_frame = int(input_frame)
                    filtered_df = keyframes[(keyframes['vid_name'] == input_vid_name) & (keyframes['shot'].apply(lambda x: eval(x)[0] <= input_frame <= eval(x)[1]))]
                    closest_row = filtered_df.iloc[(filtered_df['frame'] - input_frame).abs().argsort()[:1]]
                    text_query = f"{closest_row['vid_name'].values[0]}, {str(closest_row['frame'].values[0]).zfill(5)}"

                    image_path = os.path.join(ROOT_IMG, '-'.join(text_query.split(', ')) + '.webp')
                    print(image_path)
                    scores, idx_image, infos_query, image_paths = faiss_search[search_index].image_similarity_search(image_path, k=K_neighbors)
                else:
                    # Perform the search
                    scores, idx_image, infos_query, image_paths = faiss_search[search_index].text_search(text_query, k=K_neighbors)
                # Store results in session_state
                st.session_state['search_results'] = image_paths
                # Reset checkbox states for new search
                st.session_state['checkbox_states'] = {}
                # Reset selected_images for new search
                st.session_state['selected_images'] = {}


        # Display images from session_state if available
        if st.session_state.get('search_results'):
            # Determine the model index for displaying based on high_performance
            if high_performance and len(faiss_search) > 1:
                display_index = 1
            else:
                display_index = 0

            with st.spinner('Loading images...'):
                faiss_search[display_index].show_images(st.session_state['search_results'])

    # Handle copying to clipboard
    if st.session_state['copy_to_clipboard']:
        js_button_label = json.dumps(st.session_state['copy_to_clipboard'])
        js_code = f"""
        <script>
        navigator.clipboard.writeText({js_button_label});
        </script>
        """
        st.markdown(js_code, unsafe_allow_html=True)
        # Reset copy_to_clipboard after copying
        st.session_state['copy_to_clipboard'] = None

if __name__ == "__main__":
    main()


Overwriting app.py


# Load model

In [None]:
!npm install -q localtunnel
!pip install -q streamlit

[K[?25h
added 22 packages, and audited 23 packages in 1s

3 packages are looking for funding
  run `npm fund` for details

2 [33m[1mmoderate[22m[39m severity vulnerabilities

To address all issues, run:
  npm audit fix

Run `npm audit` for details.


In [None]:
print(" password")
print("     |")
print("     V")
!wget -q -O - ipv4.icanhazip.com
!streamlit run app.py & npx localtunnel --port 8501 --subdomain aicretrievalsystem

 password
     |
     V
35.203.187.102

Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8501[0m
[34m  Network URL: [0m[1mhttp://172.28.0.12:8501[0m
[34m  External URL: [0m[1mhttp://35.203.187.102:8501[0m
[0m
your url is: https://aicretrievalsystem.loca.lt
Total keyframes loaded: 285492
Total videos loaded: 726
Total FPS data loaded: 726
2024-10-12 14:47:04.591080: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-12 14:47:04.851632: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-12 14:47:04.923923: E external/local_xla/xla/stream_executor/