In [1]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
lyly99_logodet3k_path = kagglehub.dataset_download('lyly99/logodet3k')

print('Data source import complete.')

Downloading from https://www.kaggle.com/api/v1/datasets/download/lyly99/logodet3k?dataset_version_number=1...


100%|██████████| 2.87G/2.87G [02:30<00:00, 20.5MB/s]

Extracting files...





Data source import complete.


# Ставим библиотечку для оптимизатора Prodigy (в целом, конечно, можно воспользоваться обычным Adam или AdamW)

In [2]:
!pip install prodigyopt

Collecting prodigyopt
  Downloading prodigyopt-1.1.2-py3-none-any.whl.metadata (4.8 kB)
Downloading prodigyopt-1.1.2-py3-none-any.whl (10 kB)
Installing collected packages: prodigyopt
Successfully installed prodigyopt-1.1.2


# Импортируем нужные библиотеки

In [274]:
import os
from pathlib import Path
import PIL
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, cdist

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader

from prodigyopt import Prodigy
from transformers import ViTModel, ViTImageProcessor
import xml.etree.ElementTree as E

from numba import njit

torch.manual_seed(42)
np.random.seed(42)

# Девайс, модель, препроцессор картинок

In [275]:
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

# Класс основной модели

In [276]:
class VitMatcher(nn.Module):
    def __init__(self, model_name: str):
      super().__init__()

      self.vit_processor = ViTImageProcessor.from_pretrained(model_name)
      self.vit_model = ViTModel.from_pretrained(model_name)
      self.threshold = None
      self.positive_vectors = None # num_positives x emb_size

      for param_name, param in self.vit_model.named_parameters():
        if "pooler" in param_name:
          param.requires_grad = True
        else:
          param.requires_grad = False

    def forward(self, x: torch.Tensor):
      return self.vit_model(x).pooler_output

    # Предсказания делаются на основе того, насколько близко среднее расстояние от эмбеддинга кропа
    # до среднего расстояния внутри образцов
    def predict(self, x: torch.Tensor):
      if self.positive_vectors is None:
        return
      else:
        with torch.no_grad():
          input_embeddings = self.forward(x).cpu().numpy() # batch_size x emb_size
        distances_matrix = cdist(self.positive_vectors, input_embeddings)

        return distances_matrix.mean(axis=0) < self.threshold

    # задаём трешхолд для расстояния
    def set_threshold(self, positive_samples: torch.Tensor):
      with torch.no_grad():
        self.positive_vectors = self.forward(positive_samples).cpu().numpy()
      distance_matrix = pdist(self.positive_vectors)
      self.threshold = distance_matrix.mean() + distance_matrix.std()

# Функция для парсинга xml с разметкой кропов

In [277]:
def get_bbox_from_xml(xml_filename: Path):
    with open(xml_filename.as_posix(), "r", encoding="utf8") as f:
        xml_string = f.read()

    root = E.fromstring(xml_string)

    bbox = root.find(".//object/bndbox")
    if bbox is not None:
        xmin = int(bbox.find("xmin").text)
        ymin = int(bbox.find("ymin").text)
        xmax = int(bbox.find("xmax").text)
        ymax = int(bbox.find("ymax").text)

        return xmin, ymin, xmax, ymax

    return None

# Базовый препроцессор для картинок

In [278]:
def my_transformer(img: PIL.Image, path_to_image: str) -> torch.Tensor:
    path_to_xml = Path(path_to_image).parent / (Path(path_to_image).stem + '.xml')
    crop_coords = get_bbox_from_xml(path_to_xml)

    img = img.crop(crop_coords)
    processed_img = processor(images=img, return_tensors="pt")["pixel_values"]

    return processed_img

# Класс для создания датасета для обучения модели на Triplet Loss и вспомогательная функция для создания триплетов (anchor, positive, negative)

In [279]:
@njit
def generate_triplets(targets: np.ndarray, n_triplets: int):
      labels = targets
      triplets = []
      for x in np.arange(n_triplets):
          idx = np.random.randint(0, labels.shape[0])
          idx_matches = np.where(labels == labels[idx])[0]
          idx_no_matches = np.where(labels != labels[idx])[0]
          idx_a, idx_p = np.random.choice(idx_matches, 2, replace=False)
          idx_n = np.random.choice(idx_no_matches, 1)[0]
          triplets.append([idx_a, idx_p, idx_n])
      return triplets

def triplet_collate_fn(batch: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]):
  anchors = []
  positives = []
  negatives = []
  for anchor, positive, negative in batch:
      anchors.append(anchor)
      positives.append(positive)
      negatives.append(negative)
  return torch.cat(anchors), torch.cat(positives), torch.cat(negatives)

