/
distiller.py
140 lines (112 loc) · 4.86 KB
/
distiller.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import logging
import torch
import torch.nn.functional as F
from fastreid.config import get_cfg
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY, build_model, Baseline
from fastreid.utils.checkpoint import Checkpointer
logger = logging.getLogger(__name__)
@META_ARCH_REGISTRY.register()
class Distiller(Baseline):
def __init__(self, cfg):
super().__init__(cfg)
# Get teacher model config
model_ts = []
for i in range(len(cfg.KD.MODEL_CONFIG)):
cfg_t = get_cfg()
cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG[i])
cfg_t.defrost()
cfg_t.MODEL.META_ARCHITECTURE = "Baseline"
# Change syncBN to BN due to no DDP wrapper
if cfg_t.MODEL.BACKBONE.NORM == "syncBN":
cfg_t.MODEL.BACKBONE.NORM = "BN"
if cfg_t.MODEL.HEADS.NORM == "syncBN":
cfg_t.MODEL.HEADS.NORM = "BN"
model_t = build_model(cfg_t)
# No gradients for teacher model
for param in model_t.parameters():
param.requires_grad_(False)
logger.info("Loading teacher model weights ...")
Checkpointer(model_t).load(cfg.KD.MODEL_WEIGHTS[i])
model_ts.append(model_t)
self.ema_enabled = cfg.KD.EMA.ENABLED
self.ema_momentum = cfg.KD.EMA.MOMENTUM
if self.ema_enabled:
cfg_self = cfg.clone()
cfg_self.defrost()
cfg_self.MODEL.META_ARCHITECTURE = "Baseline"
if cfg_self.MODEL.BACKBONE.NORM == "syncBN":
cfg_self.MODEL.BACKBONE.NORM = "BN"
if cfg_self.MODEL.HEADS.NORM == "syncBN":
cfg_self.MODEL.HEADS.NORM = "BN"
model_self = build_model(cfg_self)
# No gradients for self model
for param in model_self.parameters():
param.requires_grad_(False)
if cfg_self.MODEL.WEIGHTS != '':
logger.info("Loading self distillation model weights ...")
Checkpointer(model_self).load(cfg_self.MODEL.WEIGHTS)
else:
# Make sure the initial state is same
for param_q, param_k in zip(self.parameters(), model_self.parameters()):
param_k.data.copy_(param_q.data)
model_ts.insert(0, model_self)
# Not register teacher model as `nn.Module`, this is
# make sure teacher model weights not saved
self.model_ts = model_ts
@torch.no_grad()
def _momentum_update_key_encoder(self, m=0.999):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.parameters(), self.model_ts[0].parameters()):
param_k.data = param_k.data * m + param_q.data * (1. - m)
def forward(self, batched_inputs):
if self.training:
images = self.preprocess_image(batched_inputs)
# student model forward
s_feat = self.backbone(images)
assert "targets" in batched_inputs, "Labels are missing in training!"
targets = batched_inputs["targets"].to(self.device)
if targets.sum() < 0: targets.zero_()
s_outputs = self.heads(s_feat, targets)
t_outputs = []
# teacher model forward
with torch.no_grad():
if self.ema_enabled:
self._momentum_update_key_encoder(self.ema_momentum) # update self distill model
for model_t in self.model_ts:
t_feat = model_t.backbone(images)
t_output = model_t.heads(t_feat, targets)
t_outputs.append(t_output)
losses = self.losses(s_outputs, t_outputs, targets)
return losses
# Eval mode, just conventional reid feature extraction
else:
return super().forward(batched_inputs)
def losses(self, s_outputs, t_outputs, gt_labels):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
loss_dict = super().losses(s_outputs, gt_labels)
s_logits = s_outputs['pred_class_logits']
loss_jsdiv = 0.
for t_output in t_outputs:
t_logits = t_output['pred_class_logits'].detach()
loss_jsdiv += self.jsdiv_loss(s_logits, t_logits)
loss_dict["loss_jsdiv"] = loss_jsdiv / len(t_outputs)
return loss_dict
@staticmethod
def _kldiv(y_s, y_t, t):
p_s = F.log_softmax(y_s / t, dim=1)
p_t = F.softmax(y_t / t, dim=1)
loss = F.kl_div(p_s, p_t, reduction="sum") * (t ** 2) / y_s.shape[0]
return loss
def jsdiv_loss(self, y_s, y_t, t=16):
loss = (self._kldiv(y_s, y_t, t) + self._kldiv(y_t, y_s, t)) / 2
return loss