In [None]:
%cd baseline
import pandas as pd
import numpy as np

In [None]:
from pathlib import Path
from pprint import pprint
from typing import Tuple

import hydra
import pytorch_lightning as pl
from omegaconf import DictConfig
from oml.const import TCfg
from oml.datasets.images import get_retrieval_images_datasets
from oml.lightning.callbacks.metric import MetricValCallback
from oml.lightning.modules.extractor import ExtractorModule, ExtractorModuleDDP
from oml.lightning.pipelines.parser import (
    check_is_config_for_ddp,
    parse_logger_from_config,
    parse_ckpt_callback_from_config,
    parse_engine_params_from_config,
    parse_sampler_from_config,
    parse_scheduler_from_config,
)
from oml.metrics.embeddings import EmbeddingMetrics
from oml.registry.losses import get_criterion_by_cfg
from oml.registry.models import get_extractor_by_cfg
from oml.registry.optimizers import get_optimizer_by_cfg
from oml.registry.transforms import TRANSFORMS_REGISTRY, get_transforms_by_cfg
from oml.utils.misc import dictconfig_to_dict, set_global_seed
from torch.utils.data import DataLoader

import torch

import albumentations as albu
import cv2
from albumentations.pytorch import ToTensorV2
from oml.const import MEAN, PAD_COLOR, STD, TNormParam

In [None]:
from datetime import datetime

postfix = "metric_learning"

current_dateTime = datetime.now()
y = current_dateTime.year
month = current_dateTime.month
d = current_dateTime.day
hour = current_dateTime.hour
minute = current_dateTime.minute
s = current_dateTime.second
ms = current_dateTime.microsecond

cfg: TCfg = {
    "postfix": postfix,
    "seed": 42,
    "image_size": 640,
    "accelerator": "gpu",
    "devices": 1, 
    "num_workers": 4,
    "cache_size": 0,
    "test_data_dir": "test/",
    "bs_val": 8,  

    "extractor":{
        "name": "resnet",
        "args":{
            "arch": "resnet50",
            "gem_p": 1.0,
            "remove_fc": True,
            "normalise_features": False,
            "weights": "checkpoints/best.ckpt",
        },
    }
}


In [None]:
def get_transforms(im_size: int, mean: TNormParam = MEAN, std: TNormParam = STD) -> albu.Compose:
    """
    Use default oml albu augs, but without HorizontalFlip.
    :param im_size:
    :param mean:
    :param std:
    :return:
    """
    return albu.Compose(
        [
            albu.LongestMaxSize(max_size=im_size),
            albu.PadIfNeeded(
                min_height=im_size,
                min_width=im_size,
                border_mode=cv2.BORDER_CONSTANT,
                value=PAD_COLOR,
            ),
            albu.Normalize(mean=mean, std=std),
            ToTensorV2(),
        ],
    )

# Формирование предсказания

In [None]:
import itertools
import json
from pathlib import Path

import pytorch_lightning as pl
from torch.utils.data import DataLoader

from oml.const import IMAGE_EXTENSIONS
from oml.datasets.images import ImageBaseDataset
from oml.ddp.utils import get_world_size_safe, is_main_process, sync_dicts_ddp
from oml.transforms.images.utils import get_im_reader_for_transforms
from oml.utils.images.images import find_broken_images
from oml.utils.misc import dictconfig_to_dict


def extractor_prediction_pipeline(cfg: TCfg) -> None:
    """
    This pipeline allows you to save features extracted by a feature extractor.

    """
    print(cfg)

    transforms = get_transforms(cfg['image_size'])
    filenames = [list(Path(cfg["test_data_dir"]).glob(f"**/*.{ext}")) for ext in IMAGE_EXTENSIONS]
    filenames = list(itertools.chain(*filenames))

    if len(filenames) == 0:
        raise RuntimeError(f"There are no images in the provided directory: {cfg['test_data_dir']}")

    f_imread = get_im_reader_for_transforms(transforms)

    print("Let's check if there are broken images:")
    broken_images = find_broken_images(filenames, f_imread=f_imread)
    if broken_images:
        raise ValueError(f"There are images that cannot be open:\n {broken_images}.")

    dataset = ImageBaseDataset(paths=filenames, transform=transforms, f_imread=f_imread)

    loader = DataLoader(
        dataset=dataset, batch_size=cfg["bs_val"], num_workers=cfg["num_workers"], shuffle=False, drop_last=False
    )

    extractor = get_extractor_by_cfg(cfg["extractor"])
    pl_model = ExtractorModule(extractor=extractor)

    trainer_engine_params = parse_engine_params_from_config(cfg)
    trainer_engine_params["use_distributed_sampler"] = True
    trainer = pl.Trainer(precision=16, **trainer_engine_params)
    predictions = trainer.predict(model=pl_model, dataloaders=loader, return_predictions=True)

    paths, embeddings = [], []
    for prediction in predictions:
        paths.extend([filenames[i] for i in prediction[dataset.index_key].tolist()])
        embeddings.extend(prediction[pl_model.embeddings_key].tolist())

    paths = sync_dicts_ddp({"key": list(map(str, paths))}, get_world_size_safe())["key"]
    embeddings = sync_dicts_ddp({"key": embeddings}, get_world_size_safe())["key"]

    
    return dict(zip(paths, embeddings))

In [None]:
dict_results = extractor_prediction_pipeline(cfg)

In [None]:
import faiss
import numpy as np

# Преобразуем данные в массив numpy
paths = list(dict_results.keys())
embeddings = np.array(list(dict_results.values()), dtype=np.float32)

# Нормализуем эмбеддинги для косинусной близости
faiss.normalize_L2(embeddings)

# Создаем индекс FAISS для косинусной близости
index = faiss.IndexFlatIP(embeddings.shape[1])  # IndexFlatIP для внутреннего произведения (косинусная близость)
index.add(embeddings)  # Добавляем эмбеддинги в индекс

final_result = {}
# Выбираем запрашиваемое изображение 
for query_index in range(len(paths)):
    query_embedding = embeddings[query_index].reshape(1, -1)
    query = str(Path(paths[query_index]).name)
    # Ищем ближайшие изображения
    k = embeddings.shape[0]  # Количество ближайших соседей (все изображения)
    distances, indices = index.search(query_embedding, k)
    
    # Сортируем результаты по расстоянию (косинусная близость)
    sorted_results = [Path(paths[i]).name for i in indices[0]]
    final_result[query] = sorted_results


In [None]:
submission_df = pd.DataFrame(list(final_result.items()), columns=["image_name", "recommendation"])
submission_df["recommendation"] = submission_df["recommendation"].apply(lambda x: list(x))

In [None]:
submission_df.head()

In [None]:
submission_df.to_csv("submission.csv", index=False)