In [1]:
import argparse
from fvcore.nn import FlopCountAnalysis, parameter_count_table

from conf_finetune import _C as cfg
from main_finetune import initialize_data_loader, initialize_model

In [2]:
import torch
import torch.nn.functional as F
import torch.optim
import torch.utils.data

import datasets
import models
import utils.transforms
from utils.lr_decay import param_groups_lrd
from utils.sampler import RandomIdentitySampler
from utils.scaler import NativeScalerWithGradNormCount
from utils.triplet_loss import TripletLoss

def initialize_model(cfg, num_classes, device_id):
    # logger.info(f'creating model: {cfg.MODEL.NAME}')
    model = models.__dict__[cfg.MODEL.NAME](cfg, num_classes)
    model.cuda(device_id)
    
    triplet = TripletLoss()
    def loss_func(feats, logits, target):
        if not isinstance(feats, tuple) and not isinstance(logits, tuple):
            id_loss = F.cross_entropy(logits, target)
            tri_loss = triplet(feats, target)[0]
        else:
            id_loss = [F.cross_entropy(logit, target) for logit in logits]
            id_loss = sum(id_loss) / len(id_loss)
            tri_loss = [triplet(feat, target)[0] for feat in feats]
            tri_loss = sum(tri_loss) / len(tri_loss)
        return cfg.MODEL.ID_LOSS_WEIGHT * id_loss + cfg.MODEL.TRI_LOSS_WEIGHT * tri_loss, id_loss, tri_loss
    
    param_groups = param_groups_lrd(model, cfg.OPTIMIZER.WEIGHT_DECAY, model.no_weight_decay(), cfg.OPTIMIZER.LAYER_DECAY)
    optimizer = torch.optim.AdamW(param_groups, cfg.OPTIMIZER.LR, cfg.OPTIMIZER.BETAS)
    scaler = NativeScalerWithGradNormCount()
    return model, loss_func, optimizer, scaler


def initialize_data_loader(cfg):
    train_transform = utils.transforms.__dict__[cfg.INPUT.TRANSFORM](cfg)
    train_dataset = datasets.__dict__[cfg.DATASET.NAME](cfg, train_transform, is_train=True)
    num_classes = train_dataset.num_classes
    train_sampler = RandomIdentitySampler(train_dataset, cfg.ENGINE.BATCH_SIZE, cfg.DATALOADER.NUM_INSTANCES)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.ENGINE.BATCH_SIZE, 
        num_workers=cfg.DATALOADER.NUM_WORKERS, 
        pin_memory=True, 
        sampler=train_sampler
    )

    val_transform = utils.transforms.__dict__[cfg.VALIDATE.TRANSFORM](cfg)
    val_dataset = datasets.__dict__[cfg.DATASET.NAME](cfg, val_transform, is_train=False)
    num_queries = val_dataset.num_queries
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=cfg.VALIDATE.BATCH_SIZE,
        shuffle=False,
        num_workers=cfg.DATALOADER.NUM_WORKERS,
        pin_memory=True
    )

    return train_loader, val_loader, num_classes, num_queries

In [7]:
parser = argparse.ArgumentParser(description='Antelope fine-tuning')
parser.add_argument('--config_file', default='', help='path to config file', type=str)
parser.add_argument('opts', help='modify config options using the command-line', default=None, nargs=argparse.REMAINDER)
args = parser.parse_args(args=['--config_file', 'configs/finetune/MSMT17/mae_inet_lup_vitb_ep800_ratio_optimized/baseline.yaml'])
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()

In [8]:
train_loader, val_loader, num_classes, num_queries = initialize_data_loader(cfg)
model, criterion, optimizer, scaler = initialize_model(cfg, num_classes, device_id=0)

In [16]:
batch = next(iter(train_loader))
flops = FlopCountAnalysis(model, batch[0].cuda())
print(flops.total())

  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
Unsupported operator aten::add encountered 25 time(s)
Unsupported operator aten::div encountered 12 time(s)
Unsupported operator aten::mul encountered 34 time(s)
Unsupported operator aten::softmax encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::bernoulli_ encountered 22 time(s)
Unsupported operator aten::div_ encountered 22 time(s)


726526623744


In [17]:
print(parameter_count_table(model))

| name                       | #elements or shape   |
|:---------------------------|:---------------------|
| model                      | 86.5M                |
|  cls_token                 |  (1, 1, 768)         |
|  pos_embed                 |  (1, 129, 768)       |
|  patch_embed               |  0.6M                |
|   patch_embed.proj         |   0.6M               |
|    patch_embed.proj.weight |    (768, 3, 16, 16)  |
|    patch_embed.proj.bias   |    (768,)            |
|  blocks                    |  85.1M               |
|   blocks.0                 |   7.1M               |
|    blocks.0.norm1          |    1.5K              |
|    blocks.0.attn           |    2.4M              |
|    blocks.0.norm2          |    1.5K              |
|    blocks.0.mlp            |    4.7M              |
|   blocks.1                 |   7.1M               |
|    blocks.1.norm1          |    1.5K              |
|    blocks.1.attn           |    2.4M              |
|    blocks.1.norm2         

