/
node_input_output.py
56 lines (36 loc) · 1.66 KB
/
node_input_output.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import TYPE_CHECKING, List
from .node import Node
if TYPE_CHECKING:
from .cartesian_graph import CartesianGraph
class InputNode(Node):
"""An input node of the computational graph.
"""
_arity = 0
def __init__(self, idx: int, input_nodes: List[int]) -> None:
super().__init__(idx, input_nodes)
def __call__(self, x: List[float], graph: "CartesianGraph") -> None:
assert False
def format_output_str(self, graph: "CartesianGraph") -> None:
self._output_str = f"x[{self._idx}]"
def format_output_str_numpy(self, graph: "CartesianGraph") -> None:
self.format_output_str(graph)
def format_output_str_torch(self, graph: "CartesianGraph") -> None:
self._output_str = f"x[:, {self._idx}]"
def format_output_str_sympy(self, graph: "CartesianGraph") -> None:
self.format_output_str(graph)
class OutputNode(Node):
"""An output node of the computational graph.
"""
_arity = 1
def __init__(self, idx: int, input_nodes: List[int]) -> None:
super().__init__(idx, input_nodes)
def __call__(self, x: List[float], graph: "CartesianGraph") -> None:
self._output = graph[self._addresses[0]].output
def format_output_str(self, graph: "CartesianGraph") -> None:
self._output_str = f"{graph[self._addresses[0]].output_str}"
def format_output_str_numpy(self, graph: "CartesianGraph") -> None:
self.format_output_str(graph)
def format_output_str_torch(self, graph: "CartesianGraph") -> None:
self.format_output_str(graph)
def format_output_str_sympy(self, graph: "CartesianGraph") -> None:
self.format_output_str(graph)