Skip to content

Commit

Permalink
update fastreid v1.2
Browse files Browse the repository at this point in the history
Summary:
1. refactor dataloader and heads
2. bugfix in fastattr, fastclas, fastface and partialreid
3. partial-fc supported in fastface
  • Loading branch information
L1aoXingyu committed Apr 2, 2021
1 parent 9288db6 commit 44cee30
Show file tree
Hide file tree
Showing 40 changed files with 862 additions and 478 deletions.
1 change: 0 additions & 1 deletion configs/VeRi/sbs_R50-ibn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ MODEL:

SOLVER:
OPT: SGD
NESTEROV: True
BASE_LR: 0.01
ETA_MIN_LR: 7.7e-5

Expand Down
4 changes: 2 additions & 2 deletions fastreid/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@
_C.MODEL.HEADS.CLS_LAYER = "Linear" # ArcSoftmax" or "CircleSoftmax"

# Margin and Scale for margin-based classification layer
_C.MODEL.HEADS.MARGIN = 0.15
_C.MODEL.HEADS.SCALE = 128
_C.MODEL.HEADS.MARGIN = 0.
_C.MODEL.HEADS.SCALE = 1

# ---------------------------------------------------------------------------- #
# REID LOSSES options
Expand Down
41 changes: 20 additions & 21 deletions fastreid/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,19 @@
_root = os.getenv("FASTREID_DATASETS", "datasets")


def _train_loader_from_config(cfg, *, Dataset=None, transforms=None, sampler=None, **kwargs):
def _train_loader_from_config(cfg, *, train_set=None, transforms=None, sampler=None, **kwargs):
if transforms is None:
transforms = build_transforms(cfg, is_train=True)

if Dataset is None:
Dataset = CommDataset
if train_set is None:
train_items = list()
for d in cfg.DATASETS.NAMES:
data = DATASET_REGISTRY.get(d)(root=_root, **kwargs)
if comm.is_main_process():
data.show_train()
train_items.extend(data.train)

train_items = list()
for d in cfg.DATASETS.NAMES:
data = DATASET_REGISTRY.get(d)(root=_root, **kwargs)
if comm.is_main_process():
data.show_train()
train_items.extend(data.train)

train_set = Dataset(train_items, transforms, relabel=True)
train_set = CommDataset(train_items, transforms, relabel=True)

if sampler is None:
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
Expand Down Expand Up @@ -92,24 +90,25 @@ def build_reid_train_loader(
return train_loader


def _test_loader_from_config(cfg, dataset_name, *, Dataset=None, transforms=None, **kwargs):
def _test_loader_from_config(cfg, *, dataset_name=None, test_set=None, num_query=0, transforms=None, **kwargs):
if transforms is None:
transforms = build_transforms(cfg, is_train=False)

if Dataset is None:
Dataset = CommDataset

data = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
if comm.is_main_process():
data.show_test()
test_items = data.query + data.gallery
if test_set is None:
assert dataset_name is not None, "dataset_name must be explicitly passed in when test_set is not provided"
data = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
if comm.is_main_process():
data.show_test()
test_items = data.query + data.gallery
test_set = CommDataset(test_items, transforms, relabel=False)

test_set = Dataset(test_items, transforms, relabel=False)
# Update query number
num_query = len(data.query)

return {
"test_set": test_set,
"test_batch_size": cfg.TEST.IMS_PER_BATCH,
"num_query": len(data.query),
"num_query": num_query,
}


Expand Down
4 changes: 2 additions & 2 deletions fastreid/engine/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def default_setup(cfg, args):
PathManager.mkdirs(output_dir)

rank = comm.get_rank()
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
# setup_logger(output_dir, distributed_rank=rank, name="fvcore")
logger = setup_logger(output_dir, distributed_rank=rank)

logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
Expand Down Expand Up @@ -423,7 +423,7 @@ def build_test_loader(cls, cfg, dataset_name):
It now calls :func:`fastreid.data.build_reid_test_loader`.
Overwrite it if you'd like a different data loader.
"""
return build_reid_test_loader(cfg, dataset_name)
return build_reid_test_loader(cfg, dataset_name=dataset_name)

@classmethod
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
Expand Down
1 change: 1 addition & 0 deletions fastreid/engine/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def _do_eval(self):
)
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)

torch.cuda.empty_cache()
# Evaluation may take different time among workers.
# A barrier make them start the next iteration together.
comm.synchronize()
Expand Down
113 changes: 40 additions & 73 deletions fastreid/layers/any_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,110 +4,77 @@
@contact: sherlockliao01@gmail.com
"""

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = [
'Linear',
'ArcSoftmax',
'CosSoftmax',
'CircleSoftmax'
"Linear",
"ArcSoftmax",
"CosSoftmax",
"CircleSoftmax"
]


class Linear(nn.Module):
def __init__(self, num_classes, scale, margin):
super().__init__()
self._num_classes = num_classes
self.s = 1
self.m = 0
self.num_classes = num_classes
self.s = scale
self.m = margin

def forward(self, logits, *args):
def forward(self, logits, targets):
return logits

def extra_repr(self):
return 'num_classes={}, scale={}, margin={}'.format(self._num_classes, self.s, self.m)
return f"num_classes={self.num_classes}, scale={self.s}, margin={self.m}"


class ArcSoftmax(nn.Module):
def __init__(self, num_classes, scale, margin):
super().__init__()
self._num_classes = num_classes
self.s = scale
self.m = margin
class CosSoftmax(Linear):
r"""Implement of large margin cosine distance:
"""

self.easy_margin = False
def forward(self, logits, targets):
index = torch.where(targets != -1)[0]
m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
m_hot.scatter_(1, targets[index, None], self.m)
logits[index] -= m_hot
logits.mul_(self.s)
return logits

self.cos_m = math.cos(self.m)
self.sin_m = math.sin(self.m)
self.threshold = math.cos(math.pi - self.m)
self.mm = math.sin(math.pi - self.m) * self.m

class ArcSoftmax(Linear):

def forward(self, logits, targets):
sine = torch.sqrt(1.0 - torch.pow(logits, 2))
phi = logits * self.cos_m - sine * self.sin_m # cos(theta + m)
if self.easy_margin:
phi = torch.where(logits > 0, phi, logits)
else:
phi = torch.where(logits > self.threshold, phi, logits - self.mm)
one_hot = torch.zeros(logits.size(), device=logits.device)
one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * logits)
output *= self.s
return output
index = torch.where(targets != -1)[0]
m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
m_hot.scatter_(1, targets[index, None], self.m)
logits.acos_()
logits[index] += m_hot
logits.cos_().mul_(self.s)
return logits

