/
central_resnet50.py
63 lines (62 loc) · 2.64 KB
/
central_resnet50.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# model settings
model = dict(type='ImageClassifier',
backbone=dict(type='Central_Model',
backbone_name='resnet',
task_names=('gv_patch', 'gv_global'),
main_task_name='gv_global',
trans_type='crossconvhrnetlayer',
frozen_stages=4,
task_name_to_backbone={
'gv_global':
dict(
depth=50,
frozen_stages=4,
num_stages=4,
out_indices=(3, ),
style='pytorch',
),
'gv_patch':
dict(
depth=50,
frozen_stages=4,
num_stages=4,
out_indices=(3, ),
style='pytorch',
),
},
layer2channel={
'layer1': 256,
'layer2': 512,
'layer3': 1024
},
layer2auxlayers={
'layer1': [
'layer1',
],
'layer2': [
'layer1',
'layer2',
],
'layer3': ['layer1', 'layer2', 'layer3'],
},
trans_layers=['layer1', 'layer2', 'layer3'],
channels=[64, 128, 192],
return_tuple=False,
init_cfg=dict(
type='Pretrained',
checkpoint='checkpoints/up-g/r50-cls-bn.pth',
)),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
init_cfg=dict(
type='Kaiming',
a=2.23606,
mode='fan_out',
nonlinearity='relu',
distribution='uniform',
),
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))