      
# Обучение метрики и быстрый векторный поиск

## Введение

Этот ноутбук сделан в основном для демонстрации возможностей обучения метрики и поиска с помощью библиотеки qdrant.
В качестве задачи выбрана задача поиска родственников. Т.е. по фотографии человека нужно в базе найти фотографии его родственников. Для этого есть база Families in the Wild (FIW), и она уже подготовлена для моделей, которые будут использоваться в этом задании (сделаны шаги детекции и выравнивания).

## Цели

 * Используя предобученный энкодер от insightface на базе resnet18 собрать базу векторов лиц для FiW.
 * С помощью qdrant посчитать метрики посика (Precision@K, False Match Rate @ False Non-Match Rate).
 * Попробовать конфигурации приближённого поиска qdrant, и сравнить метрики поиска с предыдущими.
 * С помощью тюнинга модели через методы обучения метрики улучшить показатели для нашей задачи.

## Иструкции

 * По ходу кода можно будет встретить секции `#TODO: ...`, после которых предлагается что-то сделать.
 * Для тюнинга сети есть отдельные параметры в начале кода, которые можно менять.
 * По ходу кода также будут написаны критерии проверки.
 * При сдаче лучше прогнать весь ноутбук с нуля, чтобы было понятно, как происходило выполнение.

---

    

# 1. Установка Библиотек

In [None]:
!pip install torch torchvision torchaudio timm qdrant-client pytorch-metric-learning scikit-learn pandas tqdm Pillow ipywidgets wldhx.yadisk-direct --quiet
print("Libraries installed successfully!")

Libraries installed successfully!


# 2. Загрузка и Подготовка Набора Данных и Предобученного Бэкбона