class TripletImageFolder(torchvision.datasets.ImageFolder):
    """From the torchvision.datasets.ImageFolder it generates triplet samples, used in training. For testing we use normal image folder.
    Note: a triplet is composed by a pair of matching images and one of different class.
    """
    def __init__(self, *arg, **kw):
        super(TripletImageFolder, self).__init__(*arg, **kw)
        shuffled_indices = np.random.choice(np.arange(len(self.samples)), len(self.samples), replace=False)
        self.train_indices = shuffled_indices[:int(0.7 * len(self.samples))]
        self.test_indices = shuffled_indices[int(0.3 * len(self.samples)):]

        self.n_triplets_train = len(self.train_indices)
        self.n_triplets_test = len(self.test_indices)

        self.train_triplets = np.array(generate_triplets(np.array(self.targets)[self.train_indices], self.n_triplets_train))
        self.test_triplets = np.array(generate_triplets(np.array(self.targets)[self.test_indices], self.n_triplets_test))

    def set_triplets(self, triplets):
        self.train_triplets = triplets

    def generate_triplets(self):
        return np.array(generate_triplets(np.array(self.targets)[self.train_indices], self.n_triplets_train))

    def __getitem__(self, index):
        if index < self.n_triplets_train:
          t = self.train_triplets[index]
          path_a, _ = np.array(self.samples)[self.train_indices][t[0]]
          path_p, _ = np.array(self.samples)[self.train_indices][t[1]]
          path_n, _ = np.array(self.samples)[self.train_indices][t[2]]
        else:
          t = self.test_triplets[index - self.n_triplets_train]

          path_a, _ = np.array(self.samples)[self.test_indices][t[0]]
          path_p, _ = np.array(self.samples)[self.test_indices][t[1]]
          path_n, _ = np.array(self.samples)[self.test_indices][t[2]]

        img_a = self.loader(path_a)
        img_p = self.loader(path_p)
        img_n = self.loader(path_n)

        if self.transform is not None:
            img_a = self.transform(img_a, path_a)
            img_p = self.transform(img_p, path_p)
            img_n = self.transform(img_n, path_n)

        return img_a, img_p, img_n

    def __len__(self):
        return len(self.samples)

    def sample_for_test(self, num_samples: int = 10):
      stop_sampling = False
      while not stop_sampling:
        try:
          class_label = np.random.choice(np.array(self.targets)[self.test_indices], 1)
          matched_indices = np.where(np.array(np.array(self.targets)[self.test_indices]) == class_label)[0]
          positive_samples_idxs = np.random.choice(matched_indices, num_samples, replace=False)
          stop_sampling = True
        except Exception:
          stop_sampling = False

      positive_images = []
      raw_images = []
      for idx in positive_samples_idxs:
        path_p, _ = np.array(self.samples)[self.test_indices][idx]
        img_p = self.loader(path_p)
        raw_images.append(img_p.copy())
        if self.transform is not None:
            img_p = self.transform(img_p, path_p)
        positive_images.append(img_p)

      return torch.cat(positive_images), raw_images, class_label

# Задаём тренировочный и тестовый датасеты, а также модель

In [281]:
triplet_data = TripletImageFolder(root=f"{lyly99_logodet3k_path}/LogoDet-3K/Electronic", transform=my_transformer)
train_dataset, test_dataset = torch.utils.data.Subset(triplet_data, np.arange(triplet_data.n_triplets_train)), torch.utils.data.Subset(triplet_data, np.arange(triplet_data.n_triplets_test))

train_triplet_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=triplet_collate_fn)
test_triplet_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=triplet_collate_fn)

In [282]:
vit_matcher = VitMatcher("google/vit-base-patch16-224-in21k")
loss_fn = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
loss_fn = torch.compile(loss_fn)
opt = Prodigy([param for param_name, param in vit_matcher.named_parameters() if param.requires_grad == True], lr=1., weight_decay=0.01, slice_p=11)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

Using decoupled weight decay


In [283]:
vit_matcher = torch.compile(vit_matcher)
vit_matcher = vit_matcher.to(device)

# Зададим случайные тестовые образцы одного класса, а затем образцы другого класса и посмотрим, какова точность предсказаний перед началом обучения

In [284]:
positive_samples, raw_images, class_label = triplet_data.sample_for_test()
vit_matcher.set_threshold(positive_samples.to(device))

In [286]:
new_positive_samples, new_raw_images, new_class_label = triplet_data.sample_for_test()
(vit_matcher.predict(new_positive_samples.to(device)) == False).mean()

0.0

# Цикл обучения

In [None]:
num_epochs = 20

for epoch in tqdm(range(num_epochs)):
  train_epoch_losses = []
  for img_a, img_p, img_n in tqdm(train_triplet_dataloader):
    with torch.autocast(device_type=device):
      model_input = torch.cat([img_a, img_p, img_n]).to(device)
      model_output = vit_matcher(model_input)
      emb_a, emb_p, emb_n = model_output.chunk(3)
      loss = loss_fn(emb_a, emb_p, emb_n)
    train_epoch_losses.append(loss.item())

    opt.zero_grad()
    loss.backward()
    opt.step()

  test_epoch_losses = []
  for img_a, img_p, img_n in tqdm(test_triplet_dataloader):
    model_input = torch.cat([img_a, img_p, img_n]).to(device)
    with torch.no_grad():
      with torch.autocast(device_type=device):
        model_output = vit_matcher(model_input)
        emb_a, emb_p, emb_n = model_output.chunk(3)
        loss = loss_fn(emb_a, emb_p, emb_n)
        test_epoch_losses.append(loss.item())

  new_positive_samples, new_raw_images, new_class_label = triplet_data.sample_for_test()
  test_predictions = vit_matcher.predict(new_positive_samples.to(device))
  if new_class_label != class_label:
    test_predictions = test_predictions == False
  print(f"Accuracy: {test_predictions.mean():.3f}")

  print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {np.mean(train_epoch_losses)}")
  print(f"Epoch {epoch+1}/{num_epochs}, Test Loss: {np.mean(test_epoch_losses)}")
  triplets = train_triplet_dataloader.dataset.dataset.generate_triplets()
  train_triplet_dataloader.dataset.dataset.set_triplets(triplets)

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/212 [00:00<?, ?it/s]

  0%|          | 0/212 [00:00<?, ?it/s]

Accuracy: 1.000
Epoch 1/20, Train Loss: 0.09923777198594697
Epoch 1/20, Test Loss: 0.029991730858729977


  0%|          | 0/212 [00:00<?, ?it/s]

  0%|          | 0/212 [00:00<?, ?it/s]

Accuracy: 1.000
Epoch 2/20, Train Loss: 0.22104076257432406
Epoch 2/20, Test Loss: 0.06981677122255962


  0%|          | 0/212 [00:00<?, ?it/s]