-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathadd_fake_passthrough.py
95 lines (77 loc) · 3.93 KB
/
add_fake_passthrough.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
###################################################################################################
# Copyright (C) 2021 Maxim Integrated Products, Inc. All Rights Reserved.
#
# Maxim Integrated Products, Inc. Default Copyright Notice:
# https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
###################################################################################################
"""
Command line tool to add passthrough layer to a quantized model as identity Conv2D kernels.
"""
import argparse
import copy
from collections import OrderedDict
import torch
from torch import nn
def parse_arguments():
"""Parses command line arguments"""
parser = argparse.ArgumentParser(description="Fake Passthrough Layer Insertion")
parser.add_argument('--input-checkpoint-path', metavar='S', required=True,
help="path to checkpoint file")
parser.add_argument('--output-checkpoint-path', metavar='S', required=True,
help="path to checkpoint file")
parser.add_argument('--layer-name', metavar='S', required=True,
help='name of the added passtrhough layer')
parser.add_argument('--layer-depth', type=int, required=True,
help='depth of the passthrough layer')
parser.add_argument('--layer-name-after-pt', metavar='S', required=True,
help='name of the layer just after the passthrough layer is added')
parser.add_argument('--low-memory-footprint', action='store_true', default=False,
help='enables 2-bit quantization for weights')
args = parser.parse_args()
return args
def passthrough_faker(n_channels, low_memory_footprint=False):
"""Creates passthrough layer"""
a = nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=1, bias=False)
a.weight.data = torch.zeros_like(a.weight.data)
for i in range(a.weight.data.shape[0]):
if low_memory_footprint:
a.weight.data[i, i, :, :] = 1
else:
a.weight.data[i, i, :, :] = 64
return a
def main():
"""Main function to add passthrough layer"""
args = parse_arguments()
device = torch.device('cpu')
checkpoint = torch.load(args.input_checkpoint_path)
passthrough_kernel = passthrough_faker(args.layer_depth, args.low_memory_footprint)
new_checkpoint = copy.deepcopy(checkpoint)
# remove `module.` prefix from the state dictionary keys if model is trained with GPU
# (see:https://discuss.pytorch.org/t/prefix-parameter-names-in-saved-model-if-trained-by-multi-
# gpu/494)
new_state_dict = OrderedDict()
for k, v in new_checkpoint['state_dict'].items():
name = k.replace("module.", '')
new_state_dict[name] = v
new_state_dict[f'{args.layer_name}.output_shift'] = torch.Tensor([1.]).to(device)
if args.low_memory_footprint:
new_state_dict[f'{args.layer_name}.weight_bits'] = torch.Tensor([2.]).to(device)
else:
new_state_dict[f'{args.layer_name}.weight_bits'] = torch.Tensor([8.]).to(device)
new_state_dict[f'{args.layer_name}.bias_bits'] = torch.Tensor([8.]).to(device)
new_state_dict[f'{args.layer_name}.quantize_activation'] = torch.Tensor([1.]).to(device)
new_state_dict[f'{args.layer_name}.adjust_output_shift'] = torch.Tensor([0.]).to(device)
new_state_dict[f'{args.layer_name}.shift_quantile'] = torch.Tensor([1.]).to(device)
new_state_dict[f'{args.layer_name}.op.weight'] = passthrough_kernel.weight.data.to(device)
move_layer = False
for key in list(new_state_dict.keys()):
if not move_layer and key.startswith(args.layer_name_after_pt):
move_layer = True
if move_layer and key.startswith(args.layer_name):
move_layer = False
if move_layer:
new_state_dict.move_to_end(key)
new_checkpoint['state_dict'] = new_state_dict
torch.save(new_checkpoint, args.output_checkpoint_path)
if __name__ == '__main__':
main()