In [None]:
!cd /content/ && mkdir fiw_downloads -p
!curl -L $(yadisk-direct https://disk.yandex.kz/d/7LRCtOL1RUVbdg) -o /content/fiw_downloads/train.tar.gz
!curl -L $(yadisk-direct https://disk.yandex.kz/d/I-hA1X76G3RAAQ) -o /content/fiw_downloads/val.tar.gz
!cd /content/fiw_downloads/ && tar -xzf train.tar.gz && tar -xzf val.tar.gz
!cd /content/ && cp -r /content/fiw_downloads/train-faces-det /content/ && cp -r /content/fiw_downloads/val-faces-det /content/
!ls -lah /content/
!curl -L $(yadisk-direct https://disk.yandex.kz/d/iJ4UtNPJ9iH9EQ) -o /content/backbone.pth

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:--  0:00:01 --:--:--     0
100 69.6M  100 69.6M    0     0  9965k      0  0:00:07  0:00:07 --:--:-- 14.8M
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 22.2M  100 22.2M    0     0  5253k      0  0:00:04  0:00:04 --:--:-- 6718k
total 92M
drwxr-xr-x   1 root root 4.0K Jun  7 01:19 .
drwxr-xr-x   1 root root 4.0K Jun  7 01:14 ..
-rw-r--r--   1 root root  92M Jun  7 01:19 backbone.pth
drwxr-xr-x   4 root root 4.0K Jun  5 13:38 .config
drwxr-xr-x   4 root root 4.0K Jun  7 01:18 fiw_downloads
drwxr-xr-x   1 root root 4.0K Jun  5 13:38 sample_data
drwxr-xr-x 572 root root  20K Jun  7 01:18 train

# 3. Импорт Библиотек

In [None]:
import os
import glob
import random
from collections import defaultdict
from PIL import Image
import numpy as np
import pandas as pd
import time # For timing searches
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision import transforms
import timm

from qdrant_client import QdrantClient, models
from qdrant_client.http.models import PointStruct, Distance, VectorParams

from pytorch_metric_learning import losses, miners, samplers, testers, trainers
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.distances import CosineSimilarity

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import roc_auc_score
from tqdm.notebook import tqdm # Use tqdm.notebook for better Colab integration

# For reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

print("Libraries imported.")

Libraries imported.


# 4. Параметры Конфигурации (Редактируемые Формы)

In [None]:
#@markdown **Пути к наборам данных автоматически устанавливаются на основе Ячейки 2.**
#@markdown Параметры подмножества применяются к соответствующим наборам данных (train для дообучения, val для БД/запроса).
FIW_TRAIN_DATA_PATH = "/content/train-faces-det"  #@param {type:"string"}
FIW_VAL_DATA_PATH = "/content/val-faces-det"    #@param {type:"string"}
MIN_IMAGES_PER_PERSON_GLOBAL = 2
#@markdown ---
#@markdown **Имена коллекций Qdrant:**
COLLECTION_NAME_PRETRAINED = "fiw_pretrained_val"  #@param {type:"string"}
COLLECTION_NAME_PRETRAINED_OPTIMIZED = "fiw_pretrained_val_optimized"  #@param {type:"string"}
COLLECTION_NAME_FINETUNED = "fiw_finetuned_val"  #@param {type:"string"}
#@markdown ---
#@markdown **Параметры дообучения:**
NUM_EPOCHS_FINETUNE = 2  #@param {type:"integer"}
LEARNING_RATE_MODEL = 1e-5 #@param {type:"number"}
LEARNING_RATE_LOSS = 1e-3 #@param {type:"number"}
#@markdown ---
#@markdown **Параметры батча:**
BATCH_SIZE = 64  #@param {type: "integer"}
MIN_IMAGES_PER_CLASS_IN_BATCH = 4  #@param {type: "integer"}  # Number of images of the same class inside the batch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Используется устройство: {DEVICE}")
print(f"Путь к обучающим данным FIW: {FIW_TRAIN_DATA_PATH}")
print(f"Путь к валидационным данным FIW: {FIW_VAL_DATA_PATH}")

# Parameter validation / adjustment helper
def _validate_subset_param(param_val):
    return None if param_val <= 0 else param_val

for path in [FIW_TRAIN_DATA_PATH, FIW_VAL_DATA_PATH]:
    if not os.path.exists(path):
        print(f"WARNING: FIW Path '{path}' does not exist. Ensure Cell 2 (Download and Prepare Dataset) ran successfully.")
    else:
        num_families = len(glob.glob(os.path.join(path, 'F*')))
        print(f"Found {num_families} family folders in {path}.")

Используется устройство: cuda
Путь к обучающим данным FIW: /content/train-faces-det
Путь к валидационным данным FIW: /content/val-faces-det
Found 570 family folders in /content/train-faces-det.
Found 192 family folders in /content/val-faces-det.


# 5. IResNet Model (from InsightFace)

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

using_ckpt = False

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=dilation,
                     groups=groups,
                     bias=False,
                     dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class IBasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 groups=1, base_width=64, dilation=1):
        super(IBasicBlock, self).__init__()
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
        self.conv1 = conv3x3(inplanes, planes)
        self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.prelu = nn.PReLU(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.downsample = downsample
        self.stride = stride

    def forward_impl(self, x):
        identity = x
        out = self.bn1(x)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return out

    def forward(self, x):
        if self.training and using_ckpt:
            return checkpoint(self.forward_impl, x)
        else:
            return self.forward_impl(x)


class IResNet(nn.Module):
    fc_scale = 7 * 7
    def __init__(self,
                 block, layers, dropout=0, num_features=512, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
        super(IResNet, self).__init__()
        self.extra_gflops = 0.0
        self.fp16 = fp16
        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
        self.prelu = nn.PReLU(self.inplanes)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
        self.layer2 = self._make_layer(block,
                                       128,
                                       layers[1],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block,
                                       256,
                                       layers[2],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block,
                                       512,
                                       layers[3],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
        self.dropout = nn.Dropout(p=dropout, inplace=True)
        self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
        self.features = nn.BatchNorm1d(num_features, eps=1e-05)
        nn.init.constant_(self.features.weight, 1.0)
        self.features.weight.requires_grad = False

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0, 0.1)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, IBasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
            )
        layers = []
        layers.append(
            block(self.inplanes, planes, stride, downsample, self.groups,
                  self.base_width, previous_dilation))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      groups=self.groups,
                      base_width=self.base_width,
                      dilation=self.dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        with torch.cuda.amp.autocast(self.fp16):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.prelu(x)
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
            x = self.bn2(x)
            x = torch.flatten(x, 1)
            x = self.dropout(x)
        x = self.fc(x.float() if self.fp16 else x)
        x = self.features(x)
        x = F.normalize(x, 2, 1)
        return x


def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
    model = IResNet(block, layers, **kwargs)
    if pretrained:
        raise ValueError()
    return model


def iresnet18(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
                    progress, **kwargs)


def iresnet34(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
                    progress, **kwargs)


def iresnet50(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
                    progress, **kwargs)


def iresnet100(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
                    progress, **kwargs)

def iresnet200(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
                    progress, **kwargs)


# 6. Helper Functions

In [None]:
# Дополнительные функции потерь, более расширенный вариант, чем у pytorch-metric-learning

import math
import typing as t


class NormedLinear(nn.Linear):
    def __init__(self, *args, **kwargs):
        kwargs['bias'] = False
        super().__init__(*args, **kwargs)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        normed_weights = F.normalize(self.weight)
        return F.linear(input_tensor, normed_weights)


class EmbeddingNormalization(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        return F.normalize(input_tensor)


class CosFace(nn.Module):
    def __init__(self, s=64.0, m=0.40, label_smoothing: float = 0.0, target_loss: t.Any = F.cross_entropy):
        super(CosFace, self).__init__()
        self.s = s
        self.m = m
        self.label_smoothing = label_smoothing
        self.target_loss = target_loss

    def forward(self, cosine, label):
        index = torch.where(label != -1)[0]
        m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
        m_hot.scatter_(1, label[index, None], self.m)
        cosine[index] -= m_hot
        ret = cosine * self.s
        return self.target_loss(ret, label, label_smoothing=self.label_smoothing)


class ArcFace(nn.Module):
    def __init__(self, s=64.0, m=0.5, label_smoothing: float = 0.0, target_loss: t.Any = F.cross_entropy):
        super(ArcFace, self).__init__()
        self.s = s
        self.m = m
        self.label_smoothing = label_smoothing
        self.target_loss = target_loss

    def forward(self, cosine: torch.Tensor, label):
        index = torch.where(label != -1)[0]
        m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
        m_hot.scatter_(1, label[index, None], self.m)
        cosine.acos_()
        cosine[index] += m_hot
        cosine.cos_().mul_(self.s)
        return self.target_loss(cosine, label, label_smoothing=self.label_smoothing)


# based on https://github.com/niliusha123/Margin-based-Softmax/blob/main/sphereface2.py
class SphereProduct2(nn.Module):
    def __init__(self, lamb=0.7, r=30, m=0.4, t=3, b=0.25, label_smoothing: float = 0.0, reduction: str = 'mean'):
        super(SphereProduct2, self).__init__()
        self.lamb = lamb
        self.r = r
        self.m = m
        self.t = t
        self.b = b
        self.label_smoothing = label_smoothing
        self.reduction = reduction

    def similarity_correction(self, cosine):
        return 2 * ((cosine + 1) / 2) ** self.t - 1

    def forward(self, cosine: torch.Tensor, label):
        num_classes = cosine.shape[1]
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cos_theta = cosine.clamp(-1, 1)
        similarity = self.similarity_correction(cos_theta)
        cos_m_theta_p = self.r * (similarity - self.m) + self.b
        cos_m_theta_n = self.r * (similarity + self.m) + self.b
        cos_p_theta = (self.lamb / self.r) * torch.log(1 + torch.exp(-cos_m_theta_p))

        cos_n_theta = ((1 - self.lamb) / self.r) * torch.log(1 + torch.exp(cos_m_theta_n))
        # --------------------------- convert label to one-hot ---------------------------
        index = torch.where(label != -1)[0]
        one_hot = torch.zeros(cos_theta.size(), device=cosine.device) + self.label_smoothing / (num_classes - 1)
        one_hot.scatter_(1, label[index, None].long(), 1 - self.label_smoothing)
        # --------------------------- Calculate output ---------------------------
        loss = (one_hot * cos_p_theta) + (1 - one_hot) * cos_n_theta
        loss = loss.sum(dim=1)
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'none':
            return loss
        else:
            return loss.mean()

In [None]:
def get_image_paths_and_ids(dataset_path, max_images_per_person=None, min_images_per_person=2, subset_families=None):
    """
    Scans the dataset directory and extracts image paths and person IDs (FamilyID_MID).
    `max_images_per_person` and `subset_families` can be None or 0 to use all.
    """
    image_data = []
    person_image_counts = defaultdict(list)

    if not os.path.exists(dataset_path):
        print(f"Error: Dataset path {dataset_path} not found.")
        return []

    family_folders = sorted([f for f in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, f)) and f.startswith("F")])

    if subset_families is None or subset_families <= 0:
        # print(f"Using all {len(family_folders)} families from {dataset_path}.") # Can be verbose
        pass
    elif subset_families < len(family_folders):
        family_folders = family_folders[:subset_families]
        print(f"Using subset of {subset_families} families from {os.path.basename(dataset_path)}.")
    else:
        print(f"Requested subset_families ({subset_families}) is >= total families ({len(family_folders)}) in {os.path.basename(dataset_path)}. Using all available.")

    for family_id in tqdm(family_folders, desc=f"Scanning F in {os.path.basename(dataset_path)}", leave=False):
        family_dir = os.path.join(dataset_path, family_id)
        member_folders = sorted([m for m in os.listdir(family_dir) if os.path.isdir(os.path.join(family_dir, m)) and m.startswith("MID")])
        for member_id in member_folders:
            person_id_str = f"{family_id}"
            member_dir = os.path.join(family_dir, member_id)
            images_in_member_dir = glob.glob(os.path.join(member_dir, "*.jpg")) + \
                                   glob.glob(os.path.join(member_dir, "*.png"))
            if images_in_member_dir:
                person_image_counts[person_id_str].extend(images_in_member_dir)

    valid_person_ids = [pid for pid, imgs in person_image_counts.items() if len(imgs) >= min_images_per_person]

    for person_id_str in valid_person_ids:
        person_images = person_image_counts[person_id_str]
        if max_images_per_person is not None and max_images_per_person > 0:
            if len(person_images) > max_images_per_person:
                 person_images = random.sample(person_images, max_images_per_person)
        for img_path in person_images:
            image_data.append((img_path, person_id_str))

    print(f"Dataset {os.path.basename(dataset_path)}: Total images collected: {len(image_data)} from {len(valid_person_ids)} persons.")
    if not image_data and ( (subset_families is not None and subset_families > 0) or \
                           (max_images_per_person is not None and max_images_per_person > 0) ):
        print(f"Warning: No images collected for {os.path.basename(dataset_path)} with current subsetting. Check parameters or dataset integrity.")
    return image_data


def get_timm_model_and_transform(model_name, pretrained=True):
    model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
    model = nn.Sequential(model, EmbeddingNormalization())
    model = model.to(DEVICE)
    model.eval()
    config = timm.data.resolve_data_config({}, model=model)
    transform = timm.data.create_transform(**config)
    features = model(torch.zeros((1, 3, 112, 112)).to(DEVICE))
    model.num_features = features.shape[-1]
    return model, transform

def get_arcface_model_and_transform(pretrained=True):
    model = iresnet18()
    if pretrained:
        model.load_state_dict(torch.load('/content/backbone.pth', map_location='cpu'))
    model.to(DEVICE)
    model.eval()
    transform = T.Compose(
            [T.ToTensor(),
             T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
             ])
    features = model(torch.zeros((1, 3, 112, 112)).to(DEVICE))
    model.num_features = features.shape[-1]
    return model, transform

def embed_images(model, transform, image_data_list, batch_size=32, desc_prefix=""):
    all_embeddings_list = [] # Use a list to append numpy arrays
    all_person_ids = []
    all_image_paths = []
    model.eval()
    with torch.no_grad(), torch.inference_mode(True):
        for i in tqdm(range(0, len(image_data_list), batch_size), desc=f"{desc_prefix}Embedding images", leave=False):
            batch_data = image_data_list[i:i+batch_size]
            images_batch, current_person_ids_batch, current_image_paths_batch = [], [], []
            for img_path, person_id_str in batch_data:
                try:
                    img = Image.open(img_path).convert("RGB")
                    img_tensor = transform(img)
                    images_batch.append(img_tensor)
                    current_person_ids_batch.append(person_id_str)
                    current_image_paths_batch.append(img_path)
                except Exception: pass # Silently skip problematic images
            if not images_batch: continue
            images_tensor = torch.stack(images_batch).to(DEVICE)
            embeddings_np = model(images_tensor).cpu().numpy()
            all_embeddings_list.append(embeddings_np) # Append numpy array directly
            all_person_ids.extend(current_person_ids_batch)
            all_image_paths.extend(current_image_paths_batch)

    if not all_embeddings_list: return np.array([]), [], [] # Return empty if nothing was embedded
    all_embeddings_np = np.concatenate(all_embeddings_list, axis=0)
    return all_embeddings_np, all_person_ids, all_image_paths


def setup_qdrant_collection(client, collection_name, vector_size, distance=Distance.COSINE, recreate=True, hnsw_config=None, quantization_config=None):
    collection_exists = False
    try:
        client.get_collection(collection_name=collection_name)
        collection_exists = True
    except Exception: pass # Assumes error means collection doesn't exist

    if recreate and collection_exists:
        client.delete_collection(collection_name=collection_name)
        print(f"Collection {collection_name} deleted.")
        collection_exists = False

    if not collection_exists:
        print(f"Creating collection {collection_name}...")
        client.create_collection(
            collection_name=collection_name,
            vectors_config=VectorParams(size=vector_size, distance=distance), # Corrected: use models.VectorParams if that's the type
            hnsw_config=hnsw_config, quantization_config=quantization_config)
        print(f"Collection {collection_name} created.")
    else: print(f"Collection {collection_name} already exists and recreate was False.")


def index_embeddings_qdrant(client, collection_name, embeddings, image_paths, person_ids):
    if not isinstance(embeddings, np.ndarray) or embeddings.size == 0:
        print(f"No valid embeddings to index for {collection_name}.")
        return

    points_to_upsert = []
    for i, (emb, img_path, p_id) in enumerate(zip(embeddings, image_paths, person_ids)):
        # Ensure emb is a list of floats, not numpy array or other types
        vector_as_list = emb.tolist() if isinstance(emb, np.ndarray) else list(map(float, emb))
        points_to_upsert.append(
            PointStruct(id=i, vector=vector_as_list, payload={"image_path": img_path, "person_id": p_id})
        )

    if points_to_upsert:
        client.upsert(collection_name=collection_name, points=points_to_upsert, wait=True)
        print(f"Indexed {len(points_to_upsert)} embeddings into {collection_name}")
    else:
        print(f"No points were prepared for indexing into {collection_name}.")


def calculate_precision_at_k(qdrant_client, collection_name, query_embeddings, query_person_ids, query_image_paths, k_values, db_image_paths_set):
    precision_at_k = {k: [] for k in k_values}
    max_k = max(k_values) if k_values else 0
    if max_k == 0: return {k: 0.0 for k in k_values}, 0.0

    search_times = []
    for i in tqdm(range(len(query_embeddings)), desc="Calculating P@K", leave=False):
        # Ensure query_emb is a flat list of floats
        query_emb_np = query_embeddings[i]
        query_emb_list = query_emb_np.tolist() if isinstance(query_emb_np, np.ndarray) else list(map(float, query_emb_np))

        query_pid = query_person_ids[i]
        query_img_path = query_image_paths[i]

        start_time = time.time()
        search_results_response = qdrant_client.query_points(
            collection_name=collection_name,
            query=query_emb_list,  # Pass the vector to 'query' parameter
            limit=max_k + 5,
            with_payload=True,
            with_vectors=False,
        )
        search_times.append(time.time() - start_time)

        retrieved_pids = []
        # Qdrant's query_points usually returns a list of ScoredPoint objects directly
        # or a response object containing `points`.
        actual_hits = search_results_response
        if hasattr(search_results_response, 'points') and isinstance(search_results_response.points, list):
            actual_hits = search_results_response.points


        for hit in actual_hits:
            if hit.payload["image_path"] == query_img_path or \
               (db_image_paths_set and hit.payload["image_path"] not in db_image_paths_set):
                print('strange!')
                continue
            retrieved_pids.append(hit.payload["person_id"])
            if len(retrieved_pids) == max_k:
                break
        # print('hits', query_pid, retrieved_pids)

        if not retrieved_pids:
            for k_val in k_values: precision_at_k[k_val].append(0.0)
            continue

        for k_val in k_values:
            if k_val == 0:
                precision_at_k[k_val].append(0.0)
                continue
            actual_retrieved_count = len(retrieved_pids)
            # Consider only the top min(k_val, actual_retrieved_count) results for P@k
            capped_retrieved_pids = retrieved_pids[:min(k_val, actual_retrieved_count)]
            relevant_count = sum(1 for pid in capped_retrieved_pids if pid == query_pid)

            denominator = min(k_val, actual_retrieved_count) # P@k is relevant/k, but if less than k retrieved, use actual retrieved
            if denominator == 0 and k_val > 0 : # If k>0 but 0 retrieved, precision is 0
                 precision_at_k[k_val].append(0.0)
            elif denominator > 0:
                 precision_at_k[k_val].append(relevant_count / denominator)
            else: # k_val is 0 or something unexpected
                 precision_at_k[k_val].append(0.0)


    avg_search_time = np.mean(search_times) if search_times else 0
    return {k: np.mean(values) if values else 0.0 for k, values in precision_at_k.items()}, avg_search_time


def calculate_fmr_at_fnmr(qdrant_client, collection_name, query_embeddings, query_person_ids, query_image_paths, target_fnmr=0.1, num_eval_queries=100, db_image_paths_set=None):
    genuine_scores, imposter_scores = [], []
    if not db_image_paths_set: print("Warning: db_image_paths_set not provided for FMR@FNMR.")
    if len(query_embeddings) == 0: print("No query embeddings for FMR@FNMR calculation."); return -1.0

    # query_indices = random.sample(range(len(query_embeddings)), min(num_eval_queries, len(query_embeddings)))
    query_indices = range(len(query_embeddings))
    if not query_indices: print("No queries to evaluate for FMR@FNMR after sampling."); return -1.0

    for i in tqdm(query_indices, desc="Collecting scores for FMR@FNMR", leave=False):
        query_emb_np = query_embeddings[i]
        query_emb_list = query_emb_np.tolist() if isinstance(query_emb_np, np.ndarray) else list(map(float, query_emb_np))
        query_pid = query_person_ids[i]
        query_img_path = query_image_paths[i]

        # *** CORRECTED QDRANT CALL for query_points ***
        search_results_response = qdrant_client.query_points(
            collection_name=collection_name,
            query=query_emb_list,  # Pass the vector to 'query' parameter
            limit=500,
            with_vectors=False,
            with_payload=True
        )

        actual_hits = search_results_response
        if hasattr(search_results_response, 'points') and isinstance(search_results_response.points, list):
            actual_hits = search_results_response.points

        # print('fnmr hits', query_pid, [(hit.payload["person_id"], hit.score) for hit in actual_hits])

        for hit in actual_hits:
            if hit.payload["image_path"] == query_img_path: continue
            if db_image_paths_set and hit.payload["image_path"] not in db_image_paths_set: continue
            score = -hit.score
            if hit.payload["person_id"] == query_pid: genuine_scores.append(score)
            else: imposter_scores.append(score)

    if not genuine_scores or not imposter_scores:
        print("Warning: Not enough genuine or imposter scores for FMR@FNMR."); return -1.0

    genuine_scores, imposter_scores = np.array(sorted(genuine_scores)), np.array(sorted(imposter_scores))
    y_true = np.concatenate((np.ones(len(genuine_scores)), np.zeros(len(imposter_scores))))
    y_pred = -np.concatenate((genuine_scores, imposter_scores))

    if len(genuine_scores) == 0: print("Error: No genuine scores for FMR@FNMR."); return -2.0

    # Calculate threshold for target FNMR
    # FNMR = count(genuine_scores < threshold) / len(genuine_scores)
    # We want FNMR <= target_fnmr. threshold = (1-target_fnmr)-th percentile of genuine_scores
    threshold_idx_fnmr = int(len(genuine_scores) * (1.0 - target_fnmr)) # Use 1.0 for float context
    threshold_idx_fnmr = min(max(0, threshold_idx_fnmr), len(genuine_scores) - 1) # Bound index
    threshold = genuine_scores[threshold_idx_fnmr]

    # Calculate FMR at this threshold
    # FMR = count(imposter_scores < threshold) / len(imposter_scores)
    fmr = np.sum(imposter_scores < threshold) / len(imposter_scores) if len(imposter_scores) > 0 else 0.0

    actual_fnmr_count = np.sum(genuine_scores > threshold)
    actual_fnmr = actual_fnmr_count / len(genuine_scores) if len(genuine_scores) > 0 else 0.0
    print(f"  Target FNMR: {target_fnmr:.4f}, Actual FNMR: {actual_fnmr:.4f} at threshold {threshold:.4f} (lower score is better)")
    return fmr


class FIWMetricLearningDataset(Dataset):
    def __init__(self, image_paths, person_ids_str, transform, label_encoder):
        self.image_paths, self.person_ids_str, self.transform, self.label_encoder = image_paths, person_ids_str, transform, label_encoder
        self.labels = self.label_encoder.transform(self.person_ids_str)
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img_path, label = self.image_paths[idx], self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        return self.transform(image), label


def test_qdrant_configuration(qdrant_client,
                              db_embeddings, db_person_ids, db_image_paths,
                              query_embeddings, query_person_ids, query_image_paths,
                              k_values, target_fnmr, embedding_dim,
                              hnsw_config, quantization_config,
                              description_prefix):
    """
    Настраивает коллекцию Qdrant с заданными параметрами, индексирует эмбеддинги
    и вычисляет метрики поиска.

    Возвращает словарь с Precision@K, FMR@FNMR и средним временем поиска.
    """
    collection_name = f"test_collection"
    print(f"\n--- Тестирование конфигурации Qdrant: {description_prefix} ---")
    print(f"Настройка коллекции '{collection_name}' с HNSW={hnsw_config}, Квантование={quantization_config}")

    # Убедитесь, что коллекция существует и настроена правильно
    setup_qdrant_collection(qdrant_client, collection_name, embedding_dim, Distance.COSINE, recreate=True, hnsw_config=hnsw_config, quantization_config=quantization_config)
    # Индексируем эмбеддинги БД
    index_embeddings_qdrant(qdrant_client, collection_name, db_embeddings, db_image_paths, db_person_ids)

    # Подготавливаем набор путей к изображениям БД для исключения само-совпадений
    db_image_paths_set = set(db_image_paths)

    # Вычисляем Precision@K
    precision, search_time = calculate_precision_at_k(qdrant_client, collection_name, query_embeddings, query_person_ids, query_image_paths, k_values, db_image_paths_set)
    print(f"{description_prefix} - P@K: {precision}, Среднее Время Поиска: {search_time:.4f}с")

    # Вычисляем FMR@FNMR
    fmr = calculate_fmr_at_fnmr(qdrant_client, collection_name, query_embeddings, query_person_ids, query_image_paths, target_fnmr=target_fnmr, db_image_paths_set=db_image_paths_set)
    print(f"{description_prefix} - FMR @ FNMR={target_fnmr:.2f}: {fmr:.4f}")

    return {"p@k": precision, "fmr@fnmr": fmr, "search_time": search_time}

print("Helper functions defined with corrected Qdrant query_points usage.")

Helper functions defined with corrected Qdrant query_points usage.


# 7. Выполнение Основного Скрипта

In [None]:
# --- Part 0: Data Preparation ---
print("--- Part 0: Data Preparation (Train/Val Split) ---")

# Load training data (for fine-tuning)
print(f"Loading training data from: {FIW_TRAIN_DATA_PATH}")
train_image_data_all = get_image_paths_and_ids(
    FIW_TRAIN_DATA_PATH,
    max_images_per_person=None,
    min_images_per_person=2
)
if not train_image_data_all:
    print("No training data loaded. Fine-tuning will be skipped.")
else:
    print(f"Loaded {len(train_image_data_all)} images for training.")

# Load validation data (for DB indexing and querying)
print(f"\nLoading validation data from: {FIW_VAL_DATA_PATH}")
val_image_data_all = get_image_paths_and_ids(
    FIW_VAL_DATA_PATH,
    max_images_per_person=None,
    min_images_per_person=2 # Ensure query images also meet this
)
if not val_image_data_all:
    print("No validation data loaded. Evaluation cannot proceed. Exiting.")

db_data_val, query_data_val = train_test_split(val_image_data_all, train_size=0.8, stratify=[el[1] for el in val_image_data_all])

if not db_data_val or not query_data_val:
    print("Error: DB or Query set from validation data is empty after splitting.")

db_data_train = train_image_data_all

print(f"Validation Data Split: DB set has {len(db_data_val)} images")
print(f"Validation Data Split: Query set has {len(query_data_val)} images")

--- Part 0: Data Preparation (Train/Val Split) ---
Loading training data from: /content/train-faces-det


Scanning F in train-faces-det:   0%|          | 0/570 [00:00<?, ?it/s]

Dataset train-faces-det: Total images collected: 15822 from 570 persons.
Loaded 15822 images for training.

Loading validation data from: /content/val-faces-det


Scanning F in val-faces-det:   0%|          | 0/192 [00:00<?, ?it/s]

Dataset val-faces-det: Total images collected: 5045 from 192 persons.
Validation Data Split: DB set has 4036 images
Validation Data Split: Query set has 1009 images


In [None]:
# Example of loss functions
# параметры и примеры для того, чтобы проще было менять настройки
miner = miners.MultiSimilarityMiner(epsilon=0.1, distance=CosineSimilarity())
loss_contrastive = losses.ContrastiveLoss(distance=CosineSimilarity())
loss_infonce = losses.NTXentLoss(temperature=0.5)
loss_arcface = ArcFace(label_smoothing=1e-5) # это реализация, описанная выше, а не взятая из pytorch metric learning, юудьте внимательны, в неё нужно передавать не эмбеддинги, а результаты классификации (пример можно увидеть в коде)

### Конфигурации Qdrant (10 баллов)

Добавьте хотя бы ещё две отличных конфигурации qdrant, они замерятся в основной части, и результаты сохранятся в последней ячейке.

In [None]:
# Определяем все конфигурации Qdrant для тестирования
# Это список словарей, каждый из которых описывает одну конфигурацию
# TODO: добавьте сюда хотя бы ещё 2 разные конфигурации
qdrant_test_scenarios = [
    {
        "key": "pretrained_on_val_with_train",
        "description": "Предобученная Модель Val+Train (Стандартный Qdrant)",
        "hnsw_config": None,
        "quantization_config": None,
    },
    {
        "key": "pretrained_on_val_optimized_qdrant",
        "description": "Предобученная Модель Val+Train (Оптимизированный Qdrant - Скалярное Квантование)",
        "hnsw_config": models.HnswConfigDiff(m=1, ef_construct=4),
        "quantization_config": models.ScalarQuantization(scalar=models.ScalarQuantizationConfig(type=models.ScalarType.INT8, quantile=0.99, always_ram=True )),
    },
    {
        "key": "pretrained_on_val_binary_quant",
        "description": "Предобученная Модель Val+Train (BQ)",
        "hnsw_config": None, # HNSW по умолчанию
        "quantization_config": models.BinaryQuantization(binary=models.BinaryQuantizationConfig(always_ram=True)),
    },
    {
        "key": "pretrained_on_val_product_quant",
        "description": "Предобученная Модель Val+Train (PQ)",
        "hnsw_config": None, # HNSW по умолчанию
        "quantization_config": models.ProductQuantization(
            product=models.ProductQuantizationConfig(
                compression=models.CompressionRatio.X4, # Можно попробовать X8, X16, X32
                always_ram=True
            )
        ),
    },
    {
        "key": "pretrained_on_val_aggressive_hnsw",
        "description": "Предобученная Модель Val+Train (Агрессивное HNSW)",
        "hnsw_config": models.HnswConfigDiff(m=16, ef_construct=200, full_scan_threshold=1000000), # Более агрессивные параметры HNSW
        "quantization_config": None,
    },
    {
        "key": "pretrained_on_val_aggressive_hnsw_scalar",
        "description": "Предобученная Модель Val (Агрессивное HNSW + Квантование)",
        "hnsw_config": models.HnswConfigDiff(m=16, ef_construct=200, full_scan_threshold=1000000), # Более агрессивные параметры HNSW
        "quantization_config": models.ScalarQuantization(scalar=models.ScalarQuantizationConfig(type=models.ScalarType.INT8, quantile=0.99, always_ram=True )),
    },
]

## Основная часть домашней работы

In [None]:
target_fnmr=0.1
def run_homework_train_val_split(target_fnmr=0.1):
    # Global Label Encoder for consistency IF person IDs could overlap (not the case here due to train/val files)
    # For this setup, it's fine to create label encoders per dataset part as needed.
    results_summary = {} # To store metrics from different stages

    # --- Part 1: Pretrained Model Evaluation on Validation Data ---
    print("\n--- Part 1: Pretrained Model Evaluation (on Validation Set) ---")
    pretrained_model, transform = get_arcface_model_and_transform(pretrained=True)
    embedding_dim = pretrained_model.num_features

    print("Embedding DB images from Validation Set (pretrained)...")
    db_embeddings_val_pt, db_person_ids_val_pt, db_image_paths_val_pt = embed_images(pretrained_model, transform, db_data_val, desc_prefix="Val-DB Pretrained: ")
    if db_embeddings_val_pt.size == 0: print("Failed to generate Val-DB embeddings (pretrained). Skipping Part 1 eval."); return results_summary

    print("Embedding train images as DB...")
    db_embeddings_train_pt, db_person_ids_train_pt, db_image_paths_train_pt = embed_images(pretrained_model, transform, db_data_train, desc_prefix="Val-DB Pretrained: ")
    if db_embeddings_train_pt.size == 0: print("Failed to generate Val-DB embeddings (pretrained). Skipping Part 1 eval."); return results_summary

    print("Embedding Query images from Validation Set (pretrained)...")
    query_embeddings_val_pt, query_person_ids_val_pt, query_image_paths_val_pt = embed_images(pretrained_model, transform, query_data_val, desc_prefix="Val-Query Pretrained: ")
    if query_embeddings_val_pt.size == 0: print("Failed to generate Val-Query embeddings (pretrained). Skipping Part 1 eval."); return results_summary

    qdrant_client = QdrantClient(location=":memory:")

    setup_qdrant_collection(qdrant_client, COLLECTION_NAME_PRETRAINED, embedding_dim, Distance.COSINE, recreate=True)
    index_embeddings_qdrant(qdrant_client, COLLECTION_NAME_PRETRAINED, db_embeddings_val_pt, db_image_paths_val_pt, db_person_ids_val_pt)

    k_values = [1, 5, 10]
    db_image_paths_set_val_pt = set(db_image_paths_val_pt)

    precision_pt, search_time_pt = calculate_precision_at_k(qdrant_client, COLLECTION_NAME_PRETRAINED, query_embeddings_val_pt, query_person_ids_val_pt, query_image_paths_val_pt, k_values, db_image_paths_set_val_pt)
    print(f"Pretrained Model (Val Set) - P@K: {precision_pt}, Avg Search Time: {search_time_pt:.4f}s")
    fmr_pt = calculate_fmr_at_fnmr(qdrant_client, COLLECTION_NAME_PRETRAINED, query_embeddings_val_pt, query_person_ids_val_pt, query_image_paths_val_pt, target_fnmr=target_fnmr, db_image_paths_set=db_image_paths_set_val_pt)
    print(f"Pretrained Model (Val Set) - FMR @ FNMR={target_fnmr:.2f}: {fmr_pt:.4f}")
    results_summary["pretrained_on_val"] = {"p@k": precision_pt, "fmr@fnmr": fmr_pt, "search_time": search_time_pt}

    for scenario in qdrant_test_scenarios:
        results_summary[scenario["key"]] = test_qdrant_configuration(
            qdrant_client,
            db_embeddings_val_pt, db_person_ids_val_pt, db_image_paths_val_pt,
            query_embeddings_val_pt, query_person_ids_val_pt, query_image_paths_val_pt,
            k_values, target_fnmr, embedding_dim,
            scenario["hnsw_config"], scenario["quantization_config"],
            scenario["description"]
        )


    # --- Part 3: Fine-tuning with PyTorch Metric Learning (on Training Data) ---
    if not train_image_data_all:
        print("\nSkipping Part 3: Fine-tuning, as no training data was loaded.")
    else:
        print("\n--- Part 3: Fine-tuning Model (on Training Set) ---")
        train_image_paths_ft = [item[0] for item in train_image_data_all]
        train_person_ids_str_ft = [item[1] for item in train_image_data_all]

        # Label encoder for the training data
        train_label_encoder = LabelEncoder()
        train_label_encoder.fit(train_person_ids_str_ft) # Fit on actual training person IDs
        num_classes_train_ft = len(train_label_encoder.classes_)

        if num_classes_train_ft <= 1:
            print(f"ERROR: Training dataset for fine-tuning has only {num_classes_train_ft} classes. Need at least 2. Skipping fine-tuning.")
            # Jump to final summary if fine-tuning is skipped
            return results_summary
        # --------
        # Some augmentations for regularization
        # TODO: можно менять, чтобы уменьшить переобучаемость модели
        transform_ft = T.Compose([T.RandomHorizontalFlip(), T.RandomGrayscale(), transform])
        # --------


        # Use the same transform as the pretrained model initially for fine-tuning dataset
        train_dataset_ft = FIWMetricLearningDataset(train_image_paths_ft, train_person_ids_str_ft, transform_ft, train_label_encoder)

        finetune_model, _ = get_arcface_model_and_transform(pretrained=True) # Start from pretrained

        finetune_model = finetune_model.to(DEVICE)
        ft_embedding_dim = finetune_model.num_features
        classifier = NormedLinear(ft_embedding_dim, num_classes_train_ft, bias=False).to(DEVICE)

        optimizer = optim.AdamW([
            {'params': finetune_model.parameters(), 'lr': LEARNING_RATE_MODEL},
            {'params': classifier.parameters(), 'lr': LEARNING_RATE_LOSS}
        ])

        # Sampler for fine-tuning data
        unique_labels_train_ft, counts_train_ft = np.unique(train_dataset_ft.labels, return_counts=True)
        min_class_size_ft = counts_train_ft.min() if len(counts_train_ft) > 0 else 0
        batch_s_ft = BATCH_SIZE # Default batch size

        if min_class_size_ft < 2 and len(train_dataset_ft) > 0:
            print(f"Warning: Smallest class in fine-tuning data has {min_class_size_ft} samples. MPerClassSampler might fail. Using RandomSampler.")
            train_sampler_ft = torch.utils.data.RandomSampler(train_dataset_ft)
            batch_s_ft = BATCH_SIZE if len(train_dataset_ft) >=BATCH_SIZE else max(1, len(train_dataset_ft))
            train_dataloader_ft = DataLoader(train_dataset_ft, batch_size=batch_s_ft, sampler=train_sampler_ft, drop_last=True if len(train_dataset_ft)>batch_s_ft else False)
        elif len(train_dataset_ft) == 0:
            print("ERROR: Training dataset for fine-tuning is empty. Skipping fine-tuning.")
            return results_summary
        else:
            m_per_class = MIN_IMAGES_PER_CLASS_IN_BATCH
            num_unique_classes_in_train_ft = len(unique_labels_train_ft)
            # Ensure num_classes_per_batch is at least 2 for most miners, and that m_per_class * num_classes_per_batch isn't too large or zero
            num_classes_per_batch = min(max(2, num_unique_classes_in_train_ft // m_per_class if m_per_class > 0 else num_unique_classes_in_train_ft), BATCH_SIZE)
            if num_unique_classes_in_train_ft < num_classes_per_batch : num_classes_per_batch = num_unique_classes_in_train_ft

            if num_classes_per_batch < 2 or m_per_class == 0 or num_classes_per_batch == 0:
                 print(f"Warning: Cannot form valid batches for MPerClassSampler (m={m_per_class}, C_batch={num_classes_per_batch}, C_total={num_unique_classes_in_train_ft}). Using RandomSampler.")
                 train_sampler_ft = torch.utils.data.RandomSampler(train_dataset_ft)
                 batch_s_ft = BATCH_SIZE if len(train_dataset_ft) >=BATCH_SIZE else max(1, len(train_dataset_ft))
                 train_dataloader_ft = DataLoader(train_dataset_ft, batch_size=batch_s_ft, sampler=train_sampler_ft, drop_last=True if len(train_dataset_ft)>batch_s_ft else False)
            else:
                print(f"Using MPerClassSampler for fine-tuning with m={m_per_class}, C_batch={num_classes_per_batch}.")
                train_sampler_ft = samplers.MPerClassSampler(train_dataset_ft.labels, m=m_per_class, length_before_new_iter=len(train_dataset_ft))
                train_dataloader_ft = DataLoader(train_dataset_ft, sampler=train_sampler_ft, batch_size=batch_s_ft)

        print(f"Starting fine-tuning for {NUM_EPOCHS_FINETUNE} epochs...")
        finetune_model.eval()
        for epoch in range(1, NUM_EPOCHS_FINETUNE + 1):
            finetune_model.train()
            total_loss_epoch, num_batches = 0, 0
            for data_b, labels_b in tqdm(train_dataloader_ft, desc=f"Epoch {epoch}/{NUM_EPOCHS_FINETUNE}"):
                data_b, labels_b = data_b.to(DEVICE), labels_b.to(DEVICE)
                optimizer.zero_grad()
                embeddings_b = finetune_model(data_b)
                # --------
                # TODO: You can change the losses here, and combine them
                loss1 = loss_arcface(classifier(embeddings_b), labels_b)
                indices_tuple = miner(embeddings_b, labels_b)
                loss2 = loss_contrastive(embeddings_b, labels_b, indices_tuple)
                combined_loss = loss1  # + 10 * loss2
                # --------
                combined_loss.backward()
                optimizer.step()
                total_loss_epoch += combined_loss.item()
                num_batches += 1
            avg_loss = total_loss_epoch / num_batches if num_batches > 0 else total_loss_epoch
            print(f"Epoch {epoch} Average Loss: {avg_loss:.4f}")
            print("Fine-tuning finished.")
            finetune_model.eval()

            # --- Re-evaluate with the fine-tuned model on Validation Data ---
            print("\nRe-evaluating with fine-tuned model (on Validation Set)...")
            print("Embedding Val-DB images (fine-tuned)...")
            db_embeddings_val_ft, db_person_ids_val_ft, db_image_paths_val_ft = embed_images(finetune_model, transform, db_data_val, desc_prefix="Val-DB Fine-tuned: ")
            if db_embeddings_val_ft.size == 0: print("Failed to generate Val-DB embeddings (fine-tuned). Skipping FT eval."); return results_summary

            print("Embedding Val-Query images (fine-tuned)...")
            query_embeddings_val_ft, query_person_ids_val_ft, query_image_paths_val_ft = embed_images(finetune_model, transform, query_data_val, desc_prefix="Val-Query Fine-tuned: ")
            if query_embeddings_val_ft.size == 0: print("Failed to generate Val-Query embeddings (fine-tuned). Skipping FT eval."); return results_summary

            setup_qdrant_collection(qdrant_client, COLLECTION_NAME_FINETUNED, ft_embedding_dim, Distance.COSINE, recreate=True)
            index_embeddings_qdrant(qdrant_client, COLLECTION_NAME_FINETUNED, db_embeddings_val_ft, db_image_paths_val_ft, db_person_ids_val_ft)

            db_image_paths_set_val_ft = set(db_image_paths_val_ft)
            precision_ft, search_time_ft = calculate_precision_at_k(qdrant_client, COLLECTION_NAME_FINETUNED, query_embeddings_val_ft, query_person_ids_val_ft, query_image_paths_val_ft, k_values, db_image_paths_set_val_ft)
            print(f"Fine-tuned Model (Val Set) - P@K: {precision_ft}, Avg Search Time: {search_time_ft:.4f}s")
            fmr_ft = calculate_fmr_at_fnmr(qdrant_client, COLLECTION_NAME_FINETUNED, query_embeddings_val_ft, query_person_ids_val_ft, query_image_paths_val_ft, target_fnmr=target_fnmr, db_image_paths_set=db_image_paths_set_val_ft)
            print(f"Fine-tuned Model (Val Set) - FMR @ FNMR={target_fnmr}: {fmr_ft:.4f}")
            results_summary["finetuned_on_val"] = {"p@k": precision_ft, "fmr@fnmr": fmr_ft, "search_time": search_time_ft}

    return results_summary

# Execute the main homework logic
final_results_train_val = run_homework_train_val_split(target_fnmr=target_fnmr)


--- Part 1: Pretrained Model Evaluation (on Validation Set) ---


  with torch.cuda.amp.autocast(self.fp16):


Embedding DB images from Validation Set (pretrained)...


Val-DB Pretrained: Embedding images:   0%|          | 0/127 [00:00<?, ?it/s]

Embedding train images as DB...


Val-DB Pretrained: Embedding images:   0%|          | 0/495 [00:00<?, ?it/s]

Embedding Query images from Validation Set (pretrained)...


Val-Query Pretrained: Embedding images:   0%|          | 0/32 [00:00<?, ?it/s]

Creating collection fiw_pretrained_val...
Collection fiw_pretrained_val created.
Indexed 4036 embeddings into fiw_pretrained_val


Calculating P@K:   0%|          | 0/1009 [00:00<?, ?it/s]

Pretrained Model (Val Set) - P@K: {1: np.float64(0.9286422200198216), 5: np.float64(0.7728444003964322), 10: np.float64(0.5709613478691774)}, Avg Search Time: 0.0087s


Collecting scores for FMR@FNMR:   0%|          | 0/1009 [00:00<?, ?it/s]

  Target FNMR: 0.1000, Actual FNMR: 0.1000 at threshold -0.1209 (lower score is better)
Pretrained Model (Val Set) - FMR @ FNMR=0.10: 0.6497

--- Тестирование конфигурации Qdrant: Предобученная Модель Val+Train (Стандартный Qdrant) ---
Настройка коллекции 'test_collection' с HNSW=None, Квантование=None
Creating collection test_collection...
Collection test_collection created.
Indexed 4036 embeddings into test_collection


Calculating P@K:   0%|          | 0/1009 [00:00<?, ?it/s]

Предобученная Модель Val+Train (Стандартный Qdrant) - P@K: {1: np.float64(0.9286422200198216), 5: np.float64(0.7728444003964322), 10: np.float64(0.5709613478691774)}, Среднее Время Поиска: 0.0074с


Collecting scores for FMR@FNMR:   0%|          | 0/1009 [00:00<?, ?it/s]

  Target FNMR: 0.1000, Actual FNMR: 0.1000 at threshold -0.1209 (lower score is better)
Предобученная Модель Val+Train (Стандартный Qdrant) - FMR @ FNMR=0.10: 0.6497

--- Тестирование конфигурации Qdrant: Предобученная Модель Val+Train (Оптимизированный Qdrant - Скалярное Квантование) ---
Настройка коллекции 'test_collection' с HNSW=m=1 ef_construct=4 full_scan_threshold=None max_indexing_threads=None on_disk=None payload_m=None, Квантование=scalar=ScalarQuantizationConfig(type=<ScalarType.INT8: 'int8'>, quantile=0.99, always_ram=True)
Collection test_collection deleted.
Creating collection test_collection...
Collection test_collection created.
Indexed 4036 embeddings into test_collection


Calculating P@K:   0%|          | 0/1009 [00:00<?, ?it/s]

Предобученная Модель Val+Train (Оптимизированный Qdrant - Скалярное Квантование) - P@K: {1: np.float64(0.9286422200198216), 5: np.float64(0.7728444003964322), 10: np.float64(0.5709613478691774)}, Среднее Время Поиска: 0.0093с


Collecting scores for FMR@FNMR:   0%|          | 0/1009 [00:00<?, ?it/s]

  Target FNMR: 0.1000, Actual FNMR: 0.1000 at threshold -0.1209 (lower score is better)
Предобученная Модель Val+Train (Оптимизированный Qdrant - Скалярное Квантование) - FMR @ FNMR=0.10: 0.6497

--- Тестирование конфигурации Qdrant: Предобученная Модель Val+Train (BQ) ---
Настройка коллекции 'test_collection' с HNSW=None, Квантование=binary=BinaryQuantizationConfig(always_ram=True)
Collection test_collection deleted.
Creating collection test_collection...
Collection test_collection created.
Indexed 4036 embeddings into test_collection


Calculating P@K:   0%|          | 0/1009 [00:00<?, ?it/s]

Предобученная Модель Val+Train (BQ) - P@K: {1: np.float64(0.9286422200198216), 5: np.float64(0.7728444003964322), 10: np.float64(0.5709613478691774)}, Среднее Время Поиска: 0.0077с


Collecting scores for FMR@FNMR:   0%|          | 0/1009 [00:00<?, ?it/s]

  Target FNMR: 0.1000, Actual FNMR: 0.1000 at threshold -0.1209 (lower score is better)
Предобученная Модель Val+Train (BQ) - FMR @ FNMR=0.10: 0.6497

--- Тестирование конфигурации Qdrant: Предобученная Модель Val+Train (PQ) ---
Настройка коллекции 'test_collection' с HNSW=None, Квантование=product=ProductQuantizationConfig(compression=<CompressionRatio.X4: 'x4'>, always_ram=True)
Collection test_collection deleted.
Creating collection test_collection...
Collection test_collection created.
Indexed 4036 embeddings into test_collection


Calculating P@K:   0%|          | 0/1009 [00:00<?, ?it/s]

Предобученная Модель Val+Train (PQ) - P@K: {1: np.float64(0.9286422200198216), 5: np.float64(0.7728444003964322), 10: np.float64(0.5709613478691774)}, Среднее Время Поиска: 0.0092с


Collecting scores for FMR@FNMR:   0%|          | 0/1009 [00:00<?, ?it/s]

  Target FNMR: 0.1000, Actual FNMR: 0.1000 at threshold -0.1209 (lower score is better)
Предобученная Модель Val+Train (PQ) - FMR @ FNMR=0.10: 0.6497

--- Тестирование конфигурации Qdrant: Предобученная Модель Val+Train (Агрессивное HNSW) ---
Настройка коллекции 'test_collection' с HNSW=m=16 ef_construct=200 full_scan_threshold=1000000 max_indexing_threads=None on_disk=None payload_m=None, Квантование=None
Collection test_collection deleted.
Creating collection test_collection...
Collection test_collection created.
Indexed 4036 embeddings into test_collection


Calculating P@K:   0%|          | 0/1009 [00:00<?, ?it/s]

Предобученная Модель Val+Train (Агрессивное HNSW) - P@K: {1: np.float64(0.9286422200198216), 5: np.float64(0.7728444003964322), 10: np.float64(0.5709613478691774)}, Среднее Время Поиска: 0.0082с


Collecting scores for FMR@FNMR:   0%|          | 0/1009 [00:00<?, ?it/s]

  Target FNMR: 0.1000, Actual FNMR: 0.1000 at threshold -0.1209 (lower score is better)
Предобученная Модель Val+Train (Агрессивное HNSW) - FMR @ FNMR=0.10: 0.6497

--- Тестирование конфигурации Qdrant: Предобученная Модель Val (Агрессивное HNSW + Квантование) ---
Настройка коллекции 'test_collection' с HNSW=m=16 ef_construct=200 full_scan_threshold=1000000 max_indexing_threads=None on_disk=None payload_m=None, Квантование=scalar=ScalarQuantizationConfig(type=<ScalarType.INT8: 'int8'>, quantile=0.99, always_ram=True)
Collection test_collection deleted.
Creating collection test_collection...
Collection test_collection created.
Indexed 4036 embeddings into test_collection


Calculating P@K:   0%|          | 0/1009 [00:00<?, ?it/s]

Предобученная Модель Val (Агрессивное HNSW + Квантование) - P@K: {1: np.float64(0.9286422200198216), 5: np.float64(0.7728444003964322), 10: np.float64(0.5709613478691774)}, Среднее Время Поиска: 0.0094с


Collecting scores for FMR@FNMR:   0%|          | 0/1009 [00:00<?, ?it/s]

  Target FNMR: 0.1000, Actual FNMR: 0.1000 at threshold -0.1209 (lower score is better)
Предобученная Модель Val (Агрессивное HNSW + Квантование) - FMR @ FNMR=0.10: 0.6497

--- Part 3: Fine-tuning Model (on Training Set) ---
Using MPerClassSampler for fine-tuning with m=4, C_batch=64.
Starting fine-tuning for 2 epochs...


  with torch.cuda.amp.autocast(self.fp16):


Epoch 1/2:   0%|          | 0/214 [00:00<?, ?it/s]

Epoch 1 Average Loss: 33.4654
Fine-tuning finished.

Re-evaluating with fine-tuned model (on Validation Set)...
Embedding Val-DB images (fine-tuned)...


Val-DB Fine-tuned: Embedding images:   0%|          | 0/127 [00:00<?, ?it/s]

Embedding Val-Query images (fine-tuned)...


Val-Query Fine-tuned: Embedding images:   0%|          | 0/32 [00:00<?, ?it/s]

Creating collection fiw_finetuned_val...
Collection fiw_finetuned_val created.
Indexed 4036 embeddings into fiw_finetuned_val


Calculating P@K:   0%|          | 0/1009 [00:00<?, ?it/s]

Fine-tuned Model (Val Set) - P@K: {1: np.float64(0.931615460852329), 5: np.float64(0.7726461843409316), 10: np.float64(0.576808721506442)}, Avg Search Time: 0.0074s


Collecting scores for FMR@FNMR:   0%|          | 0/1009 [00:00<?, ?it/s]

  Target FNMR: 0.1000, Actual FNMR: 0.1000 at threshold -0.1194 (lower score is better)
Fine-tuned Model (Val Set) - FMR @ FNMR=0.1: 0.6466


Epoch 2/2:   0%|          | 0/214 [00:00<?, ?it/s]

Epoch 2 Average Loss: 24.3649
Fine-tuning finished.

Re-evaluating with fine-tuned model (on Validation Set)...
Embedding Val-DB images (fine-tuned)...


Val-DB Fine-tuned: Embedding images:   0%|          | 0/127 [00:00<?, ?it/s]

Embedding Val-Query images (fine-tuned)...


Val-Query Fine-tuned: Embedding images:   0%|          | 0/32 [00:00<?, ?it/s]

Collection fiw_finetuned_val deleted.
Creating collection fiw_finetuned_val...
Collection fiw_finetuned_val created.
Indexed 4036 embeddings into fiw_finetuned_val


Calculating P@K:   0%|          | 0/1009 [00:00<?, ?it/s]

Fine-tuned Model (Val Set) - P@K: {1: np.float64(0.9385530227948464), 5: np.float64(0.7781962338949455), 10: np.float64(0.5836471754212091)}, Avg Search Time: 0.0096s


Collecting scores for FMR@FNMR:   0%|          | 0/1009 [00:00<?, ?it/s]

  Target FNMR: 0.1000, Actual FNMR: 0.1000 at threshold -0.1158 (lower score is better)
Fine-tuned Model (Val Set) - FMR @ FNMR=0.1: 0.6488


## Результаты экспериментов

* (20 баллов) результаты успешно посчитались и отобразились в ячейке
* (10 баллов) за дополнительные конфигурации qdrant (как описано выше)
* (10 баллов) за то, что Precision@K увеличилось для всех K после дообучения
* (10 баллов) за относительное улучшение метрики Precision@10 на 2% (например `0.576` улучшлось до `0.588`)

In [None]:
#@markdown Эта ячейка выведет сводку собранных метрик.

if 'final_results_train_val' in locals() and final_results_train_val:
    print("\n--- Итоговая сводка метрик (Оценка с разделением Train/Val) ---")
    for stage, metrics in final_results_train_val.items():
        # Форматирование P@K результатов
        pk_items = metrics.get("p@k", {})
        pk_str = ", ".join([f"P@{k}={v:.3f}" for k, v in pk_items.items()]) if pk_items else "N/A"

        fmr_val = metrics.get('fmr@fnmr', -1.0) # -1.0 как индикатор отсутствия данных
        fmr_str = f"FMR@FNMR={target_fnmr:.2f}: {fmr_val:.4f}" if fmr_val != -1.0 else f"FMR@FNMR={target_fnmr:.2f}: N/A"

        search_time_val = metrics.get('search_time', -1.0)
        time_str = f"Среднее время поиска: {search_time_val:.4f} сек" if search_time_val != -1.0 else "Среднее время поиска: N/A"

        print(f"\nЭтап: {stage.replace('_', ' ').title()}")
        print(f"  Метрики оценены на: Валидационном наборе") # Уточнение набора для оценки
        print(f"  Precision@K: {pk_str}")
        print(f"  {fmr_str}")
        print(f"  {time_str}")
else:
    print("Нет результатов для отображения. Убедитесь, что основной скрипт (Ячейка 6) выполнен успешно.")



--- Итоговая сводка метрик (Оценка с разделением Train/Val) ---

Этап: Pretrained On Val
  Метрики оценены на: Валидационном наборе
  Precision@K: P@1=0.929, P@5=0.773, P@10=0.571
  FMR@FNMR=0.10: 0.6497
  Среднее время поиска: 0.0087 сек

Этап: Pretrained On Val With Train
  Метрики оценены на: Валидационном наборе
  Precision@K: P@1=0.929, P@5=0.773, P@10=0.571
  FMR@FNMR=0.10: 0.6497
  Среднее время поиска: 0.0074 сек

Этап: Pretrained On Val Optimized Qdrant
  Метрики оценены на: Валидационном наборе
  Precision@K: P@1=0.929, P@5=0.773, P@10=0.571
  FMR@FNMR=0.10: 0.6497
  Среднее время поиска: 0.0093 сек

Этап: Pretrained On Val Binary Quant
  Метрики оценены на: Валидационном наборе
  Precision@K: P@1=0.929, P@5=0.773, P@10=0.571
  FMR@FNMR=0.10: 0.6497
  Среднее время поиска: 0.0077 сек

Этап: Pretrained On Val Product Quant
  Метрики оценены на: Валидационном наборе
  Precision@K: P@1=0.929, P@5=0.773, P@10=0.571
  FMR@FNMR=0.10: 0.6497
  Среднее время поиска: 0.0092 сек

Этап