In [3]:
parser = argparse.ArgumentParser(description='Antelope fine-tuning')
parser.add_argument('--config_file', default='', help='path to config file', type=str)
parser.add_argument('opts', help='modify config options using the command-line', default=None, nargs=argparse.REMAINDER)
args = parser.parse_args(args=['--config_file', 'configs/finetune/MSMT17/mae_inet_lup_vitb_ep800_ratio_optimized/lem_pool.yaml'])
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()

In [4]:
train_loader, val_loader, num_classes, num_queries = initialize_data_loader(cfg)
model, criterion, optimizer, scaler = initialize_model(cfg, num_classes, device_id=0)

In [5]:
batch = next(iter(train_loader))
flops = FlopCountAnalysis(model, batch[0].cuda())
print(flops.total())

  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
Unsupported operator aten::add encountered 25 time(s)
Unsupported operator aten::div encountered 12 time(s)
Unsupported operator aten::mul encountered 34 time(s)
Unsupported operator aten::softmax encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::bernoulli_ encountered 22 time(s)
Unsupported operator aten::div_ encountered 22 time(s)
Unsupported operator aten::mean encountered 1 time(s)
Unsupported operator aten::add_ encountered 2 time(s)


726578528256


In [6]:
print(parameter_count_table(model))

| name                       | #elements or shape   |
|:---------------------------|:---------------------|
| model                      | 87.4M                |
|  cls_token                 |  (1, 1, 768)         |
|  pos_embed                 |  (1, 129, 768)       |
|  patch_embed               |  0.6M                |
|   patch_embed.proj         |   0.6M               |
|    patch_embed.proj.weight |    (768, 3, 16, 16)  |
|    patch_embed.proj.bias   |    (768,)            |
|  blocks                    |  85.1M               |
|   blocks.0                 |   7.1M               |
|    blocks.0.norm1          |    1.5K              |
|    blocks.0.attn           |    2.4M              |
|    blocks.0.norm2          |    1.5K              |
|    blocks.0.mlp            |    4.7M              |
|   blocks.1                 |   7.1M               |
|    blocks.1.norm1          |    1.5K              |
|    blocks.1.attn           |    2.4M              |
|    blocks.1.norm2         

In [3]:
parser = argparse.ArgumentParser(description='Antelope fine-tuning')
parser.add_argument('--config_file', default='', help='path to config file', type=str)
parser.add_argument('opts', help='modify config options using the command-line', default=None, nargs=argparse.REMAINDER)
args = parser.parse_args(args=['--config_file', 'configs/finetune/MSMT17/mae_inet_lup_vitb_ep800_ratio_optimized/lem_tran.yaml'])
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()

In [4]:
train_loader, val_loader, num_classes, num_queries = initialize_data_loader(cfg)
model, criterion, optimizer, scaler = initialize_model(cfg, num_classes, device_id=0)

In [5]:
batch = next(iter(train_loader))
flops = FlopCountAnalysis(model, batch[0].cuda())
print(flops.total())

  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
Unsupported operator aten::add encountered 27 time(s)
Unsupported operator aten::div encountered 13 time(s)
Unsupported operator aten::mul encountered 37 time(s)
Unsupported operator aten::softmax encountered 13 time(s)
Unsupported operator aten::gelu encountered 13 time(s)
Unsupported operator aten::bernoulli_ encountered 24 time(s)
Unsupported operator aten::div_ encountered 24 time(s)
Unsupported operator aten::mean encountered 1 time(s)
Unsupported operator aten::add_ encountered 2 time(s)


786234114048


In [6]:
print(parameter_count_table(model))

| name                       | #elements or shape   |
|:---------------------------|:---------------------|
| model                      | 94.4M                |
|  cls_token                 |  (1, 1, 768)         |
|  pos_embed                 |  (1, 129, 768)       |
|  patch_embed               |  0.6M                |
|   patch_embed.proj         |   0.6M               |
|    patch_embed.proj.weight |    (768, 3, 16, 16)  |
|    patch_embed.proj.bias   |    (768,)            |
|  blocks                    |  85.1M               |
|   blocks.0                 |   7.1M               |
|    blocks.0.norm1          |    1.5K              |
|    blocks.0.attn           |    2.4M              |
|    blocks.0.norm2          |    1.5K              |
|    blocks.0.mlp            |    4.7M              |
|   blocks.1                 |   7.1M               |
|    blocks.1.norm1          |    1.5K              |
|    blocks.1.attn           |    2.4M              |
|    blocks.1.norm2         

In [3]:
parser = argparse.ArgumentParser(description='Antelope fine-tuning')
parser.add_argument('--config_file', default='', help='path to config file', type=str)
parser.add_argument('opts', help='modify config options using the command-line', default=None, nargs=argparse.REMAINDER)
args = parser.parse_args(args=['--config_file', 'configs/finetune/MSMT17/mae_inet_lup_vitb_ep800_ratio_optimized/lem.yaml'])
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()

