## Подключение библиотек

In [None]:
import os
import json
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
from annoy import AnnoyIndex

from annoy import AnnoyIndex
import json

from tqdm import tqdm

import csv

from typing import Callable, List

from collections import Counter

## Загружаем модель

In [None]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

## Загружаем векторнуб базу данных 

In [3]:
path = '/home/jupyter/datasphere/project/annoy indexes/image_embeddings_256.ann'
embedding_dim = 512

with open("/home/jupyter/datasphere/project/annoy_full/annoy_metadata.json", 'r') as f:
        metadata = json.load(f)

index = AnnoyIndex(embedding_dim, 'angular')
index.load(path)

True

## Оценка тестовых объектов

In [4]:
root_directory = '/home/jupyter/datasphere/project/TEST_DATASET'

In [5]:
broken_files = ['664e035d-05b3-4766-922c-432dcad827b2.jpg', '1dddee44-7ae9-4a95-8b7d-b0918c62064c.jpg', 'c2232a78-6d52-4e1b-9dc4-dd38d457217c.jpg']

In [6]:
k = 10 # количество ближайших соседей

item = []
predicted = []

# Итерация по файлам в директории
for dirpath, _, filenames in os.walk(root_directory):
    for filename in tqdm(filenames, desc="Processing files"):
        
        if filename in broken_files:
            continue
        
        item.append(filename)
        
        # Обработка изображения
        file_path = os.path.join(dirpath, filename)
        dir_name = os.path.basename(dirpath) 
        image = Image.open(file_path).convert("RGB")

        # Получение векторного представления входного объекта
        inputs = processor(images=image, return_tensors="pt")
        with torch.no_grad():
            image_embedding = model.get_image_features(**inputs)

        # Получение k ближайших изображений 
        image_embedding = image_embedding.cpu().numpy().flatten()
        indices = index.get_nns_by_vector(image_embedding, k, search_k=k*256*10)
            
        # Получение самого популярного класса среди k ближайших изображений
        candidat_list = []   
        for nn_idx in indices:
            nn_metadata_entry = metadata.get(str(nn_idx))
            candidat_list.append(nn_metadata_entry['directory'])
            
        # Назначаем самый популярный класс из топ k
        most_common_item = Counter(candidat_list).most_common(1)[0][0]

        pred_images = []
        selected = 0
        
        # Получение k ближаших объектов самого популярного класса, которые являются итоговым предсказанием
        indices_sort = index.get_nns_by_vector(image_embedding, index.get_n_items(), search_k=k*256*10)
        for nn_idx in indices_sort:
            nn_metadata_entry = metadata.get(str(nn_idx))
            nn_dir = nn_metadata_entry['directory']

            if (nn_dir == most_common_item):
                nn_filename = nn_metadata_entry['filename']
                pred_images.append(nn_filename)
                selected += 1
                
            if selected == k:
                break
            
        predicted.append(pred_images)

Processing files: 100%|██████████| 423/423 [01:44<00:00,  4.05it/s]


In [7]:
with open('/home/jupyter/datasphere/project/submission/submission.csv', mode='w', newline='', encoding='utf-8') as file:
    writer = csv.writer(file)
    writer.writerow(['image', 'recs']) 

    for img, recs in zip(item, predicted):
        recs_str = ','.join(recs)
        writer.writerow([img, f'{recs_str}']) 

In [8]:
len(predicted)

420