def extra_repr(self):
return 'num_classes={}, scale={}, margin={}'.format(self._num_classes, self.s, self.m)


class CircleSoftmax(nn.Module):
def __init__(self, num_classes, scale, margin):
super().__init__()
self._num_classes = num_classes
self.s = scale
self.m = margin
class CircleSoftmax(Linear):

def forward(self, logits, targets):
alpha_p = torch.clamp_min(-logits.detach() + 1 + self.m, min=0.)
alpha_n = torch.clamp_min(logits.detach() + self.m, min=0.)
delta_p = 1 - self.m
delta_n = self.m

s_p = self.s * alpha_p * (logits - delta_p)
s_n = self.s * alpha_n * (logits - delta_n)
# When use model parallel, there are some targets not in class centers of local rank
index = torch.where(targets != -1)[0]
m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype)
m_hot.scatter_(1, targets[index, None], 1)

targets = F.one_hot(targets, num_classes=self._num_classes)
logits_p = alpha_p * (logits - delta_p)
logits_n = alpha_n * (logits - delta_n)

pred_class_logits = targets * s_p + (1.0 - targets) * s_n
logits[index] = logits_p[index] * m_hot + logits_n[index] * (1 - m_hot)

return pred_class_logits
neg_index = torch.where(targets == -1)[0]
logits[neg_index] = logits_n[neg_index]

def extra_repr(self):
return "num_classes={}, scale={}, margin={}".format(self._num_classes, self.s, self.m)
logits.mul_(self.s)


class CosSoftmax(nn.Module):
r"""Implement of large margin cosine distance:
Args:
num_classes: size of each output sample
"""

def __init__(self, num_classes, scale, margin):
super().__init__()
self._num_classes = num_classes
self.s = scale
self.m = margin

def forward(self, logits, targets):
phi = logits - self.m
targets = F.one_hot(targets, num_classes=self._num_classes)
output = (targets * phi) + ((1.0 - targets) * logits)
output *= self.s

return output

def extra_repr(self):
return "num_classes={}, scale={}, margin={}".format(self._num_classes, self.s, self.m)
return logits
10 changes: 1 addition & 9 deletions fastreid/modeling/heads/embedding_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,7 @@ def __init__(
# Linear layer
assert hasattr(any_softmax, cls_type), "Expected cls types are {}, " \
"but got {}".format(any_softmax.__all__, cls_type)
self.weight = nn.Parameter(torch.Tensor(num_classes, feat_dim))
# Initialize weight parameters
if cls_type == "Linear":
nn.init.normal_(self.weight, std=0.001)
elif cls_type == "CircleSoftmax":
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
elif cls_type == "ArcSoftmax" or cls_type == "CosSoftmax":
nn.init.xavier_uniform_(self.weigth)

self.weight = nn.Parameter(torch.normal(0, 0.01, (num_classes, feat_dim)))
self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions projects/FastAttr/fastattr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
@contact: sherlockliao01@gmail.com
"""

from .attr_baseline import AttrBaseline
from .attr_evaluation import AttrEvaluator
from .attr_head import AttrHead
from .config import add_attr_config
from .data_build import build_attr_train_loader, build_attr_test_loader
from .datasets import *
from .modeling import *
from .attr_dataset import AttrDataset
3 changes: 1 addition & 2 deletions projects/FastAttr/fastattr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
def add_attr_config(cfg):
_C = cfg

_C.MODEL.LOSSES.BCE = CN()
_C.MODEL.LOSSES.BCE.WEIGHT_ENABLED = True
_C.MODEL.LOSSES.BCE = CN({"WEIGHT_ENABLED": True})
_C.MODEL.LOSSES.BCE.SCALE = 1.

_C.TEST.THRES = 0.5
74 changes: 0 additions & 74 deletions projects/FastAttr/fastattr/data_build.py

This file was deleted.

9 changes: 9 additions & 0 deletions projects/FastAttr/fastattr/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""

from .attr_baseline import AttrBaseline
from .attr_head import AttrHead
from .bce_loss import cross_entropy_sigmoid_loss
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 44cee30

Please sign in to comment.