-
Notifications
You must be signed in to change notification settings - Fork 144
/
aggregation.py
120 lines (86 loc) · 3.4 KB
/
aggregation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import abc
import torch
from ..utils.registry import Registry
from .typing import TabularData
from .utils.torch_utils import calculate_batch_size_from_input_size
aggregation_registry: Registry = Registry.class_registry("torch.aggregation_registry")
class FeatureAggregation(torch.nn.Module):
def __rrshift__(self, other):
from .block.base import right_shift_block
return right_shift_block(self, other)
def forward(self, inputs: TabularData) -> torch.tensor:
return super(FeatureAggregation, self).forward(inputs)
@abc.abstractmethod
def forward_output_size(self, input_size):
raise NotImplementedError
def build(self, input_size, device=None):
if device:
self.to(device)
self.input_size = input_size
return self
def output_size(self):
if not self.input_size:
# TODO: log warning here
pass
return self.forward_output_size(self.input_size)
@aggregation_registry.register("concat")
class ConcatFeatures(FeatureAggregation):
def __init__(self, axis=-1):
super().__init__()
self.axis = axis
def forward(self, inputs):
tensors = []
for name in sorted(inputs.keys()):
tensors.append(inputs[name])
return torch.cat(tensors, dim=self.axis)
def forward_output_size(self, input_size):
batch_size = calculate_batch_size_from_input_size(input_size)
return batch_size, sum([i[1] for i in input_size.values()])
@aggregation_registry.register("sequential_concat")
class SequentialConcatFeatures(FeatureAggregation):
def forward(self, inputs):
tensors = []
for name in sorted(inputs.keys()):
val = inputs[name]
if val.ndim == 2:
val = val.unsqueeze(dim=-1)
tensors.append(val)
return torch.cat(tensors, dim=-1)
def forward_output_size(self, input_size):
batch_size = calculate_batch_size_from_input_size(input_size)
converted_input_size = {}
for key, val in input_size.items():
if len(val) == 2:
converted_input_size[key] = val + (1,)
else:
converted_input_size[key] = val
return (
batch_size,
list(input_size.values())[0][1],
sum([i[-1] for i in converted_input_size.values()]),
)
@aggregation_registry.register("stack")
class StackFeatures(FeatureAggregation):
def __init__(self, axis=-1):
super().__init__()
self.axis = axis
def forward(self, inputs):
tensors = []
for name in sorted(inputs.keys()):
tensors.append(inputs[name])
return torch.stack(tensors, dim=self.axis)
def forward_output_size(self, input_size):
batch_size = calculate_batch_size_from_input_size(input_size)
last_dim = [i for i in input_size.values()][0][-1]
return batch_size, len(input_size), last_dim
@aggregation_registry.register("element-wise-sum")
class ElementwiseSum(FeatureAggregation):
def __init__(self):
super().__init__()
self.stack = StackFeatures(axis=0)
def forward(self, inputs):
return self.stack(inputs).sum(dim=0)
def forward_output_size(self, input_size):
batch_size = calculate_batch_size_from_input_size(input_size)
last_dim = [i for i in input_size.values()][0][-1]
return batch_size, last_dim