# Metric Learning

In [1]:
import os
import random
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
import fiftyone.zoo as foz
import wandb  # импортируем wandb
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def fix_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

fix_seed()

## Эксперимент 1

Добавим к baseline дополнительные аугментации и оптимизатор AdamW.

In [3]:
class TripletFODataset(Dataset):
    def __init__(self, samples, transform=None, label_to_idx=None):
        self.transform = transform
        if label_to_idx is None:
            labels = sorted({label for _, label in samples})
            self.label_to_idx = {label: idx for idx, label in enumerate(labels)}
        else:
            self.label_to_idx = label_to_idx

        self.samples = [(filepath, self.label_to_idx[label]) for filepath, label in samples]
        self.class_to_indices = {}
        for idx, (_, label) in enumerate(self.samples):
            self.class_to_indices.setdefault(label, []).append(idx)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        filepath, anchor_label = self.samples[index]
        anchor_img = Image.open(filepath).convert("RGB")
        if self.transform:
            anchor_img = self.transform(anchor_img)

        positive_index = index
        while positive_index == index:
            positive_index = random.choice(self.class_to_indices[anchor_label])
        positive_filepath, _ = self.samples[positive_index]
        positive_img = Image.open(positive_filepath).convert("RGB")
        if self.transform:
            positive_img = self.transform(positive_img)

        negative_label = anchor_label
        while negative_label == anchor_label:
            negative_label = random.choice(list(self.class_to_indices.keys()))
        negative_index = random.choice(self.class_to_indices[negative_label])
        negative_filepath, negative_label = self.samples[negative_index]
        negative_img = Image.open(negative_filepath).convert("RGB")
        if self.transform:
            negative_img = self.transform(negative_img)

        return (anchor_img, positive_img, negative_img,
                torch.tensor(anchor_label), torch.tensor(negative_label))

In [4]:
class EmbeddingNet(nn.Module):
    def __init__(self, backbone_name="levit_128", embedding_dim=128, pretrained=True):
        """
        Загружаем модель LeViT-128 для получения эмбеддингов
        """
        super(EmbeddingNet, self).__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=pretrained, num_classes=1000)
        self.backbone.reset_classifier(0)
        backbone_features = self.backbone.num_features
        self.fc = nn.Linear(backbone_features, embedding_dim)

    def forward(self, x):
        features = self.backbone(x)
        embedding = self.fc(features)
        embedding = nn.functional.normalize(embedding, p=2, dim=1)
        return embedding

In [5]:
def train_one_epoch(model, dataloader, optimizer, device, margin=1.0, semi_hard=True):
    model.train()
    running_loss = 0.0
    triplet_loss_fn = nn.TripletMarginLoss(margin=margin, p=2)
    
    for batch_idx, batch in enumerate(dataloader):
        anchor, positive, negative, anchor_label, negative_label = batch

        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)
        anchor_label = anchor_label.to(device)
        negative_label = negative_label.to(device)

        optimizer.zero_grad()

        anchor_out = model(anchor)
        positive_out = model(positive)
        negative_out = model(negative)

        if semi_hard:
            # Реализуем semi-hard mining вручную
            candidate_embeddings = torch.cat([anchor_out, negative_out], dim=0)
            candidate_labels = torch.cat([anchor_label, negative_label], dim=0)
            batch_loss = 0.0
            batch_size = anchor_out.size(0)

            for i in range(batch_size):
                d_ap = torch.norm(anchor_out[i] - positive_out[i], p=2)
                mask = (candidate_labels != anchor_label[i])
                if mask.sum() == 0:
                    chosen_negative = negative_out[i]
                else:
                    candidate_emb = candidate_embeddings[mask]
                    d_an = torch.norm(anchor_out[i].unsqueeze(0) - candidate_emb, p=2, dim=1)
                    semi_hard_mask = (d_an > d_ap) & (d_an < d_ap + margin)
                    if semi_hard_mask.sum() > 0:
                        candidate_d_an = d_an[semi_hard_mask]
                        chosen_idx = torch.argmin(candidate_d_an)
                        chosen_negative = candidate_emb[semi_hard_mask][chosen_idx]
                    else:
                        chosen_negative = negative_out[i]
                d_an_final = torch.norm(anchor_out[i] - chosen_negative, p=2)
                loss_i = torch.relu(d_ap - d_an_final + margin)
                batch_loss += loss_i
            loss = batch_loss / batch_size
        else:
            loss = triplet_loss_fn(anchor_out, positive_out, negative_out)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 10 == 0:
            print(f"Batch {batch_idx}/{len(dataloader)}: Loss = {loss.item():.4f}")

    avg_loss = running_loss / len(dataloader)
    wandb.log({"Train/Triplet Loss": avg_loss})
    return avg_loss


In [5]:
def compute_base_embeddings(model, dataloader, device):
    model.eval()
    sums = {}
    counts = {}
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            embeddings = model(images)  # извлекаем эмбеддинги
            labels = labels.to(device)
            for emb, label in zip(embeddings, labels):
                label = label.item()
                if label not in sums:
                    sums[label] = emb.clone()
                    counts[label] = 1
                else:
                    sums[label] += emb
                    counts[label] += 1
    base_embeddings = {label: sums[label] / counts[label] for label in sums}
    return base_embeddings

