In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content
!pip install -U sentence-transformers

In [None]:
from sentence_transformers import SentenceTransformer
sentences = ["Cô giáo đang ăn kem", "Chị gái đang thử món thịt dê"]

model = SentenceTransformer('keepitreal/vietnamese-sbert')
embeddings = model.encode(sentences)
print(embeddings)

In [None]:
from IPython.display import clear_output
!pip install faiss-gpu
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install translate
!pip install googletrans==3.1.0a0
!pip install langdetect
clear_output()

In [None]:
import numpy as np
import faiss
import glob
import json
import matplotlib.pyplot as plt
import os
import math
import clip
import torch
import pandas as pd
import re
from langdetect import detect

from tqdm import tqdm

## Translation

In [None]:
import googletrans
import translate

class Translation:
    def __init__(self, from_lang='vi', to_lang='en', mode='google'):
        # The class Translation is a wrapper for the two translation libraries, googletrans and translate. 
        self.__mode = mode
        self.__from_lang = from_lang
        self.__to_lang = to_lang

        if mode in 'googletrans':
            self.translator = googletrans.Translator()
        elif mode in 'translate':
            self.translator = translate.Translator(from_lang=from_lang,to_lang=to_lang)

    def preprocessing(self, text):
        """
        It takes a string as input, and returns a string with all the letters in lowercase
        :param text: The text to be processed
        :return: The text is being returned in lowercase.
        """
        return text.lower()

    def __call__(self, text):
        """
        The function takes in a text and preprocesses it before translation
        :param text: The text to be translated
        :return: The translated text.
        """
        text = self.preprocessing(text)
        return self.translator.translate(text) if self.__mode in 'translate' \
                else self.translator.translate(text, dest=self.__to_lang).text

## Faiss

In [None]:
class MyFaiss:
  def __init__(self, root_database: str, bin_file: str, json_path: str):    
    self.index = self.load_bin_file(bin_file)
    self.id2img_fps = self.load_json_file(json_path)

    self.translater = Translation()
    
    self.__device = "cuda" if torch.cuda.is_available() else "cpu"
    self.model, preprocess = clip.load("ViT-B/16", device=self.__device)
    
  def load_json_file(self, json_path: str):
      with open(json_path, 'r') as f:
        js = json.loads(f.read())

      return {int(k):v for k,v in js.items()}

  def load_bin_file(self, bin_file: str):
    return faiss.read_index(bin_file)

  def show_images(self, image_paths):
    fig = plt.figure(figsize=(15, 10))
    columns = int(math.sqrt(len(image_paths)))
    rows = int(np.ceil(len(image_paths)/columns))

    for i in range(1, columns*rows +1):
      img = plt.imread(image_paths[i - 1])
      ax = fig.add_subplot(rows, columns, i)
      ax.set_title('/'.join(image_paths[i - 1].split('/')[-3:]))

      plt.imshow(img)
      plt.axis("off")
      
    plt.show()

  def image_search(self, id_query, k):    
    query_feats = self.index.reconstruct(id_query).reshape(1,-1)

    scores, idx_image = self.index.search(query_feats, k=k)
    idx_image = idx_image.flatten()

    infos_query = list(map(self.id2img_fps.get, list(idx_image)))
    image_paths = [info['image_path'] for info in infos_query]
    
    # print(f"scores: {scores}")
    # print(f"idx: {idx_image}")
    # print(f"paths: {image_paths}")
    
    return scores, idx_image, infos_query, image_paths

  def text_search(self, text, k):
    if detect(text) == 'vi':
      text = self.translater(text)

    ###### TEXT FEATURES EXACTING ######
    text = clip.tokenize([text]).to(self.__device)  
    text_features = self.model.encode_text(text).cpu().detach().numpy().astype(np.float32)

    ###### SEARCHING #####
    scores, idx_image = self.index.search(text_features, k=k)
    idx_image = idx_image.flatten()

    ###### GET INFOS KEYFRAMES_ID ######
    infos_query = list(map(self.id2img_fps.get, list(idx_image)))
    image_paths = [info['image_path'] for info in infos_query]
    # lst_shot = [info['list_shot_id'] for info in infos_query]

    # print(f"scores: {scores}")
    # print(f"idx: {idx_image}")
    # print(f"paths: {image_paths}")

    return scores, idx_image, infos_query, image_paths

  def write_csv(self, infos_query, scores, des_path):
    video_names = []
    frame_ids = []
    score_ids = []

    ### GET INFOS SUBMIT ###
    for score, info in zip(scores.flatten().tolist(), infos_query):
      video_name = info['image_path'].split('/')[-2] + '.mp4'
      lst_frames = info['list_shot_id']

      for id_frame in lst_frames:
        video_names.append(video_name)
        frame_ids.append(id_frame)
        score_ids.append(score)

    ### FORMAT DATAFRAME ###
    check_files = {"video_names": video_names, "frame_ids": frame_ids, "scores": score_ids}
    df = pd.DataFrame(check_files)
    ###########################

    ### Merge csv exist file to faiss search information ###
    if os.path.exists(des_path):
      df_exist = pd.read_csv(des_path, header=None, names=["video_names", "frame_ids", "scores"])
      
      df.append(df_exist)

    ### Return DataFrame with duplicate rows removed ###
    df.drop_duplicates(subset=["video_names", "frame_ids"], inplace=True)

    ### Sort By Score ###
    df.sort_values(by=['scores'])

    ### Specifies up to 100 lines ###
    if len(df) < 99:
      df.to_csv(des_path, header=False, index=False)
      print(f"Save submit file to {des_path}")
    else:
      print('Exceed the allowed number of lines')