In [4]:
train_loader, val_loader, num_classes, num_queries = initialize_data_loader(cfg)
model, criterion, optimizer, scaler = initialize_model(cfg, num_classes, device_id=0)

In [5]:
batch = next(iter(train_loader))
flops = FlopCountAnalysis(model, batch[0].cuda())
print(flops.total())

  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
Unsupported operator aten::add encountered 25 time(s)
Unsupported operator aten::div encountered 12 time(s)
Unsupported operator aten::mul encountered 34 time(s)
Unsupported operator aten::softmax encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::bernoulli_ encountered 22 time(s)
Unsupported operator aten::div_ encountered 22 time(s)
Unsupported operator aten::add_ encountered 15 time(s)


745756999680


In [6]:
print(parameter_count_table(model))

| name                       | #elements or shape   |
|:---------------------------|:---------------------|
| model                      | 96.6M                |
|  cls_token                 |  (1, 1, 768)         |
|  pos_embed                 |  (1, 129, 768)       |
|  patch_embed               |  0.6M                |
|   patch_embed.proj         |   0.6M               |
|    patch_embed.proj.weight |    (768, 3, 16, 16)  |
|    patch_embed.proj.bias   |    (768,)            |
|  blocks                    |  85.1M               |
|   blocks.0                 |   7.1M               |
|    blocks.0.norm1          |    1.5K              |
|    blocks.0.attn           |    2.4M              |
|    blocks.0.norm2          |    1.5K              |
|    blocks.0.mlp            |    4.7M              |
|   blocks.1                 |   7.1M               |
|    blocks.1.norm1          |    1.5K              |
|    blocks.1.attn           |    2.4M              |
|    blocks.1.norm2         

In [3]:
parser = argparse.ArgumentParser(description='Antelope fine-tuning')
parser.add_argument('--config_file', default='', help='path to config file', type=str)
parser.add_argument('opts', help='modify config options using the command-line', default=None, nargs=argparse.REMAINDER)
args = parser.parse_args(args=['--config_file', 'configs/finetune/MSMT17/mae_inet_lup_vitb_ep800_ratio_optimized/lem_plus.yaml'])
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()

In [4]:
train_loader, val_loader, num_classes, num_queries = initialize_data_loader(cfg)
model, criterion, optimizer, scaler = initialize_model(cfg, num_classes, device_id=0)

In [5]:
batch = next(iter(train_loader))
flops = FlopCountAnalysis(model, batch[0].cuda())
print(flops.total())

  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
Unsupported operator aten::add encountered 25 time(s)
Unsupported operator aten::div encountered 12 time(s)
Unsupported operator aten::mul encountered 34 time(s)
Unsupported operator aten::softmax encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
Unsupported operator aten::bernoulli_ encountered 22 time(s)
Unsupported operator aten::div_ encountered 22 time(s)
Unsupported operator aten::add_ encountered 9 time(s)


881330798592


In [6]:
print(parameter_count_table(model))

| name                       | #elements or shape   |
|:---------------------------|:---------------------|
| model                      | 0.2G                 |
|  cls_token                 |  (1, 1, 768)         |
|  pos_embed                 |  (1, 129, 768)       |
|  patch_embed               |  0.6M                |
|   patch_embed.proj         |   0.6M               |
|    patch_embed.proj.weight |    (768, 3, 16, 16)  |
|    patch_embed.proj.bias   |    (768,)            |
|  blocks                    |  85.1M               |
|   blocks.0                 |   7.1M               |
|    blocks.0.norm1          |    1.5K              |
|    blocks.0.attn           |    2.4M              |
|    blocks.0.norm2          |    1.5K              |
|    blocks.0.mlp            |    4.7M              |
|   blocks.1                 |   7.1M               |
|    blocks.1.norm1          |    1.5K              |
|    blocks.1.attn           |    2.4M              |
|    blocks.1.norm2         

In [7]:
from fvcore.nn import parameter_count
parameter_count(model)

defaultdict(int,
            {'': 163667490,
             'cls_token': 768,
             'pos_embed': 99072,
             'patch_embed': 590592,
             'patch_embed.proj': 590592,
             'patch_embed.proj.weight': 589824,
             'patch_embed.proj.bias': 768,
             'blocks': 85054464,
             'blocks.0': 7087872,
             'blocks.0.norm1': 1536,
             'blocks.0.norm1.weight': 768,
             'blocks.0.norm1.bias': 768,
             'blocks.0.attn': 2362368,
             'blocks.0.attn.qkv': 1771776,
             'blocks.0.attn.qkv.weight': 1769472,
             'blocks.0.attn.qkv.bias': 2304,
             'blocks.0.attn.proj': 590592,
             'blocks.0.attn.proj.weight': 589824,
             'blocks.0.attn.proj.bias': 768,
             'blocks.0.norm2': 1536,
             'blocks.0.norm2.weight': 768,
             'blocks.0.norm2.bias': 768,
             'blocks.0.mlp': 4722432,
             'blocks.0.mlp.fc1': 2362368,
             'block