-
Notifications
You must be signed in to change notification settings - Fork 35
/
op_base.py
141 lines (116 loc) · 7.13 KB
/
op_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
(C) Copyright 2021 IBM Corp.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Created on June 30, 2021
"""
from typing import Union, List, Optional
from abc import abstractmethod
from fuse.data.utils.sample import get_sample_id
from fuse.utils.ndict import NDict
from fuse.data.ops.hashable_class import HashableClass
import inspect
class OpBase(HashableClass):
"""
Operator Base Class
Operators are the building blocks of the sample processing pipeline.
Each operator gets as an input the sample_dict as created by the previous operator in pipeline,
modify sample_dict (can either add/delete/modify fields in sample_dict)before passing it to the next operator in pipeline.
"""
@abstractmethod
def __call__(self, sample_dict: NDict, **kwargs) -> Union[None, dict, List[dict]]:
"""
call function that apply the operation
:param sample_dict: the generated dictionary generated so far (generated be the previous ops in the pipeline)
The first op will typically get just the sample_id stored in sample_dict['data']['sample_id']
:param kwargs: additional arguments defined per operation
:return: Typically modified sample_dict.
There are two special cases supported only if the operation is in static pipeline:
* return None - ignore the sample and do not raise an error
* return list of sample_dict - a case splitted to few samples. for example image splitted to patches.
"""
raise NotImplementedError
class OpReversibleBase(OpBase):
"""
Special case of op - declaring that the operation can be reversed when required
(useful to reverse processing steps before presenting the output)
If there is nothing to reverse - to just declare that the op is reversible inherit from OpReversibleBase instead of OpBase and implement simple reverse method that returns sample_dict as is.
If some logic required to reverse the operation:
(1) record the information required to reverse the operation in __call__ function. Use op_id to store it in sample_dict (sample_dict[op_id] = <information to record?).
(2) override reverse() method: read the recorded information from sample_dict[op_id] and use it to reverse the operation
"""
@abstractmethod
def __call__(self, sample_dict: NDict, op_id: Optional[str], **kwargs) -> Union[None, dict, List[dict]]:
"""
See OpBase.__call__ for more infomation. The only difference is the extra argument that can be used to "record" the information required to reverse the operation.
:param op_id: unique identifier for an operation.
Might be used to support reverse operation as sample_dict key.
In such a case use sample_dict[op_id] = info_to_store
"""
raise NotImplementedError
def reverse(self, sample_dict: NDict, key_to_reverse: str, key_to_follow: str, op_id: Optional[str]) -> dict:
"""
reverse operation
If a reverse operation is not necessary (for example operator that reads an image),
implement simple reverse method that returns sample_dict as is
If reverse operation is necessary but not required by the project so far,
inherit from OpBase (will throw an NotImplementedError in case the reverse operation will be called).
To support reverse operation, store the parameters which necessary to apply the reverse operation
such as key to the transformed value and the argument to the transform operation in sample_dict[op_id].
Those values can be extracted back during the reverse operation.
:param sample_dict: the dictionary as modified by the previous steps (reversed direction)
:param op_id: See op_id in __call__ function
:param key_to_reverse: the required value to reverse
:param key_to_follow: run the reverse according to the operation applied on this value
:return: modified sample_dict
"""
raise NotImplementedError(
f"op {self} is not reversible. If there is nothing to reverse, just implement simple reverse method that returns sample_dict as is. If extra logic required to reverse follow the instructions in OpReversibleBase"
)
def op_call(op: OpBase, sample_dict: NDict, op_id: str, **kwargs):
if inspect.isclass(op):
raise Exception(
f"Error: expected an instance object, not a class object for {op}\n"
"When creating a pipeline, such error can happen when you provide the following ops list description:\n"
"[SomeOp, {}]\n"
"instead of \n"
"[SomeOp(), {}]\n"
)
try:
if isinstance(op, OpReversibleBase):
return op(sample_dict, op_id=op_id, **kwargs)
elif isinstance(op, OpBase): # OpBase but not reversible
return op(sample_dict, **kwargs)
else:
raise Exception(
f"Ops are expected to be instances of classes or subclasses of OpBase. The following op is not: {op}"
)
except:
# error messages are cryptic without this. For example, you can get "TypeError: __call__() got an unexpected keyword argument 'key_out_input'" , without any reference to the relevant op!
print(
"************************************************************************************************************************************\n"
+ f"error in __call__ method of op={op}, op_id={op_id}, sample_id={get_sample_id(sample_dict)} - more details below"
+ "*************************************************************************************************************************************\n"
)
raise
def op_reverse(op, sample_dict: NDict, key_to_reverse: str, key_to_follow: str, op_id: Optional[str]):
if isinstance(op, OpReversibleBase):
try:
return op.reverse(sample_dict, key_to_reverse, key_to_follow, op_id)
except:
# error messages are cryptic without this. For example, you can get "TypeError: __call__() got an unexpected keyword argument 'key_out_input'" , without any reference to the relevant op!
print(
f"error in reverse method of op={op}, op_id={op_id}, sample_id={get_sample_id(sample_dict)} - more details below"
)
raise
else: # OpBase but note reversible
raise NotImplementedError(
f"op {op} is not reversible. If there is nothing to reverse, just inherit OpReversibleBase instead of OpBase and implement simple reverse method that returns sample_dict as is. If extra logic required to reverse follow the instructions in OpReversibleBase"
)