In [1]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import copy
import json
import numpy as np
import matplotlib.pyplot as plt 
import albumentations as A
from albumentations.pytorch import ToTensorV2
from pytorch_metric_learning import losses, miners, testers, samplers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from tqdm import tqdm
import random
import wandb

In [2]:
from models.efficientnet import EffNetModel
from models.swin import SwinTransformerModel
from models.efficientnet_v2 import EffNetV2Model

In [3]:
from datasets_handlers.base_dataset import BaseImageDataset, ClassLabelsMapper
from datasets_handlers.dataset_cub import cub200_dataset
from datasets_handlers.in_shop import in_shop_dataset
from datasets_handlers.stanford_online_products import stanford_products_dataset
from datasets_handlers.google_landmark import google_landmark_dataset

In [4]:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [5]:
#stanford_train = stanford_products_dataset(split='train')
#stanford_test = stanford_products_dataset(split='test')

In [6]:
cub200_train = cub200_dataset(split='train')
cub200_test = cub200_dataset(split='test')

In [7]:
#inshop_train = ClassLabelsMapper(in_shop_dataset(split='train'))
#inshop_gallery = in_shop_dataset(split='gallery')
#inshop_query = in_shop_dataset(split='query')

In [8]:
#glm_train = google_landmark_dataset()

In [9]:
image_size = 224
train_transform = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ImageCompression(quality_lower=99, quality_upper=100),
        A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=10, border_mode=0, p=0.7),
        A.Resize(image_size, image_size),
        A.CoarseDropout(max_height=int(image_size * 0.4), max_width=int(image_size * 0.4), max_holes=1, p=0.5),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
        ToTensorV2()
    ]
)

test_transform = A.Compose([
    A.Resize(image_size, image_size),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2()
])

In [10]:
dataset_train = ClassLabelsMapper(BaseImageDataset(cub200_train, path_prefix='data/', augmentations=train_transform))
dataset_test = BaseImageDataset(cub200_test, path_prefix='data/', augmentations=test_transform)
#dataset_gallery = BaseImageDataset(inshop_gallery, path_prefix='data/', augmentations=test_transform)
#dataset_query = BaseImageDataset(inshop_query, path_prefix='data/', augmentations=test_transform)

In [11]:
parameters = {
    'dataset:': type(dataset_train.dataset).__name__,
    'n_epochs': 500,
    'batch_size': 16,
    'lr': 1e-4
}

In [12]:
wandb.init(project="my-test-project", entity="ilya_fedorov", config=parameters)

