In [227]:
import cv2
import torch

from pymilvus import MilvusClient

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import VisionDataset
from torchvision.models import resnet50, ResNet50_Weights
from videohash import VideoHash

# Video preprocessing

In [190]:
success = 0
count = 0

vidcap = cv2.VideoCapture('/Users/wallander/Desktop/v0.mp4')

frames = []
while success:
    success, image = vidcap.read()
    if (count % 24) == 0 and success:
        frames.append(image)
    count += 1
    
frames = np.stack(frames)

print(len(frames))

58


In [28]:
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

In [202]:
class BasicVisionDataset(VisionDataset):
    def __init__(self, images, transform=None, target_transform=None):
        if isinstance(images, np.ndarray):
            transform.transforms.insert(0, transforms.ToPILImage())
        super(BasicVisionDataset, self).__init__(root=None, transform=transform, target_transform=target_transform)
        self.images = images

    def __getitem__(self, index):
        return torch.unsqueeze(self.transform(self.images[index]), 0)

    def __len__(self):
        return len(self.targets)

In [220]:
class FEDataset(Dataset):

    def __init__(self, images, root_dir=None, transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        if isinstance(images, np.ndarray) and transform is not None:
            transform.transforms.insert(0, transforms.ToPILImage())
            
        self.images = images
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.transform(self.images[idx])

In [221]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
])

dataset = FEDataset(frames, transform=transform)
dataloader = DataLoader(dataset=dataset, batch_size=16)

In [222]:
dataset[1].shape

torch.Size([3, 224, 224])

In [245]:
output = torch.vstack([model(batch) for batch in dataloader])
output.shape

torch.Size([58, 1000])

# Milvus

In [257]:
output = output.detach().numpy()

In [258]:
client = MilvusClient("./milvus_demo.db")

In [259]:
if client.has_collection(collection_name="demo_collection"):
    client.drop_collection(collection_name="demo_collection")
client.create_collection(
    collection_name="demo_collection",
    dimension=1000,
)

In [260]:
data = [ {"id": i, "vector": output[i], "text": str(i), "subject": "history"} for i in range(len(output)) ]
res = client.insert(
    collection_name="demo_collection",
    data=data
)

In [261]:
res

{'insert_count': 58, 'ids': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57], 'cost': 0}

In [273]:
res = client.search(
    collection_name="demo_collection",
    data=[output[45]], # replace with your query vector
    limit=3,
    output_fields=["id", "text"]
)

for i in res[0]:
    print(f'distance: {i["distance"]}')
    print(f'id: {i["id"]}')

distance: 1.0000003576278687
id: 45
distance: 0.7165123224258423
id: 42
distance: 0.6817314028739929
id: 49