def validate_classification(model, base_embeddings, dataloader, device):
    model.eval()
    total = 0
    correct = 0

    base_labels = []
    base_embs = []
    for label, emb in base_embeddings.items():
        base_labels.append(label)
        base_embs.append(emb.unsqueeze(0))
    base_embs = torch.cat(base_embs, dim=0)

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            embeddings = model(images)  # получаем эмбеддинги
            dists = torch.cdist(embeddings, base_embs, p=2)
            preds = torch.argmin(dists, dim=1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

    accuracy = correct / total
    return accuracy

class Caltech256ClassificationDataset(Dataset):
    def __init__(self, samples, transform=None, label_to_idx=None):
        self.transform = transform
        if label_to_idx is None:
            labels = sorted({label for _, label in samples})
            self.label_to_idx = {label: idx for idx, label in enumerate(labels)}
        else:
            self.label_to_idx = label_to_idx
        self.samples = [(filepath, self.label_to_idx[label]) for filepath, label in samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        filepath, label = self.samples[index]
        img = Image.open(filepath).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, torch.tensor(label)

In [None]:
wandb.init(project="ITMO_metric_learning_caltech256_triplet", config={
    "backbone": "levit_128",
    "embedding_dim": 128,
    "lr": 1e-4,
    "batch_size": 32,
    "num_epochs": 2,
    "margin": 1.0,
    "semi_hard": True,
})
config = wandb.config

dataset = foz.load_zoo_dataset("caltech256")
print(f"Загружен Caltech256: {len(dataset)} образцов")

# список имен файлов валидационной выборки
val_df = pd.read_csv("val.csv")
val_filenames = set(val_df["filename"].tolist())

train_samples = []
val_samples = []
for sample in dataset:
    filename = os.path.basename(sample.filepath)
    label = sample["ground_truth"]["label"] if "ground_truth" in sample and sample["ground_truth"] is not None else sample.get("label", None)
    if label is None:
        continue
    if filename in val_filenames:
        val_samples.append((sample.filepath, label))
    else:
        train_samples.append((sample.filepath, label))

print(f"Обучающих сэмплов: {len(train_samples)}")
print(f"Валидационных сэмплов: {len(val_samples)}")

all_labels = {label for _, label in (train_samples + val_samples)}
labels_sorted = sorted(all_labels)
label_to_idx = {label: idx for idx, label in enumerate(labels_sorted)}
num_classes = len(label_to_idx)

transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
])

train_dataset = TripletFODataset(train_samples, transform=transform, label_to_idx=label_to_idx)
val_dataset = Caltech256ClassificationDataset(val_samples, transform=val_transform, label_to_idx=label_to_idx)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Используем устройство: {device}")

model = EmbeddingNet(backbone_name=config.backbone, embedding_dim=config.embedding_dim, pretrained=True)
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=config.lr)

num_epochs = config.num_epochs
for epoch in range(num_epochs):
    print(f"\nЭпоха {epoch+1}/{num_epochs}")
    train_loss = train_one_epoch(model, train_loader, optimizer, device, margin=config.margin, semi_hard=config.semi_hard)
    print(f"Epoch {epoch+1} - Average Training Loss: {train_loss:.4f}")
    
    os.makedirs("train_triplet", exist_ok=True)
    model_path = f"train_triplet/model_epoch_{epoch+1}.pth"
    torch.save(model.state_dict(), model_path)
    wandb.save(model_path)
    
    wandb.log({"epoch": epoch+1})

print("Формирование бейз-эмбеддингов для каждого класса...")
base_embeddings = compute_base_embeddings(model, DataLoader(
    Caltech256ClassificationDataset(train_samples, transform=transform, label_to_idx=label_to_idx),
    batch_size=config.batch_size, shuffle=False, num_workers=4), device)