## SBERT

In [None]:
class BERTSearch(MyFaiss):
  def __init__(self, dict_bert_search='./keyframes_id_bert.json', bin_file='./faiss_bert.bin', mode='write'):
    if mode in 'search':
      self.model = SentenceTransformer('keepitreal/vietnamese-sbert')
      
      self.index = super().load_bin_file(bin_file)
      self.id2img_fps = super().load_json_file(dict_bert_search)

    else:
      pass

  def create_files(self, des_json:str, dict_support_model:str, des_bin:str):
    count = 0
    self.infos = []

    id2img_fps = super().load_json_file(dict_support_model)
    npy_paths = sorted(glob.glob('/content/drive/MyDrive/ASR_Vietnamese_T/Embed*/*/*.npy'))

    index = faiss.IndexFlatL2(768)

    for npy_path in tqdm(npy_paths):
      need_path = npy_path.split('/')[-1].replace('.npy','')

      for id, values in id2img_fps.items():
        image_path = values['image_path']
        list_shot_id = values['list_shot_id']
        start, end = int(list_shot_id[0]), int(list_shot_id[-1])

        check_path = image_path.split('/')[-2] + f"_{start}_{end}"

        if need_path == check_path:
          info = {
                  "video_path": '/'.join(image_path.split('/')[:-1]),
                  "list_shot_id": list_shot_id
                }

          self.infos.append(info)
          
          try:
            feat = np.load(npy_path)
          except:
            print(npy_path)

          #### ADD FAISS ####
          feat = feat.astype(np.float32).reshape(1,-1)
          index.add(feat)  

          #### Delete ID ####
          id2img_fps.pop(id) # Delete an element from a dictionary 
          
          count += 1

          break
              
    results = dict(enumerate(self.infos))
    
    ##### SAVE JSON FILE #####
    with open(des_json, 'w') as f:
      f.write(json.dumps(results))

    ##### SAVE BIN FILE #####
    faiss.write_index(index, des_bin)
    
    ##### Print Infos Save #####
    print(f'Saved {des_json}')
    print(f"Number of Index: {count}")
    print(f'Saved {des_bin}')

  def bert_search(self, text, k):
    ###### TEXT FEATURES EXACTING ######
    text = [text, ]
    text_features = model.encode(text)

    ###### SEARCHING #####
    scores, idx_image = self.index.search(text_features, k=k)
    idx_image = idx_image.flatten()

    infos_query = list(map(self.id2img_fps.get, list(idx_image)))
    image_paths = [info['image_path'] for info in infos_query]
    
    # print(f"scores: {scores}")
    # print(f"idx: {idx_image}")
    # print(f"paths: {image_paths}")
    
    return scores, idx_image, infos_query, image_paths

In [None]:
create_file = BERTSearch()
create_file.create_files(des_json='/content/drive/MyDrive/keyframes_id_bert.json', dict_support_model='/content/drive/MyDrive/Video_Retrieval/dict_support_model/dict_support_model_batch.json', des_bin='/content/drive/MyDrive/faiss_bert.bin')

In [None]:
mybert = BERTSearch(dict_bert_search='/content/drive/MyDrive/keyframes_id_bert.json', bin_file='/content/drive/MyDrive/faiss_bert.bin', mode='search')

text = 'Lũ lụt'

scores, idx_image, infos_query, image_paths = mybert.bert_search(text, k=10)
mybert.show_images(image_paths)