In [24]:
#| default_exp model

In [25]:
#| export
import clip
import torch
import tempfile
import numpy as np
import pandas as pd
from os.path import basename, splitext

from PIL import Image

In [26]:
#| export
class Model:

    _instance = None

    def __new__(cls, *args, **kwargs):
        if not cls._instance:
            cls._instance = super().__new__(cls, *args, **kwargs)
        return cls._instance
    
    def __init__(self) :
        self._model, self._preprocess = clip.load("ViT-B/32")
        self._model.cpu().eval()
        self._df = pd.read_pickle("../../data/runtime/index.pkl")

    def get_image_features(self, images):
        torch_images = [self._preprocess(Image.fromarray(image)) for image in images]
        batch = torch.tensor(np.stack(torch_images)).cpu()

        with torch.no_grad():
            image_features = self._model.encode_image(batch)

        image_features /= image_features.norm(dim = -1, keepdim = True)
        return image_features
    
    def query(self, text:str, k=3):
        images = [Image.open("../../data/runtime/frames/" + file_name) for file_name in self._df["preview_name"]]
        images_prep = [self._preprocess(image) for image in images]
        image_input = torch.tensor(np.stack(images_prep)).cpu()
        text_tokens = clip.tokenize(["This is " + desc for desc in [text]]).cpu()

        with torch.no_grad():
            image_features = self._model.encode_image(image_input).float()
            text_features = self._model.encode_text(text_tokens).float()

        image_features /= image_features.norm(dim = -1, keepdim = True)
        text_features /= text_features.norm(dim = -1, keepdim = True)
        similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T

        indices = np.flip(np.argpartition(similarity[0], -k)[-k:])
        preview_names = self._df["preview_name"].iloc[indices]

        matches = []
        
        for i, preview_name in zip(indices, preview_names):
            matches.append(("../../data/runtime/fragments/" + splitext(basename(preview_name))[0] + ".mp4", similarity[0][i]))

        return matches

In [27]:
pd.read_pickle("../../data/runtime/index.pkl")

Unnamed: 0,video_name,preview_name,start,end
0,unicorn,unicorn_0_333.png,0,333
1,video0,video0_0_52.png,0,52
2,video0,video0_52_175.png,52,175
3,video0,video0_175_286.png,175,286
4,video0,video0_286_301.png,286,301
5,video1,video1_0_106.png,0,106
6,video1,video1_106_194.png,106,194
7,video1,video1_194_473.png,194,473
8,video2,video2_0_212.png,0,212
9,video2,video2_212_341.png,212,341


In [28]:
model = Model()

In [29]:
matches = model.query("anime girl")
matches[0][0]

'../../data/runtime/fragments/video3_98_165.mp4'

In [30]:
for match in matches:
    print(match)

('../../data/runtime/fragments/video3_98_165.mp4', np.float32(0.21989855))
('../../data/runtime/fragments/video3_0_98.mp4', np.float32(0.20733836))
('../../data/runtime/fragments/unicorn_0_333.mp4', np.float32(0.20660947))


In [31]:
matches = model.query("man playing guitar")
matches[0][0]

'../../data/runtime/fragments/video2_492_579.mp4'

In [32]:
matches = model.query("an anime with a man on a stick")
matches[0][0]

'../../data/runtime/fragments/video3_98_165.mp4'

In [33]:
df = pd.read_pickle("../../data/runtime/index.pkl")
df

Unnamed: 0,video_name,preview_name,start,end
0,unicorn,unicorn_0_333.png,0,333
1,video0,video0_0_52.png,0,52
2,video0,video0_52_175.png,52,175
3,video0,video0_175_286.png,175,286
4,video0,video0_286_301.png,286,301
5,video1,video1_0_106.png,0,106
6,video1,video1_106_194.png,106,194
7,video1,video1_194_473.png,194,473
8,video2,video2_0_212.png,0,212
9,video2,video2_212_341.png,212,341


In [34]:
#| hide
import nbdev; nbdev.nbdev_export()