-
Notifications
You must be signed in to change notification settings - Fork 3
/
compression.py
89 lines (74 loc) · 2.58 KB
/
compression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# -*- coding: utf-8 -*-
from __future__ import print_function
from settings import logger
import torch
import numpy as np
import time
class NoneCompressor():
@staticmethod
def compress(tensor, name=None):
return tensor, tensor.dtype
@staticmethod
def decompress(tensor, ctc, name=None):
z = tensor
return z
class TopKCompressor():
"""
Sparse Communication for Distributed Gradient Descent, Alham Fikri Aji et al., 2017
"""
residuals = {}
sparsities = []
zero_conditions = {}
values = {}
indexes = {}
c = 0
t = 0.
name = 'topk'
@staticmethod
def clear():
TopKCompressor.residuals = {}
TopKCompressor.sparsities = []
TopKCompressor.zero_conditions = {}
TopKCompressor.values = {}
TopKCompressor.indexes = {}
@staticmethod
def compress(tensor, name=None, sigma_scale=2.5, ratio=0.05):
with torch.no_grad():
if name not in TopKCompressor.residuals:
TopKCompressor.residuals[name] = torch.zeros_like(tensor.data)
# top-k solution
numel = tensor.numel()
k = max(int(numel * ratio), 1)
tensor.data.add_(TopKCompressor.residuals[name].data)
values, indexes = torch.topk(torch.abs(tensor.data), k=k)
values = tensor.data[indexes]
TopKCompressor.residuals[name].data = tensor.data + 0.0
TopKCompressor.residuals[name].data[indexes] = 0.
TopKCompressor.values[name] = values
TopKCompressor.indexes[name] = indexes
return tensor, indexes, values
@staticmethod
def get_residuals(name, like_tensor):
if name not in TopKCompressor.residuals:
TopKCompressor.residuals[name] = torch.zeros_like(like_tensor.data)
return TopKCompressor.residuals[name]
@staticmethod
def add_residuals(included_indexes, name):
with torch.no_grad():
residuals = TopKCompressor.residuals[name]
if type(included_indexes) is np.ndarray:
indexes_t = torch.from_numpy(included_indexes).to(device=residuals.device).long()
else:
indexes_t = included_indexes
values = TopKCompressor.values[name]
values.data[indexes_t] = 0.0
residuals.data[TopKCompressor.indexes[name]] += values.data
@staticmethod
def decompress(tensor, ctc, name=None):
z = tensor
return z
compressors = {
'topk': TopKCompressor,
'none': NoneCompressor,
None: NoneCompressor
}