[34m[1mwandb[0m: Currently logged in as: [33milya_fedorov[0m (use `wandb login --relogin` to force relogin)


In [13]:
%%time
train_targets = []
for x, target in cub200_train:
    train_targets.append(target)

CPU times: user 1.5 ms, sys: 0 ns, total: 1.5 ms
Wall time: 1.51 ms


In [14]:
sampler = samplers.MPerClassSampler(
    train_targets, m=4, length_before_new_iter=len(dataset_train)
)

In [15]:
dataloader_train = DataLoader(
    dataset_train,
    batch_size=parameters['batch_size'],
    num_workers=8,
    pin_memory=True,
    sampler=sampler
)

In [16]:
#model = EffNetModel('efficientnet-b0').cuda()
#model = SwinTransformerModel('microsoft/swin-base-patch4-window7-224')
model = EffNetV2Model('swin_base_patch4_window7_224')
model.cuda()
model = model.train()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [17]:
embedding_size = model(dataset_train[0][0].unsqueeze(0).cuda()).shape[1]
train_n_classes = len(dataset_train.labels_map)

In [18]:
loss_args = {'margin': 0.02}
loss_func = losses.TripletMarginLoss(**loss_args)
#loss_func = losses.ArcFaceLoss(embedding_size=embedding_size, num_classes=train_n_classes)

miner = miners.TripletMarginMiner(**loss_args)

In [19]:
dataset_gallery = dataset_test
dataset_query = dataset_test

skip_first_neighbour = dataset_query == dataset_gallery

In [20]:
wandb.config.update({"skip_first_neighbour": skip_first_neighbour, 
                     "loss": type(loss_func).__name__, 
                     "model": type(model).__name__,
                     **loss_args
})

In [21]:
optimizer = torch.optim.AdamW(model.parameters(), lr=parameters['lr'])

In [22]:
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

def test(train_set, test_set, model, accuracy_calculator):
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    if train_set != test_set:
        test_embeddings, test_labels = get_all_embeddings(test_set, model)
    else:
        test_embeddings, test_labels = train_embeddings, train_labels
    train_labels = train_labels.squeeze(1)
    test_labels = test_labels.squeeze(1)
    accuracies = accuracy_calculator.get_accuracy(
        test_embeddings, train_embeddings, test_labels, train_labels, skip_first_neighbour
    )
    return accuracies

In [23]:
accuracy_calculator = AccuracyCalculator(include=("mean_average_precision", "precision_at_1"), k=5)

In [24]:
train_iter_log_freq = 50

i = 0
mean_loss = 0
for epoch in range(parameters['n_epochs']):
    print(f'Starting epoch {epoch}')
    wandb.log({"epoch": epoch})
    
    for data, labels in tqdm(dataloader_train):
        data = data.cuda()
        labels = labels.cuda()
        optimizer.zero_grad()
        embeddings = model(data)
        hard_pairs = miner(embeddings, labels)
        loss = loss_func(embeddings, labels, hard_pairs)
        mean_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        if (i + 1) % train_iter_log_freq == 0:
            mean_loss /= train_iter_log_freq
            wandb.log({"train/loss": mean_loss})
            mean_loss = 0
        i += 1
        
    if epoch % 10 == 0:
        # query!
        model.eval()
        test_results = test(dataset_gallery, dataset_query, model, accuracy_calculator)
        wandb.log({"test/map@1": test_results['precision_at_1']})
        wandb.log({"test/map@5": test_results['mean_average_precision']})
        model.train()

Starting epoch 0


100%|█████████████████████████████████████████| 150/150 [00:22<00:00,  6.68it/s]
100%|███████████████████████████████████████████| 95/95 [00:08<00:00, 11.45it/s]


Starting epoch 1


100%|█████████████████████████████████████████| 150/150 [00:21<00:00,  6.91it/s]


Starting epoch 2


100%|█████████████████████████████████████████| 150/150 [00:23<00:00,  6.42it/s]


Starting epoch 3


100%|█████████████████████████████████████████| 150/150 [00:21<00:00,  6.87it/s]


Starting epoch 4


100%|█████████████████████████████████████████| 150/150 [00:22<00:00,  6.71it/s]


Starting epoch 5


100%|█████████████████████████████████████████| 150/150 [00:22<00:00,  6.62it/s]


Starting epoch 6


100%|█████████████████████████████████████████| 150/150 [00:22<00:00,  6.77it/s]


Starting epoch 7


100%|█████████████████████████████████████████| 150/150 [00:21<00:00,  6.91it/s]


Starting epoch 8


100%|█████████████████████████████████████████| 150/150 [00:22<00:00,  6.78it/s]


Starting epoch 9


100%|█████████████████████████████████████████| 150/150 [00:21<00:00,  6.91it/s]


Starting epoch 10


100%|█████████████████████████████████████████| 150/150 [00:22<00:00,  6.71it/s]
100%|███████████████████████████████████████████| 95/95 [00:08<00:00, 11.32it/s]


Starting epoch 11


 71%|█████████████████████████████▏           | 107/150 [00:17<00:06,  6.27it/s]


KeyboardInterrupt: 

In [21]:
test_results = test(dataset_train, dataset_train, model, accuracy_calculator)

100%|███████████████████████████████████████| 1861/1861 [01:33<00:00, 19.94it/s]
100%|███████████████████████████████████████| 1861/1861 [01:33<00:00, 19.91it/s]


In [22]:
test_results

{'mean_average_precision': 0.8537094955024545,
 'precision_at_1': 0.8352840422495005}