# Данный ноутбук посвящен тестированию Бейзлайна



In [1]:
import os
import pathlib

ROOT_PATH = pathlib.Path(__name__).resolve().parent.parent
print(ROOT_PATH)
RANDOM_SEED = 42

/Users/andrey/PycharmProjects/vector-search-hse


In [2]:
os.chdir(ROOT_PATH)
os.getcwd()

'/Users/andrey/PycharmProjects/vector-search-hse'

In [71]:
import pickle

with open('data/index.pkl', 'rb') as handle:
    index = pickle.load(handle)

with open('data/metadata.pkl', 'rb') as handle:
    meta = pickle.load(handle)

with open('data/thumbnails.pkl', 'rb') as handle:
    thumbnails_meta = pickle.load(handle)

In [93]:

from typing import Any, List, Tuple, Dict
import clip
from numpy.typing import NDArray
import numpy as np

import torch


def brute_force_query_torch(X, x, certainty_threshold):
    sims = (x @ X.t()).squeeze(0)  # shape: [N]

    mask = sims >= certainty_threshold
    filtered_indices = torch.nonzero(mask).squeeze(1)  # индексы в X
    filtered_sims = sims[filtered_indices]

    # Сортировка по убыванию
    sorted_sims, order = torch.sort(filtered_sims, descending=True)
    sorted_indices = filtered_indices[order]

    return sorted_indices, sorted_sims.float()


from pydantic import BaseModel


class VideoDescription(BaseModel):
    path: str
    video_id: int
    frame_num: int
    fps: int


class UsedVideo(BaseModel):
    start_pos: float
    end_pos: float
    score: float


class LocalSearchEngine:

    def __init__(
        self,
        index: List[NDArray],
        meta: List[Any],
        thumbnails_meta: Dict[str, Any],
        device: str,
    ):
        self.model, self.preprocessor = clip.load(
            'ViT-B/32',
            device=device,
        )
        self.dataset = torch.tensor(np.array(index))
        self.thumbnails_meta = thumbnails_meta
        self.meta = meta
        self.all_videos = sorted(set([m[0] for m in meta]))
        self.video_to_int = {v: i for i, v in enumerate(self.all_videos)}
        self.int_to_video = {i: v for v, i in self.video_to_int.items()}
        self.meta_video_ids = torch.tensor(
            [self.video_to_int[m[0]] for m in meta],
            device='cpu',
            dtype=torch.int32,
        )
        self.meta_frame_nums = torch.tensor(
            [m[1] for m in meta],
            device='cpu',
            dtype=torch.int32,
        )

    def encode_text(
        self,
        text: str
    ) -> torch.Tensor:
        with torch.no_grad():
            data = self.model.encode_text(clip.tokenize([text]))
        # data = torch.sign(data) * torch.pow(torch.abs(data), 0.25)
        data /= torch.linalg.norm(data)
        return data

    def encode_image(
        self,
        file: NDArray,
    ) -> torch.Tensor:
        with torch.no_grad():
            data = self.model.encode_image(self.preprocessor(file).unsqueeze(0))
        data /= torch.linalg.norm(data)
        return data

    def query_frames(
        self,
        x,
        threshold: float,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        idxs, certs = brute_force_query_torch(self.dataset, x, threshold)
        certs = certs.cpu()
        return idxs, certs

    def query_videos_by_tensor(
        self,
        x: torch.Tensor,
        frame_threshold: float,
        percentile: float,
        video_threshold: float,
    ) -> Tuple[List[Tuple[str, str, int]], Dict[str, Tuple[int, int, float]]]:
        idxs, certs = brute_force_query_torch(self.dataset, x, frame_threshold)
        certs = certs.cpu()
        video_idxs = self.meta_video_ids[idxs]
        video_frames = self.meta_frame_nums[idxs]

        video_descriptions = []
        used_videos = {}

        vals, order = torch.sort(video_idxs)
        targets = torch.tensor(
            [self.video_to_int[v] for v in self.all_videos],
            device=video_idxs.device,
        )
        order = order.cpu()
        left = torch.bucketize(targets, vals, right=False).cpu()
        right = torch.bucketize(targets, vals, right=True).cpu()

        for i, video in enumerate(self.all_videos):
            if left[i] == right[i]:
                continue
            args = order[left[i]:right[i]]
            cert_ = certs[order[left[i] + int((right[i] - left[i]) * (1 - percentile))]]
            if cert_ < video_threshold:
                continue
            subset = video_frames[args]
            start_ = torch.min(subset)
            end_ = torch.max(subset)
            max_frame = subset[0]
            used_videos[video] = UsedVideo(start_pos=start_.item(), end_pos=end_.item(), score=cert_.item())
            # frame_request = f'/image?video={self.video_to_int[video]}&frame_number={max_frame}'
            # video_descriptions.append((video, frame_request, self.thumbnails_meta[video][1]))
            video_description = VideoDescription(
                path=video,
                video_id=self.video_to_int[video],
                frame_num=max_frame,
                fps=self.thumbnails_meta[video][1],
            )
            video_descriptions.append(video_description)
            

        video_descriptions = sorted(video_descriptions, key=lambda x: used_videos[x.path].score, reverse=True)[:100]

        return video_descriptions, used_videos

In [94]:
engine = LocalSearchEngine(index, meta, thumbnails_meta, device='cpu')

In [95]:
x = engine.encode_text('beard')

video_desc, videos = engine.query_videos_by_tensor(x, 0.2, 0.8, 0.2)

In [98]:
videos

{'data/video/FullSizeRender.MOV': UsedVideo(start_pos=44.0, end_pos=748.0, score=0.21421445906162262),
 'data/video/IMG_0703.MOV': UsedVideo(start_pos=480.0, end_pos=1890.0, score=0.21215511858463287),
 'data/video/IMG_0704.MOV': UsedVideo(start_pos=60.0, end_pos=1680.0, score=0.21190747618675232),
 'data/video/IMG_0705.MOV': UsedVideo(start_pos=30.0, end_pos=6120.0, score=0.23442897200584412),
 'data/video/IMG_0706.MOV': UsedVideo(start_pos=30.0, end_pos=630.0, score=0.26028674840927124),
 'data/video/IMG_0707.MOV': UsedVideo(start_pos=30.0, end_pos=840.0, score=0.2433510273694992),
 'data/video/IMG_0708.MOV': UsedVideo(start_pos=240.0, end_pos=2280.0, score=0.20236071944236755),
 'data/video/IMG_0710.MOV': UsedVideo(start_pos=960.0, end_pos=1410.0, score=0.20281651616096497),
 'data/video/IMG_0711.MOV': UsedVideo(start_pos=30.0, end_pos=330.0, score=0.21214191615581512),
 'data/video/IMG_0712.MOV': UsedVideo(start_pos=180.0, end_pos=2910.0, score=0.2042262703180313),
 'data/video/IMG