diff --git a/configs/Market1501/bagtricks_vit.yml b/configs/Market1501/bagtricks_vit.yml new file mode 100644 index 00000000..e4945b7c --- /dev/null +++ b/configs/Market1501/bagtricks_vit.yml @@ -0,0 +1,88 @@ + +MODEL: + META_ARCHITECTURE: Baseline + PIXEL_MEAN: [127.5, 127.5, 127.5] + PIXEL_STD: [127.5, 127.5, 127.5] + + BACKBONE: + NAME: build_vit_backbone + DEPTH: base + FEAT_DIM: 768 + PRETRAIN: True + PRETRAIN_PATH: /export/home/lxy/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth + STRIDE_SIZE: (16, 16) + DROP_PATH_RATIO: 0.1 + DROP_RATIO: 0.0 + ATT_DROP_RATE: 0.0 + + HEADS: + NAME: EmbeddingHead + NORM: BN + WITH_BNNECK: True + POOL_LAYER: Identity + NECK_FEAT: before + CLS_LAYER: Linear + + LOSSES: + NAME: ("CrossEntropyLoss", "TripletLoss",) + + CE: + EPSILON: 0. # no smooth + SCALE: 1. + + TRI: + MARGIN: 0.0 + HARD_MINING: True + NORM_FEAT: False + SCALE: 1. + +INPUT: + SIZE_TRAIN: [ 256, 128 ] + SIZE_TEST: [ 256, 128 ] + + REA: + ENABLED: True + PROB: 0.5 + + FLIP: + ENABLED: True + + PADDING: + ENABLED: True + +DATALOADER: + SAMPLER_TRAIN: NaiveIdentitySampler + NUM_INSTANCE: 4 + NUM_WORKERS: 8 + +SOLVER: + AMP: + ENABLED: False + OPT: SGD + MAX_EPOCH: 120 + BASE_LR: 0.008 + WEIGHT_DECAY: 0.0001 + IMS_PER_BATCH: 64 + + SCHED: CosineAnnealingLR + ETA_MIN_LR: 0.000016 + + WARMUP_FACTOR: 0.01 + WARMUP_ITERS: 1000 + + CLIP_GRADIENTS: + ENABLED: True + + CHECKPOINT_PERIOD: 30 + +TEST: + EVAL_PERIOD: 5 + IMS_PER_BATCH: 128 + +CUDNN_BENCHMARK: True + +DATASETS: + NAMES: ("Market1501",) + TESTS: ("Market1501",) + +OUTPUT_DIR: logs/market1501/sbs_vit_base diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py index d5c119f4..e6bc403d 100644 --- a/fastreid/config/defaults.py +++ b/fastreid/config/defaults.py @@ -23,7 +23,7 @@ _C.MODEL.DEVICE = "cuda" _C.MODEL.META_ARCHITECTURE = "Baseline" -_C.MODEL.FREEZE_LAYERS = [''] +_C.MODEL.FREEZE_LAYERS = [] # MoCo memory size _C.MODEL.QUEUE_SIZE = 8192 @@ -46,6 +46,12 @@ _C.MODEL.BACKBONE.WITH_SE = False # If use Non-local block in backbone _C.MODEL.BACKBONE.WITH_NL = False +# Vision Transformer options +_C.MODEL.BACKBONE.SIE_COE = 3.0 +_C.MODEL.BACKBONE.STRIDE_SIZE = (16, 16) +_C.MODEL.BACKBONE.DROP_PATH_RATIO = 0.1 +_C.MODEL.BACKBONE.DROP_RATIO = 0.0 +_C.MODEL.BACKBONE.ATT_DROP_RATE = 0.0 # If use ImageNet pretrain model _C.MODEL.BACKBONE.PRETRAIN = False # Pretrain model path @@ -128,8 +134,10 @@ # ----------------------------------------------------------------------------- _C.KD = CN() -_C.KD.MODEL_CONFIG = ['',] -_C.KD.MODEL_WEIGHTS = ['',] +_C.KD.MODEL_CONFIG = [] +_C.KD.MODEL_WEIGHTS = [] +_C.KD.EMA = CN({"ENABLED": False}) +_C.KD.EMA.MOMENTUM = 0.999 # ----------------------------------------------------------------------------- # INPUT @@ -223,14 +231,25 @@ _C.SOLVER.MAX_EPOCH = 120 _C.SOLVER.BASE_LR = 3e-4 -_C.SOLVER.BIAS_LR_FACTOR = 1. + +# This LR is applied to the last classification layer if +# you want to 10x higher than BASE_LR. _C.SOLVER.HEADS_LR_FACTOR = 1. _C.SOLVER.MOMENTUM = 0.9 _C.SOLVER.NESTEROV = False _C.SOLVER.WEIGHT_DECAY = 0.0005 -_C.SOLVER.WEIGHT_DECAY_BIAS = 0. +# The weight decay that's applied to parameters of normalization layers +# (typically the affine transformation) +_C.SOLVER.WEIGHT_DECAY_NORM = 0.0 + +# The previous detection code used a 2x higher LR and 0 WD for bias. +# This is not useful (at least for recent models). You should avoid +# changing these and they exists only to reproduce previous model +# training if desired. +_C.SOLVER.BIAS_LR_FACTOR = 1.0 +_C.SOLVER.WEIGHT_DECAY_BIAS = _C.SOLVER.WEIGHT_DECAY # Multi-step learning rate options _C.SOLVER.SCHED = "MultiStepLR" @@ -251,33 +270,31 @@ # Backbone freeze iters _C.SOLVER.FREEZE_ITERS = 0 -# FC freeze iters -_C.SOLVER.FREEZE_FC_ITERS = 0 - - -# SWA options -# _C.SOLVER.SWA = CN() -# _C.SOLVER.SWA.ENABLED = False -# _C.SOLVER.SWA.ITER = 10 -# _C.SOLVER.SWA.PERIOD = 2 -# _C.SOLVER.SWA.LR_FACTOR = 10. -# _C.SOLVER.SWA.ETA_MIN_LR = 3.5e-6 -# _C.SOLVER.SWA.LR_SCHED = False - _C.SOLVER.CHECKPOINT_PERIOD = 20 # Number of images per batch across all machines. -# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will -# see 2 images per batch +# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 256, each GPU will +# see 32 images per batch _C.SOLVER.IMS_PER_BATCH = 64 -# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will -# see 2 images per batch +# Gradient clipping +_C.SOLVER.CLIP_GRADIENTS = CN({"ENABLED": False}) +# Type of gradient clipping, currently 2 values are supported: +# - "value": the absolute values of elements of each gradients are clipped +# - "norm": the norm of the gradient for each parameter is clipped thus +# affecting all elements in the parameter +_C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "norm" +# Maximum absolute value used for clipping gradients +_C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 5.0 +# Floating point number p for L-p norm to be used with the "norm" +# gradient clipping type; for L-inf, please specify .inf +_C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0 + _C.TEST = CN() _C.TEST.EVAL_PERIOD = 20 -# Number of images per batch in one process. +# Number of images per batch across all machines. _C.TEST.IMS_PER_BATCH = 64 _C.TEST.METRIC = "cosine" _C.TEST.ROC = CN({"ENABLED": False}) diff --git a/fastreid/layers/drop.py b/fastreid/layers/drop.py new file mode 100644 index 00000000..5a6750f3 --- /dev/null +++ b/fastreid/layers/drop.py @@ -0,0 +1,161 @@ +""" DropBlock, DropPath +PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. +Papers: +DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) +Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) +Code: +DropBlock impl inspired by two Tensorflow impl that I liked: + - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 + - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def drop_block_2d( + x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, + with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + DropBlock with an experimental gaussian noise option. This layer has been tested on a few training + runs with success, but needs further validation and possibly optimization for lower runtime impact. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + # seed_drop_rate, the gamma parameter + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) + + if batchwise: + # one mask for whole batch, quite a bit faster + uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) + else: + uniform_noise = torch.rand_like(x) + block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, + stride=1, + padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) + else: + x = x * block_mask + normal_noise * (1 - block_mask) + else: + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +def drop_block_fast_2d( + x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, + gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid + block mask at edges. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + if batchwise: + # one mask for whole batch, quite a bit faster + block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma + else: + # mask per batch element + block_mask = torch.rand_like(x) < gamma + block_mask = F.max_pool2d( + block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(1. - block_mask).add_(normal_noise * block_mask) + else: + x = x * (1. - block_mask) + normal_noise * block_mask + else: + block_mask = 1 - block_mask + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +class DropBlock2d(nn.Module): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + + def __init__(self, + drop_prob=0.1, + block_size=7, + gamma_scale=1.0, + with_noise=False, + inplace=False, + batchwise=False, + fast=True): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + self.inplace = inplace + self.batchwise = batchwise + self.fast = fast # FIXME finish comparisons of fast vs not + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + if self.fast: + return drop_block_fast_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + else: + return drop_block_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/fastreid/layers/helpers.py b/fastreid/layers/helpers.py new file mode 100644 index 00000000..af54b408 --- /dev/null +++ b/fastreid/layers/helpers.py @@ -0,0 +1,31 @@ +""" Layer/Module Helpers +Hacked together by / Copyright 2020 Ross Wightman +""" +import collections.abc +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def make_divisible(v, divisor=8, min_value=None): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v diff --git a/fastreid/layers/weight_init.py b/fastreid/layers/weight_init.py new file mode 100644 index 00000000..390039f1 --- /dev/null +++ b/fastreid/layers/weight_init.py @@ -0,0 +1,122 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +import math +import warnings + +import torch +from torch import nn + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif classname.find('BatchNorm') != -1: + if m.affine: + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0.0) + + +def weights_init_classifier(m): + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + nn.init.normal_(m.weight, std=0.001) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + + +from torch.nn.init import _calculate_fan_in_and_fan_out + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == 'fan_in': + denom = fan_in + elif mode == 'fan_out': + denom = fan_out + elif mode == 'fan_avg': + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') diff --git a/fastreid/modeling/backbones/__init__.py b/fastreid/modeling/backbones/__init__.py index 5387ef10..0b24e913 100644 --- a/fastreid/modeling/backbones/__init__.py +++ b/fastreid/modeling/backbones/__init__.py @@ -14,3 +14,4 @@ from .shufflenet import build_shufflenetv2_backbone from .mobilenet import build_mobilenetv2_backbone from .repvgg import build_repvgg_backbone +from .vision_transformer import build_vit_backbone diff --git a/fastreid/modeling/backbones/vision_transformer.py b/fastreid/modeling/backbones/vision_transformer.py new file mode 100644 index 00000000..821483de --- /dev/null +++ b/fastreid/modeling/backbones/vision_transformer.py @@ -0,0 +1,399 @@ +""" Vision Transformer (ViT) in PyTorch +A PyTorch implement of Vision Transformers as described in +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 +The official jax code is released and available at https://github.com/google-research/vision_transformer +Status/TODO: +* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights. +* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches. +* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code. +* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future. +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert +Hacked together by / Copyright 2020 Ross Wightman +""" + +import logging +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fastreid.layers import DropPath, trunc_normal_, to_2tuple +from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message +from .build import BACKBONE_REGISTRY + +logger = logging.getLogger(__name__) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + o = o[-1] # last feature if backbone outputs list/tuple of features + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + if hasattr(self.backbone, 'feature_info'): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Conv2d(feature_dim, embed_dim, 1) + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed_overlap(nn.Module): + """ Image to Patch Embedding with overlapping patches + """ + + def __init__(self, img_size=224, patch_size=16, stride_size=20, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + stride_size_tuple = to_2tuple(stride_size) + self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1 + self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1 + num_patches = self.num_x * self.num_y + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride_size) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.InstanceNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, x): + B, C, H, W = x.shape + + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + + x = x.flatten(2).transpose(1, 2) # [64, 8, 768] + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__(self, img_size=224, patch_size=16, stride_size=16, in_chans=3, embed_dim=768, + depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., camera=0, drop_path_rate=0., hybrid_backbone=None, + norm_layer=partial(nn.LayerNorm, eps=1e-6), sie_xishu=1.0): + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed_overlap( + img_size=img_size, patch_size=patch_size, stride_size=stride_size, in_chans=in_chans, + embed_dim=embed_dim) + + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.cam_num = camera + self.sie_xishu = sie_xishu + # Initialize SIE Embedding + if camera > 1: + self.sie_embed = nn.Parameter(torch.zeros(camera, 1, embed_dim)) + trunc_normal_(self.sie_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + trunc_normal_(self.cls_token, std=.02) + trunc_normal_(self.pos_embed, std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward(self, x, camera_id=None): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + if self.cam_num > 0: + x = x + self.pos_embed + self.sie_xishu * self.sie_embed[camera_id] + else: + x = x + self.pos_embed + + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x[:, 0].reshape(x.shape[0], -1, 1, 1) + + +def resize_pos_embed(posemb, posemb_new, hight, width): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + ntok_new = posemb_new.shape[1] + + posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:] + ntok_new -= 1 + + gs_old = int(math.sqrt(len(posemb_grid))) + logger.info('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, + posemb_new.shape, + hight, + width)) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) + posemb = torch.cat([posemb_token, posemb_grid], dim=1) + return posemb + + +@BACKBONE_REGISTRY.register() +def build_vit_backbone(cfg): + """ + Create a Vision Transformer instance from config. + Returns: + SwinTransformer: a :class:`SwinTransformer` instance. + """ + # fmt: off + input_size = cfg.INPUT.SIZE_TRAIN + pretrain = cfg.MODEL.BACKBONE.PRETRAIN + pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH + depth = cfg.MODEL.BACKBONE.DEPTH + sie_xishu = cfg.MODEL.BACKBONE.SIE_COE + stride_size = cfg.MODEL.BACKBONE.STRIDE_SIZE + drop_ratio = cfg.MODEL.BACKBONE.DROP_RATIO + drop_path_ratio = cfg.MODEL.BACKBONE.DROP_PATH_RATIO + attn_drop_rate = cfg.MODEL.BACKBONE.ATT_DROP_RATE + # fmt: on + + num_depth = { + 'small': 8, + 'base': 12, + }[depth] + + num_heads = { + 'small': 8, + 'base': 12, + }[depth] + + mlp_ratio = { + 'small': 3., + 'base': 4. + }[depth] + + qkv_bias = { + 'small': False, + 'base': True + }[depth] + + qk_scale = { + 'small': 768 ** -0.5, + 'base': None, + }[depth] + + model = VisionTransformer(img_size=input_size, sie_xishu=sie_xishu, stride_size=stride_size, depth=num_depth, + num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop_path_rate=drop_path_ratio, drop_rate=drop_ratio, attn_drop_rate=attn_drop_rate) + + if pretrain: + try: + state_dict = torch.load(pretrain_path, map_location=torch.device('cpu')) + logger.info(f"Loading pretrained model from {pretrain_path}") + + if 'model' in state_dict: + state_dict = state_dict.pop('model') + if 'state_dict' in state_dict: + state_dict = state_dict.pop('state_dict') + for k, v in state_dict.items(): + if 'head' in k or 'dist' in k: + continue + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + # To resize pos embedding when using model at different size from pretrained weights + if 'distilled' in pretrain_path: + logger.info("distill need to choose right cls token in the pth.") + v = torch.cat([v[:, 0:1], v[:, 2:]], dim=1) + v = resize_pos_embed(v, model.pos_embed.data, model.patch_embed.num_y, model.patch_embed.num_x) + state_dict[k] = v + except FileNotFoundError as e: + logger.info(f'{pretrain_path} is not found! Please check this path.') + raise e + except KeyError as e: + logger.info("State dict keys error! Please check the state dict.") + raise e + + incompatible = model.load_state_dict(state_dict, strict=False) + if incompatible.missing_keys: + logger.info( + get_missing_parameters_message(incompatible.missing_keys) + ) + if incompatible.unexpected_keys: + logger.info( + get_unexpected_parameters_message(incompatible.unexpected_keys) + ) + + return model diff --git a/fastreid/utils/weight_init.py b/fastreid/utils/weight_init.py deleted file mode 100644 index 34871921..00000000 --- a/fastreid/utils/weight_init.py +++ /dev/null @@ -1,36 +0,0 @@ -# encoding: utf-8 -""" -@author: xingyu liao -@contact: sherlockliao01@gmail.com -""" - -from torch import nn - -__all__ = [ - 'weights_init_classifier', - 'weights_init_kaiming', -] - - -def weights_init_kaiming(m): - classname = m.__class__.__name__ - if classname.find('Linear') != -1: - nn.init.normal_(m.weight, 0, 0.01) - if m.bias is not None: - nn.init.constant_(m.bias, 0.0) - elif classname.find('Conv') != -1: - nn.init.kaiming_normal_(m.weight, mode='fan_out') - if m.bias is not None: - nn.init.constant_(m.bias, 0.0) - elif classname.find('BatchNorm') != -1: - if m.affine: - nn.init.constant_(m.weight, 1.0) - nn.init.constant_(m.bias, 0.0) - - -def weights_init_classifier(m): - classname = m.__class__.__name__ - if classname.find('Linear') != -1: - nn.init.normal_(m.weight, std=0.001) - if m.bias is not None: - nn.init.constant_(m.bias, 0.0)