print("Валидация модели по задаче классификации...")
accuracy = validate_classification(model, base_embeddings, val_loader, device)
print(f"Accuracy: {accuracy:.4f}")
wandb.log({"Val/Accuracy": accuracy})

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfelisfur[0m ([33mfelisfur-wb[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Dataset already downloaded
Loading existing dataset 'caltech256'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use
Загружен Caltech256: 30607 образцов
Обучающих сэмплов: 24485
Валидационных сэмплов: 6122
Используем устройство: cuda

Эпоха 1/2
Batch 0/766: Loss = 0.9171
Batch 10/766: Loss = 0.9181
Batch 20/766: Loss = 0.8939
Batch 30/766: Loss = 0.9011
Batch 40/766: Loss = 0.9059
Batch 50/766: Loss = 0.9040
Batch 60/766: Loss = 0.8909
Batch 70/766: Loss = 0.8832
Batch 80/766: Loss = 0.8460
Batch 90/766: Loss = 0.9019
Batch 100/766: Loss = 0.8345
Batch 110/766: Loss = 0.8863
Batch 120/766: Loss = 0.8812
Batch 130/766: Loss = 0.8526
Batch 140/766: Loss = 0.8374
Batch 150/766: Loss = 0.8177
Batch 160/766: Loss = 0.8581
Batch 170/766: Loss = 0.8838
Batch 180/766: Loss = 0.7860
Batch 190/766: Loss = 0.8004
Batch 200/766: Loss = 0.7498
Batch 210/766: Loss = 0.8045
Batch 220/766: Loss = 0.7685
Batch 230/766: Loss = 0.7644
Batch 240/766: Loss = 0

In [8]:
wandb.finish()

0,1
Train/Triplet Loss,█▁
Val/Accuracy,▁
epoch,▁█

0,1
Train/Triplet Loss,0.59002
Val/Accuracy,0.80072
epoch,2.0


## Эксперимент 2

Добавим в эксперимент #1 ещё один loss - contrastive loss

In [9]:
def train_one_epoch(model, dataloader, optimizer, device, margin=1.0, semi_hard=True, contrastive_weight=1.0):
    """
    Функция обучения, которая рассчитывает одновременно два лосса:
      - Triplet loss с semi-hard mining (если semi_hard=True, иначе стандартный TripletMarginLoss)
      - Contrastive loss для пары (anchor, positive) и (anchor, negative)
    Итоговый loss = triplet_loss + contrastive_weight * contrastive_loss.
    """
    model.train()
    running_loss = 0.0
    triplet_loss_fn = nn.TripletMarginLoss(margin=margin, p=2)

    for batch_idx, batch in enumerate(dataloader):
        anchor, positive, negative, anchor_label, negative_label = batch

        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)
        anchor_label = anchor_label.to(device)
        negative_label = negative_label.to(device)

        optimizer.zero_grad()

        anchor_out = model(anchor)
        positive_out = model(positive)
        negative_out = model(negative)

        if semi_hard:
            candidate_embeddings = torch.cat([anchor_out, negative_out], dim=0)
            candidate_labels = torch.cat([anchor_label, negative_label], dim=0)

            batch_triplet_loss = 0.0
            batch_contrastive_loss = 0.0
            batch_size = anchor_out.size(0)
            for i in range(batch_size):
                d_ap = torch.norm(anchor_out[i] - positive_out[i], p=2)
                mask = (candidate_labels != anchor_label[i])
                if mask.sum() == 0:
                    chosen_negative = negative_out[i]
                else:
                    candidate_emb = candidate_embeddings[mask]
                    d_an = torch.norm(anchor_out[i].unsqueeze(0) - candidate_emb, p=2, dim=1)
                    semi_hard_mask = (d_an > d_ap) & (d_an < d_ap + margin)
                    if semi_hard_mask.sum() > 0:
                        candidate_d_an = d_an[semi_hard_mask]
                        chosen_idx = torch.argmin(candidate_d_an)
                        chosen_negative = candidate_emb[semi_hard_mask][chosen_idx]
                    else:
                        chosen_negative = negative_out[i]
                d_an_final = torch.norm(anchor_out[i] - chosen_negative, p=2)
                # Triplet loss:
                sample_triplet_loss = torch.relu(d_ap - d_an_final + margin)
                # Contrastive loss:
                #   Для позитивной пары: стремимся уменьшить квадрат расстояния
                #   Для негативной пары: стремимся, чтобы расстояние было не меньше margin
                sample_contrastive_loss = d_ap**2 + torch.relu(margin - d_an_final)**2

                batch_triplet_loss += sample_triplet_loss
                batch_contrastive_loss += sample_contrastive_loss

            loss_triplet = batch_triplet_loss / batch_size
            loss_contrastive = batch_contrastive_loss / batch_size
            loss = loss_triplet + contrastive_weight * loss_contrastive

        else:
            d_ap = torch.norm(anchor_out - positive_out, p=2, dim=1)
            d_an = torch.norm(anchor_out - negative_out, p=2, dim=1)
            loss_triplet = triplet_loss_fn(anchor_out, positive_out, negative_out)
            loss_contrastive = torch.mean(d_ap**2 + torch.relu(margin - d_an)**2)
            loss = loss_triplet + contrastive_weight * loss_contrastive

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 10 == 0:
            print(f"Batch {batch_idx}/{len(dataloader)}: Total Loss = {loss.item():.4f} | Triplet = {loss_triplet.item():.4f} | Contrastive = {loss_contrastive.item():.4f}")
        wandb.log({
            "Train/Triplet Loss": loss_triplet.item(),
            "Train/Contrastive Loss": loss_contrastive.item(),
            "Train/Total Loss": loss.item()
        })

    avg_loss = running_loss / len(dataloader)
    return avg_loss

In [10]:
wandb.init(project="ITMO_metric_learning_caltech256_triplet_contractive", config={
    "backbone": "levit_128",
    "embedding_dim": 128,
    "lr": 1e-4,
    "batch_size": 32,
    "num_epochs": 2,
    "margin": 1.0,
    "semi_hard": True,
    "contrastive_weight": 1.0,  # вес для contrastive loss
})
config = wandb.config

In [11]:
model = EmbeddingNet(backbone_name=config.backbone, embedding_dim=config.embedding_dim, pretrained=True)
model.to(device)

EmbeddingNet(
  (backbone): LevitDistilled(
    (stem): Stem16(
      (conv1): ConvNorm(
        (linear): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act1): Hardswish()
      (conv2): ConvNorm(
        (linear): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act2): Hardswish()
      (conv3): ConvNorm(
        (linear): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (act3): Hardswish()
      (conv4): ConvNorm(
        (linear): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True,

In [12]:
optimizer = optim.AdamW(model.parameters(), lr=config.lr)

In [None]:
num_epochs = config.num_epochs
for epoch in range(num_epochs):
    print(f"\nЭпоха {epoch+1}/{num_epochs}")
    train_loss = train_one_epoch(model, train_loader, optimizer, device,
                                 margin=config.margin,
                                 semi_hard=config.semi_hard,
                                 contrastive_weight=config.contrastive_weight)
    print(f"Epoch {epoch+1} - Average Training Loss: {train_loss:.4f}")
    
    # Сохраняем модель
    os.makedirs("train_triplet", exist_ok=True)
    model_path = f"train_triplet/model_epoch_{epoch+1}.pth"
    torch.save(model.state_dict(), model_path)
    wandb.save(model_path)
    
    wandb.log({"epoch": epoch+1})

# Формируем бейз-эмбеддинги по обучающему набору для k-NN классификации
print("Формирование бейз-эмбеддингов для каждого класса...")
train_classification_dataset = Caltech256ClassificationDataset(train_samples, transform=val_transform, label_to_idx=label_to_idx)
base_embeddings = compute_base_embeddings(model, DataLoader(train_classification_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4), device)

print("Валидация модели по задаче классификации...")
accuracy = validate_classification(model, base_embeddings, val_loader, device)
print(f"Accuracy: {accuracy:.4f}")
wandb.log({"Val/Accuracy": accuracy})


Эпоха 1/2
Batch 0/766: Total Loss = 2.2602 | Triplet = 0.9182 | Contrastive = 1.3419
Batch 10/766: Total Loss = 2.3805 | Triplet = 0.9408 | Contrastive = 1.4397
Batch 20/766: Total Loss = 2.2956 | Triplet = 0.9306 | Contrastive = 1.3650
Batch 30/766: Total Loss = 2.4691 | Triplet = 0.9507 | Contrastive = 1.5185
Batch 40/766: Total Loss = 2.1917 | Triplet = 0.9287 | Contrastive = 1.2629
Batch 50/766: Total Loss = 2.2556 | Triplet = 0.9340 | Contrastive = 1.3217
Batch 60/766: Total Loss = 2.2193 | Triplet = 0.9476 | Contrastive = 1.2717
Batch 70/766: Total Loss = 2.0330 | Triplet = 0.9059 | Contrastive = 1.1271
Batch 80/766: Total Loss = 2.1117 | Triplet = 0.9255 | Contrastive = 1.1861
Batch 90/766: Total Loss = 2.1152 | Triplet = 0.9031 | Contrastive = 1.2121
Batch 100/766: Total Loss = 2.1129 | Triplet = 0.9394 | Contrastive = 1.1734
Batch 110/766: Total Loss = 1.9723 | Triplet = 0.9212 | Contrastive = 1.0510
Batch 120/766: Total Loss = 1.8925 | Triplet = 0.9202 | Contrastive = 0.9723

In [14]:
wandb.finish()

0,1
Train/Contrastive Loss,██▆▄▅▄▄▃▄▃▃▃▂▂▃▃▃▃▂▁▂▃▂▂▂▂▁▂▁▂▂▂▂▁▁▁▂▂▂▁
Train/Total Loss,███▆▆▆▄▄▅▃▃▃▄▂▂▃▃▄▃▄▂▃▃▂▂▂▂▂▁▂▁▂▃▂▂▂▁▃▁▂
Train/Triplet Loss,███▇▇▇▆█▇▇▇▆▇▆▄▅▃▅▄▇▄█▆▃▅▂▃▅▅▄▄▁▄▁▂▆▆▄▂▃
Val/Accuracy,▁
epoch,▁█

0,1
Train/Contrastive Loss,0.58194
Train/Total Loss,1.47881
Train/Triplet Loss,0.89687
Val/Accuracy,0.77181
epoch,2.0


## Эксперимент 3

Pair-based со сложной стратегией семплирования

In [7]:
class PairFODataset(Dataset):
    """
    Отдаёт одну картинку + её числовой класс.
    """
    def __init__(self, samples, transform=None, label_to_idx=None):
        self.tr = transform
        if label_to_idx is None:
            labels = sorted({l for _, l in samples})
            self.label_to_idx = {l: i for i, l in enumerate(labels)}
        else:
            self.label_to_idx = label_to_idx
        self.samples = [(fp, self.label_to_idx[l]) for fp, l in samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        fp, y = self.samples[idx]
        img = Image.open(fp).convert("RGB")
        if self.tr: img = self.tr(img)
        return img, torch.tensor(y)

Hard sampling для положительного класса и semi‑hard / hard для отрицательного

In [None]:
def mine_pairs(embeds, labels, margin=0.2, semi_hard=True):
    """
    embeds : (B, D)   — L2‑нормированные эмбеддинги батча
    labels : (B,) int — метки
    Возвращает два списка индексов (i, j) положительных и отрицательных пар
    """
    with torch.no_grad():
        dists = torch.cdist(embeds, embeds, p=2)

    pos_pairs, neg_pairs = [], []
    B = embeds.size(0)
    for i in range(B):
        same      = (labels == labels[i]).nonzero(as_tuple=False).view(-1)
        diff      = (labels != labels[i]).nonzero(as_tuple=False).view(-1)
        same = same[same != i]
        if same.numel() == 0 or diff.numel() == 0:
            continue                           # ничего не можем добыть

        # positive
        j = same[torch.argmax(dists[i][same])] # hardest positive   (самый далёкий)

        # negative
        d_ap = dists[i, j].item()
        if semi_hard:
            mask = (dists[i][diff] > d_ap) & (dists[i][diff] < d_ap + margin)
            candidates = diff[mask]
            if candidates.numel():             # semi‑hard нашёлся
                k = candidates[torch.argmin(dists[i][candidates])]
            else:                              # hardest negative
                k = diff[torch.argmin(dists[i][diff])]
        else:                                  # полностью hard
            k = diff[torch.argmin(dists[i][diff])]

        pos_pairs.append((i, j))
        neg_pairs.append((i, k))
    return pos_pairs, neg_pairs

In [9]:
def train_pair_epoch(model, loader, optimizer, device,
                     margin=0.2, semi_hard=True):
    model.train()
    contrastive = nn.CosineEmbeddingLoss(margin=margin)
    running = 0.0

    for imgs, lbs in loader:
        imgs, lbs = imgs.to(device), lbs.to(device)
        optimizer.zero_grad()
        embeds = model(imgs)

        # pair mining
        pos, neg = mine_pairs(embeds.detach(), lbs, margin, semi_hard)

        if not pos:
            continue

        # готовим пары и таргеты
        idx_a  = torch.tensor([i for i, _ in pos + neg]).to(device)
        idx_b  = torch.tensor([j for _, j in pos + neg]).to(device)
        targets = torch.cat([torch.ones(len(pos)),
                             -torch.ones(len(neg))]).to(device)

        out_a, out_b = embeds[idx_a], embeds[idx_b]
        loss = contrastive(out_a, out_b, targets)
        loss.backward()
        optimizer.step()

        running += loss.item()

    avg_loss = running / len(loader)
    wandb.log({"Train/Contrastive Loss": avg_loss})
    return avg_loss

In [19]:
wandb.init(project="ITMO_metric_learning_caltech256_pair_sampler", config={
    "backbone": "levit_128", "embedding_dim": 128,
    "lr": 1e-4, "batch_size": 64,
    "num_epochs": 2, "margin": 0.2, "semi_hard": True,
})
cfg = wandb.config

In [20]:
train_ds = PairFODataset(train_samples, transform, label_to_idx)
val_ds = PairFODataset(val_samples,   val_transform, label_to_idx)

In [21]:
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=4, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=4)

In [22]:
model = EmbeddingNet(cfg.backbone, cfg.embedding_dim, pretrained=True).to(device)
optimizer = optim.AdamW(model.parameters(), lr=cfg.lr)

In [None]:
for epoch in range(cfg.num_epochs):
    print(f"\nEpoch {epoch+1}/{cfg.num_epochs}")
    train_loss = train_pair_epoch(model, train_loader, optimizer, device,
                               cfg.margin, cfg.semi_hard)
    print(f"Epoch {epoch+1} - Pair‑based train loss: {train_loss:.4f}")

    os.makedirs("train_pair", exist_ok=True)
    model_path = f"train_pair/model_epoch_{epoch+1}.pth"
    torch.save(model.state_dict(), model_path)
    wandb.save(model_path)
    
    wandb.log({"epoch": epoch+1})

print("Формирование бейз-эмбеддингов для каждого класса...")
base_embeddings = compute_base_embeddings(model, DataLoader(
    Caltech256ClassificationDataset(train_samples, transform=val_transform, label_to_idx=label_to_idx),
    batch_size=cfg.batch_size, shuffle=False, num_workers=4), device)

print("Валидация модели по задаче классификации...")
accuracy = validate_classification(model, base_embeddings, val_loader, device)
print(f"Accuracy: {accuracy:.4f}")
wandb.log({"Val/Accuracy": accuracy})


Epoch 1/2
Epoch 1 - Pair‑based train loss: 0.3387

Epoch 2/2
Epoch 2 - Pair‑based train loss: 0.2957
Формирование бейз-эмбеддингов для каждого класса...
Валидация модели по задаче классификации...
Accuracy: 0.7290


In [24]:
wandb.finish()

0,1
Train/Contrastive Loss,█▁
Val/Accuracy,▁
epoch,▁█

0,1
Train/Contrastive Loss,0.29569
Val/Accuracy,0.72901
epoch,2.0


In [25]:
def train_triplet_epoch(model, loader, optimizer, device,
                        margin=0.2, semi_hard=True):
    model.train()
    tmloss = nn.TripletMarginLoss(margin=margin, p=2)
    running = 0.0

    for imgs, lbs in loader:
        imgs, lbs = imgs.to(device), lbs.to(device)
        optimizer.zero_grad()
        embeds = model(imgs)

        pos, neg = mine_pairs(embeds.detach(), lbs, margin, semi_hard)

        if not pos:     
            continue

        idx_a = torch.tensor([i for i, _ in pos]).to(device)
        idx_p = torch.tensor([j for _, j in pos]).to(device)
        idx_n = torch.tensor([k for _, k in neg]).to(device)

        loss = tmloss(embeds[idx_a], embeds[idx_p], embeds[idx_n])
        loss.backward()
        optimizer.step()

        running += loss.item()

    avg_loss = running / len(loader)
    wandb.log({"Train/Triplet Loss": avg_loss})
    return avg_loss


In [26]:
wandb.init(project="ITMO_metric_learning_caltech256_pair_sampler_triplet", config={
    "backbone": "levit_128", "embedding_dim": 128,
    "lr": 1e-4, "batch_size": 64,
    "num_epochs": 2, "margin": 0.2, "semi_hard": True,
})
cfg = wandb.config

In [27]:
for epoch in range(cfg.num_epochs):
    print(f"\nEpoch {epoch+1}/{cfg.num_epochs}")
    train_loss = train_triplet_epoch(model, train_loader, optimizer, device,
                               cfg.margin, cfg.semi_hard)
    print(f"Epoch {epoch+1} - Pair‑based train loss: {train_loss:.4f}")

    os.makedirs("train_pair", exist_ok=True)
    model_path = f"train_pair/model_epoch_{epoch+1}.pth"
    torch.save(model.state_dict(), model_path)
    wandb.save(model_path)
    
    wandb.log({"epoch": epoch+1})

print("Формирование бейз-эмбеддингов для каждого класса...")
base_embeddings = compute_base_embeddings(model, DataLoader(
    Caltech256ClassificationDataset(train_samples, transform=val_transform, label_to_idx=label_to_idx),
    batch_size=cfg.batch_size, shuffle=False, num_workers=4), device)

print("Валидация модели по задаче классификации...")
accuracy = validate_classification(model, base_embeddings, val_loader, device)
print(f"Accuracy: {accuracy:.4f}")
wandb.log({"Val/Accuracy": accuracy})


Epoch 1/2
Epoch 1 - Pair‑based train loss: 0.0835

Epoch 2/2
Epoch 2 - Pair‑based train loss: 0.0701
Формирование бейз-эмбеддингов для каждого класса...
Валидация модели по задаче классификации...
Accuracy: 0.7654


In [28]:
wandb.finish()

0,1
Train/Triplet Loss,█▁
Val/Accuracy,▁
epoch,▁█

0,1
Train/Triplet Loss,0.0701
Val/Accuracy,0.76544
epoch,2.0


## Эксперимент 4

Используем ArcFace для обучения

In [9]:
class ArcFaceHead(nn.Module):
    def __init__(self, in_feats, out_feats, s=64.0, m=0.50):
        super().__init__()
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.Tensor(out_feats, in_feats))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, x, labels):
        x_norm = nn.functional.normalize(x, p=2, dim=1)
        w_norm = nn.functional.normalize(self.weight, p=2, dim=1)
        cos_theta = x_norm @ w_norm.t()
        theta = torch.acos(cos_theta.clamp(-1+1e-7, 1-1e-7))
        phi = torch.cos(theta + self.m)
        one_hot = nn.functional.one_hot(labels, num_classes=cos_theta.size(1)).float()
        logits = self.s * (one_hot * phi + (1 - one_hot) * cos_theta)
        return logits


In [15]:
wandb.init(project="ITMO_metric_learning_caltech256_arcface", config={
    "backbone": "levit_128",
    "embedding_dim": 128,
    "lr": 1e-4,
    "batch_size": 32,
    "num_epochs": 2,
    "arc_s": 64.0,
    "arc_m": 0.50,
    "weight_decay": 1e-4
})
cfg = wandb.config

In [16]:
train_ds = Caltech256ClassificationDataset(train_samples, transform, label_to_idx)
val_ds = Caltech256ClassificationDataset(val_samples, val_transform, label_to_idx)
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=4)

In [17]:
model = EmbeddingNet(cfg.backbone, cfg.embedding_dim, pretrained=True).to(device)

In [18]:
head = ArcFaceHead(cfg.embedding_dim, len(label_to_idx), s=cfg.arc_s, m=cfg.arc_m).to(device)

In [19]:
optimizer = optim.AdamW([
    {"params": model.backbone.parameters(), "lr": cfg.lr*0.1},
    {"params": model.fc.parameters()},
    {"params": head.parameters()}
], lr=cfg.lr, weight_decay=cfg.weight_decay)

In [20]:
for epoch in range(cfg.num_epochs):
    model.train(); head.train()
    running_loss = 0.0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        embs = model(imgs)
        logits = head(embs, labels)
        loss = nn.functional.cross_entropy(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            list(model.parameters())+list(head.parameters()), max_norm=1.0
        )
        optimizer.step()
        running_loss += loss.item()
    avg_loss = running_loss/len(train_loader)
    print(f"Epoch {epoch+1}/{cfg.num_epochs} — Train Loss: {avg_loss:.4f}")
    wandb.log({"Train/Loss": avg_loss, "epoch": epoch+1})

print("Формирование бейз-эмбеддингов для каждого класса...")
base_embeddings = compute_base_embeddings(model, DataLoader(
    Caltech256ClassificationDataset(train_samples, transform=val_transform, label_to_idx=label_to_idx),
    batch_size=cfg.batch_size, shuffle=False, num_workers=4), device)

print("Валидация модели по задаче классификации...")
accuracy = validate_classification(model, base_embeddings, val_loader, device)
print(f"Accuracy: {accuracy:.4f}")
wandb.log({"Val/Accuracy": accuracy})

Epoch 1/2 — Train Loss: 37.4359
Epoch 2/2 — Train Loss: 27.9637
Формирование бейз-эмбеддингов для каждого класса...
Валидация модели по задаче классификации...
Accuracy: 0.7328


In [21]:
wandb.finish()

0,1
Train/Loss,█▁
Val/Accuracy,▁
epoch,▁█

0,1
Train/Loss,27.96369
Val/Accuracy,0.73277
epoch,2.0


## Эксперимент 5

Proxy-Based Loss

In [49]:
class ProxyNCALoss(nn.Module):
    def __init__(self, num_classes, embedding_dim, scale=3.0):
        """
        num_classes: число классов (количество прокси)
        embedding_dim: размерность эмбеддингов
        scale: масштабирующий коэффициент (по сути, повышает различимость логитов)
        """
        super(ProxyNCALoss, self).__init__()
        # обучаемые векторы для каждого класса, инициализируем случайно
        self.proxies = nn.Parameter(torch.randn(num_classes, embedding_dim))
        nn.init.kaiming_normal_(self.proxies, mode='fan_out')
        self.scale = scale
        self.ce = nn.CrossEntropyLoss()

    def forward(self, embeddings, labels):
        proxies = nn.functional.normalize(self.proxies.to(embeddings.device), p=2, dim=1)
        # скалярные произведения между эмбеддингами и прокси: [batch_size, num_classes]
        logits = self.scale * torch.matmul(embeddings, proxies.t())
        loss = self.ce(logits, labels)
        return loss

In [52]:
wandb.init(project="ITMO_metric_learning_caltech256_proxy", config={
    "backbone": "levit_128",
    "embedding_dim": 128,
    "lr": 1e-4,
    "batch_size": 32,
    "num_epochs": 2,
    "scale": 3.0,  # масштаб для ProxyNCALoss
})
config = wandb.config

In [None]:
train_dataset = Caltech256ClassificationDataset(train_samples, transform=transform, label_to_idx=label_to_idx)
val_dataset = Caltech256ClassificationDataset(val_samples, transform=val_transform, label_to_idx=label_to_idx)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)

In [54]:
model = EmbeddingNet(backbone_name=config.backbone, embedding_dim=config.embedding_dim, pretrained=True)
model.to(device)

proxy_loss_fn = ProxyNCALoss(num_classes=num_classes, embedding_dim=config.embedding_dim, scale=config.scale)
optimizer = optim.AdamW(list(model.parameters()) + list(proxy_loss_fn.parameters()), lr=config.lr)

In [55]:
num_epochs = config.num_epochs
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        
        embeddings = model(images)
        loss = proxy_loss_fn(embeddings, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx}/{len(train_loader)}: Loss = {loss.item():.4f}")
            wandb.log({"Train/Proxy Loss": loss.item()})
    avg_loss = running_loss / len(train_loader)
    wandb.log({"Train/Average Proxy Loss": avg_loss, "epoch": epoch+1})
    print(f"Epoch {epoch+1} - Average Training Loss: {avg_loss:.4f}")
    
    os.makedirs("train_proxy", exist_ok=True)
    model_path = f"train_proxy/model_epoch_{epoch+1}.pth"
    torch.save({
        "model_state_dict": model.state_dict(),
        "proxy_state_dict": proxy_loss_fn.state_dict()
    }, model_path)
    wandb.save(model_path)
    
    # Валидация через nearest neighbor
    base_embeddings = compute_base_embeddings(model, DataLoader(train_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4), device)
    accuracy = validate_classification(model, base_embeddings, val_loader, device)
    print(f"Epoch {epoch+1} - Val Accuracy: {accuracy:.4f}")
    wandb.log({"Val/Accuracy": accuracy})

Epoch 1, Batch 0/766: Loss = 5.5349
Epoch 1, Batch 10/766: Loss = 5.6045
Epoch 1, Batch 20/766: Loss = 5.5305
Epoch 1, Batch 30/766: Loss = 5.4690
Epoch 1, Batch 40/766: Loss = 5.4560
Epoch 1, Batch 50/766: Loss = 5.3570
Epoch 1, Batch 60/766: Loss = 5.3128
Epoch 1, Batch 70/766: Loss = 5.3168
Epoch 1, Batch 80/766: Loss = 5.3181
Epoch 1, Batch 90/766: Loss = 5.2714
Epoch 1, Batch 100/766: Loss = 5.2251
Epoch 1, Batch 110/766: Loss = 5.3080
Epoch 1, Batch 120/766: Loss = 5.1738
Epoch 1, Batch 130/766: Loss = 5.2314
Epoch 1, Batch 140/766: Loss = 5.1587
Epoch 1, Batch 150/766: Loss = 5.0010
Epoch 1, Batch 160/766: Loss = 4.9476
Epoch 1, Batch 170/766: Loss = 5.1169
Epoch 1, Batch 180/766: Loss = 5.1118
Epoch 1, Batch 190/766: Loss = 5.1075
Epoch 1, Batch 200/766: Loss = 4.9189
Epoch 1, Batch 210/766: Loss = 5.1253
Epoch 1, Batch 220/766: Loss = 5.0115
Epoch 1, Batch 230/766: Loss = 4.9587
Epoch 1, Batch 240/766: Loss = 4.8964
Epoch 1, Batch 250/766: Loss = 4.9007
Epoch 1, Batch 260/766:

In [56]:
wandb.finish()

0,1
Train/Average Proxy Loss,█▁
Train/Proxy Loss,█▇▇▆▆▆▆▆▅▄▅▅▄▄▄▃▃▃▂▃▃▃▃▂▂▃▂▃▃▂▃▁▂▁▂▁▁▁▂▁
Val/Accuracy,▁█
epoch,▁█

0,1
Train/Average Proxy Loss,3.84026
Train/Proxy Loss,3.66651
Val/Accuracy,0.81885
epoch,2.0


## Эксперимент 6

In [61]:
import faiss

In [62]:
# Функция для формирования полной базы эмбеддингов обучающего набора
def build_train_index(model, dataloader, device):
    model.eval()
    embeddings_list = []
    labels_list = []
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            embeddings = model(images)
            embeddings_list.append(embeddings.cpu().numpy())
            labels_list.append(labels.cpu().numpy())
    train_embeddings = np.concatenate(embeddings_list, axis=0)
    train_labels = np.concatenate(labels_list, axis=0)
    return train_embeddings, train_labels

In [63]:
# Функция валидации с использованием FAISS-классификации
def validate_with_faiss(model, train_embeddings, train_labels, val_loader, device, k=5):
    model.eval()
    # Строим FAISS-индекс
    d = train_embeddings.shape[1]
    index = faiss.IndexFlatL2(d)
    index.add(train_embeddings)
    
    correct = 0
    total = 0
    
    for images, labels in val_loader:
        images = images.to(device)
        with torch.no_grad():
            emb = model(images)
        emb_np = emb.cpu().numpy()
        distances, indices = index.search(emb_np, k)
        # Голосование: для каждого эмбеддинга выбираем наиболее частую метку среди k соседей
        preds = []
        for neigh_idx in indices:
            neigh_labels = train_labels[neigh_idx]
            pred = np.bincount(neigh_labels).argmax()
            preds.append(pred)
        preds = np.array(preds)
        total += labels.size(0)
        correct += (preds == labels.cpu().numpy()).sum()
    
    accuracy = correct / total
    return accuracy

In [65]:
wandb.init(project="ITMO_metric_learning_caltech256_faiss", config={
    "backbone": "levit_128",
    "embedding_dim": 128,
    "lr": 1e-4,
    "batch_size": 32,
    "num_epochs": 2,
    "k": 5,
})
config = wandb.config

In [66]:
train_dataset = Caltech256ClassificationDataset(train_samples, transform=transform, label_to_idx=label_to_idx)
val_dataset = Caltech256ClassificationDataset(val_samples, transform=transform, label_to_idx=label_to_idx)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4)

In [67]:
train_embeddings, train_labels = build_train_index(model, DataLoader(train_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4), device)
print(f"Сформирован индекс на {train_embeddings.shape[0]} эмбеддингах.")

Сформирован индекс на 24485 эмбеддингах.


In [68]:
accuracy = validate_with_faiss(model, train_embeddings, train_labels, val_loader, device, k=config.k)
print(f"Val Accuracy (FAISS KNN): {accuracy:.4f}")
wandb.log({"Val/Accuracy_FAISS": accuracy})
wandb.finish()

Val Accuracy (FAISS KNN): 0.8275


0,1
Val/Accuracy_FAISS,▁

0,1
Val/Accuracy_FAISS,0.82751


In [69]:
wandb.finish()

In [70]:
torch.save(model, "model_last.pth")