Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
30 changed files
with
2,175 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
# for files opened by jupyter notebook, e.g. log files | ||
.nfs000* | ||
|
||
.DS_Store | ||
._.DS_Store | ||
*.pt | ||
*.pyc | ||
__pycache__/ | ||
.ipynb_checkpoints | ||
notebooks/ | ||
log/ | ||
|
||
# vs code files | ||
._* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,51 @@ | ||
# Divide and Conquer the Embedding Space for Metric Learning | ||
***Artsiom Sanakoyeu\*, Vadim Tschernezki\*, Uta Büchler, Björn Ommer*, In CVPR 2019** | ||
|
||
The PDF of the paper and the source code are coming soon. | ||
## About | ||
|
||
This repository contains the code for reproducing the results for [Divide and Conquer the Embedding Space for Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Sanakoyeu_Divide_and_Conquer_the_Embedding_Space_for_Metric_Learning_CVPR_2019_paper.pdf) (CVPR 2019) with the datasets [In-Shop Clothes](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html), [Stanford Online Products](http://cvgl.stanford.edu/projects/lifted_struct/) and [PKU VehicleID](https://www.pkuml.org/resources/pku-vehicleid.html). | ||
|
||
## Requirements | ||
|
||
- Python version 3.6.6 or higher | ||
- SciPy and scikit-learn packages | ||
- PyTorch ([pytorch.org](http://pytorch.org)) | ||
- Faiss with GPU support ([Faiss](https://github.com/facebookresearch/faiss)) | ||
- download and extract the datasets for [In-Shop Clothes](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html), [Stanford Online Products](http://cvgl.stanford.edu/projects/lifted_struct/) and [PKU VehicleID](https://www.pkuml.org/resources/pku-vehicleid.html) | ||
|
||
## Usage | ||
|
||
The following command will train the model with Margin loss on the In-Shop Clothes dataset for 200 epochs and a batch size of 80 while splitting the embedding layer with 8 clusters and finetuning the model from epoch 190 on. You can use this command to reproduce the results of the paper for the three datasets by changing simply `--dataset=inshop` to `--dataset=sop` (Stanford Online Products) or `--dataset=vid` (Vehicle-ID). | ||
|
||
``` | ||
CUDA_VISIBLE_DEVICES=0 python experiment.py --dataset=inshop \ | ||
--dir=test --exp=0 --random-seed=0 --nb-clusters=8 --nb-epochs=200 \ | ||
--sz-batch=80 --backend=faiss-gpu --embedding-lr=1e-5 --embedding-wd=1e-4 \ | ||
--backbone-lr=1e-5 --backbone-wd=1e-4 --finetune-epoch=190 | ||
``` | ||
|
||
The model can be trained without the proposed method by setting the number of clusters to 1 with `--nb-clusters=1`. | ||
For faster clustering we run Faiss on GPU. If you installed Faiss without GPU support use flag `--backend=faiss`. | ||
## Expected Results | ||
|
||
The model checkpoints and log files are saved in the selected log-directory. You can print a summary of the results with `python browse_results <log path>`. | ||
|
||
You will get slightly higher results than what we have reported in the paper. For SOP, In-Shop and Vehicle-ID the R@1 results should be somewhat around 76.40, 87.36 and 91.54. | ||
|
||
## License | ||
|
||
You may find out more about the license [here](LICENSE) | ||
|
||
## Reference | ||
|
||
If you use this code, please cite the following paper: | ||
|
||
Artsiom Sanakoyeu, Vadim Tschernezki, Uta Büchler, Björn Ommer. "Divide and Conquer the Embedding Space for Metric Learning", CVPR 2019. | ||
|
||
``` | ||
@InProceedings{dcesml, | ||
title={Divide and Conquer the Embedding Space for Metric Learning}, | ||
author={Sanakoyeu, Artsiom and Tschernezki, Vadim and B\"uchler, Uta and Ommer, Bj\"orn}, | ||
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, | ||
year={2019}, | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import shelve | ||
from collections import defaultdict | ||
import sys | ||
import os | ||
import numpy as np | ||
import pandas as pd | ||
import time | ||
import glob | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('exp_dir', type = str) | ||
parser.add_argument('-cw', '--col-width', type=int, default=100) | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
print('exp_dir=', args.exp_dir) | ||
|
||
files = sorted(list(map(lambda x: x[:-4], glob.glob(os.path.join(args.exp_dir, '*.dat'))))) | ||
|
||
|
||
results = defaultdict(list) | ||
ks = [1, 2, 4, 8, 10, 20, 30, 50] | ||
columns=[ | ||
'epoch', | ||
*['R@{}'.format(i) for i in ks], | ||
] | ||
|
||
last_modified = None | ||
|
||
for p in files: | ||
try: | ||
db = shelve.open(p) | ||
log_path = p + '.log' | ||
assert os.path.exists(log_path), log_path | ||
last_modified = (time.time() - os.path.getmtime(p + '.log')) / 60 | ||
except: | ||
print('Failed to open', p) | ||
try: | ||
p = os.path.basename(p) | ||
cur_results_t = np.array([(epoch, *d['score']['recall']) | ||
for (epoch, d) in db['metrics'].items()]) | ||
cur_results = np.zeros((cur_results_t.shape[0], 1 + len(ks)), dtype=float) | ||
cur_results[:, :] = np.nan | ||
cur_results[:, :2] = cur_results_t[:, :2] | ||
# TODO: maybe rename args to config | ||
if db['config']['dataset_selected'] == 'inshop': | ||
cur_results[:, 5:] = cur_results_t[:, 2:] | ||
else: | ||
cur_results[:, 2:5] = cur_results_t[:, 2:] | ||
|
||
except Exception as e: | ||
print(p, e) | ||
print(db['config']) | ||
|
||
idx_max_recall = cur_results[:, 1].argmax() | ||
best_epoch_results = cur_results[idx_max_recall] | ||
max_epoch = cur_results[:, 0].max() | ||
best_epoch_results = best_epoch_results.tolist() | ||
best_epoch_results[0] = '{:02}/{:02}'.format(int(best_epoch_results[0]), int(max_epoch)) | ||
assert len(best_epoch_results) == len(columns) | ||
|
||
for i, col_name in enumerate(columns): | ||
results[col_name].append(best_epoch_results[i]) | ||
|
||
# if the file was last modified < 10 minute ago; than print Running status | ||
if last_modified is None: | ||
results['S'].append('?') | ||
elif last_modified > 10: | ||
results['S'].append('-') | ||
else: | ||
results['S'].append('[R]') | ||
|
||
|
||
df = pd.DataFrame(index=list(map(os.path.basename, files)), | ||
data=results) | ||
|
||
pd.set_option('display.max_rows', 10000) | ||
pd.set_option('display.max_columns', 10000) | ||
pd.set_option('display.max_colwidth', args.col_width) | ||
pd.set_option('display.width', 1000000) | ||
df.sort_values(by=['R@1'], inplace=True) | ||
print(df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
{ | ||
"random_seed": 0, | ||
"verbose": false, | ||
"save_model": true, | ||
"cuda_device": 0, | ||
"sz_embedding": 128, | ||
"backend": "faiss-gpu", | ||
"nb_epochs": 200, | ||
"nb_clusters": 8, | ||
"finetune_epoch": 100, | ||
"dataset_selected": "sop", | ||
"dataset": { | ||
"vid": { | ||
"root": "/export/home/vtschern/data/vehicle-id/VehicleID_V1.0", | ||
"classes": { | ||
"train": "range(0, 13164)", | ||
"init": "range(0, 13164)", | ||
"_note": "small: 800, mid: +1600, large: +2400", | ||
"eval": "range(13164, 13164 + 2400)" | ||
} | ||
}, | ||
"inshop": { | ||
"root": "/export/home/vtschern/data/in-shop", | ||
"classes": { | ||
"train": "range(0, 3997)", | ||
"init": "range(0, 3997)", | ||
"eval": "range(0, 3985)" | ||
} | ||
}, | ||
"sop": { | ||
"root": "/export/home/vtschern/data/sop", | ||
"classes": { | ||
"train": "range(0, 11318)", | ||
"init": "range(0, 11318)", | ||
"eval": "range(11318, 22634)" | ||
} | ||
} | ||
}, | ||
"log": { | ||
"path": "log/default", | ||
"name": "sop-K-8-M-2-exp-0" | ||
}, | ||
"dataloader": { | ||
"num_workers": 4, | ||
"drop_last": false, | ||
"shuffle": false, | ||
"pin_memory": true, | ||
"batch_size": 80 | ||
}, | ||
"opt": { | ||
"backbone": { | ||
"lr": 1e-5, | ||
"weight_decay": 1e-4 | ||
}, | ||
"embedding": { | ||
"lr": 1e-5, | ||
"weight_decay": 1e-4 | ||
} | ||
}, | ||
"recluster": { | ||
"enabled": true, | ||
"mod_epoch": 2 | ||
}, | ||
"transform_parameters": { | ||
"rgb_to_bgr": false, | ||
"intensity_scale": [[0, 1], [0, 1]], | ||
"mean": [0.485, 0.456, 0.406], | ||
"std": [0.229, 0.224, 0.225], | ||
"sz_crop": 224 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import math | ||
import matplotlib | ||
import sys | ||
|
||
import train | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--nb-clusters', required = True, type = int) | ||
parser.add_argument('--dataset', dest = 'dataset_selected', | ||
choices=['sop', 'inshop', 'vid'], required = True | ||
) | ||
parser.add_argument('--nb-epochs', type = int, default=200) | ||
parser.add_argument('--finetune-epoch', type = int, default=190) | ||
parser.add_argument('--mod-epoch', type = int, default=2) | ||
parser.add_argument('--num-workers', default=4, type=int) | ||
parser.add_argument('--sz-batch', type=int, default=80) | ||
parser.add_argument('--sz-embedding', default=128, type=int) | ||
parser.add_argument('--cuda-device', default = 0, type = int) | ||
parser.add_argument('--exp', default='0', type=str, help='experiment identifier') | ||
parser.add_argument('--dir', default='default', type=str) | ||
parser.add_argument('--backend', default='faiss', | ||
choices=('torch+sklearn', 'faiss', 'faiss-gpu') | ||
) | ||
parser.add_argument('--random-seed', default = 0, type = int) | ||
parser.add_argument('--backbone-wd', default=1e-4, type=float) | ||
parser.add_argument('--backbone-lr', default=1e-5, type=float) | ||
parser.add_argument('--embedding-lr', default=1e-5, type=float) | ||
parser.add_argument('--embedding-wd', default=1e-4, type=float) | ||
parser.add_argument('--verbose', action = 'store_true') | ||
args = vars(parser.parse_args()) | ||
|
||
config = train.load_config(config_name = 'config.json') | ||
|
||
config['dataloader']['batch_size'] = args.pop('sz_batch') | ||
config['dataloader']['num_workers'] = args.pop('num_workers') | ||
config['recluster']['mod_epoch'] = args.pop('mod_epoch') | ||
config['opt']['backbone']['lr'] = args.pop('backbone_lr') | ||
config['opt']['backbone']['weight_decay'] = args.pop('backbone_wd') | ||
config['opt']['embedding']['lr'] = args.pop('embedding_lr') | ||
config['opt']['embedding']['weight_decay'] = args.pop('embedding_wd') | ||
|
||
for k in args: | ||
if k in config: | ||
config[k] = args[k] | ||
|
||
if config['nb_clusters'] == 1: | ||
config['recluster']['enabled'] = False | ||
|
||
config['log'] = { | ||
'name': '{}-K-{}-M-{}-exp-{}'.format( | ||
config['dataset_selected'], | ||
config['nb_clusters'], | ||
config['recluster']['mod_epoch'], | ||
args['exp'] | ||
), | ||
'path': 'log/{}'.format(args['dir']) | ||
} | ||
|
||
# tkinter not installed on this system, use non-GUI backend | ||
matplotlib.use('agg') | ||
train.start(config) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
|
||
.DS_Store | ||
._.DS_Store | ||
__pycache__/ | ||
.ipynb_checkpoints/ | ||
._* | ||
.nfs* | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from . import loss | ||
from . import utils | ||
from . import evaluation | ||
from . import similarity | ||
from . import model | ||
from . import data | ||
from . import clustering | ||
|
Oops, something went wrong.