-
Notifications
You must be signed in to change notification settings - Fork 0
/
SDT.py
228 lines (179 loc) · 8.43 KB
/
SDT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import torch
import torch.nn as nn
class SDT(nn.Module):
"""
Fast implementation of a soft decision tree in PyTorch.
Attributes:
input_dim (int): Number of input dimensions.
output_dim (int): Number of output dimensions, e.g., number of classes in classification.
depth (int): Depth of the tree, affecting its complexity.
lamda (float): Regularization coefficient for the loss function.
device (torch.device): Computation device (CPU or GPU).
internal_node_num_ (int): Number of internal nodes in the tree.
leaf_node_num_ (int): Number of leaf nodes in the tree.
penalty_list (List[float]): Coefficients for regularization penalty of nodes at different depths.
inner_nodes (nn.Sequential): Sequential model for internal nodes.
leaf_nodes (nn.Linear): Linear layer representing leaf nodes.
"""
def __init__(self, input_dim: int, output_dim: int, depth: int = 5, lamda: float = 1e-3, use_cuda: bool = False):
"""
Initializes the Soft Decision Tree model.
Parameters:
input_dim (int): The number of features in the input data.
output_dim (int): The number of target outputs or classes.
depth (int): The depth of the tree, affecting the number of nodes.
lamda (float): Regularization coefficient to control model complexity.
use_cuda (bool): Flag to enable CUDA (GPU) computation.
"""
super(SDT, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.depth = depth
self.lamda = lamda
self.device = torch.device("cuda" if use_cuda else "cpu")
self._validate_parameters()
self.internal_node_num_ = 2 ** self.depth - 1
self.leaf_node_num_ = 2 ** self.depth
self.penalty_list = [self.lamda *
(2 ** (-depth)) for depth in range(self.depth)]
self.inner_nodes = nn.Sequential(
nn.Linear(self.input_dim + 1, self.internal_node_num_, bias=False),
nn.Sigmoid(),
)
self.leaf_nodes = nn.Linear(
self.leaf_node_num_, self.output_dim, bias=False)
def forward(self, X: torch.Tensor, is_training_data: bool = False) -> torch.Tensor:
"""
Performs a forward pass of the model.
Parameters:
X (torch.Tensor): Input data tensor.
is_training_data (bool): Indicates if the pass is for training.
Returns:
torch.Tensor: The model's predictions. Includes penalty if is_training_data is True.
"""
_mu, _penalty = self._forward(X)
y_pred = self.leaf_nodes(_mu)
if is_training_data:
return y_pred, _penalty
else:
return y_pred
def _forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Core implementation of the model's forward pass.
Parameters:
X (torch.Tensor): Input data tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of path probabilities and total penalty.
"""
batch_size = X.size(0)
X = self._data_augment(X)
path_prob = self.inner_nodes(X)
path_prob = torch.unsqueeze(path_prob, dim=2)
path_prob = torch.cat((path_prob, 1 - path_prob), dim=2)
_mu = X.data.new(batch_size, 1, 1).fill_(1.0)
_penalty = torch.tensor(0.0).to(self.device)
begin_idx = 0
end_idx = 1
for layer_idx in range(self.depth):
_path_prob = path_prob[:, begin_idx:end_idx, :]
_penalty += self._cal_penalty(layer_idx, _mu, _path_prob)
_mu = _mu.view(batch_size, -1, 1).repeat(1, 1, 2) * _path_prob
begin_idx = end_idx
end_idx = begin_idx + 2 ** (layer_idx + 1)
mu = _mu.view(batch_size, self.leaf_node_num_)
return mu, _penalty
def _cal_penalty(self, layer_idx: int, _mu: torch.Tensor, _path_prob: torch.Tensor) -> torch.Tensor:
"""
Computes regularization penalty for a given layer.
Parameters:
layer_idx (int): Index of the current tree layer.
_mu (torch.Tensor): Path probabilities up to the current layer.
_path_prob (torch.Tensor): Probabilities for routing at the current layer.
Returns:
torch.Tensor: Computed regularization penalty for the layer.
"""
penalty = torch.tensor(0.0).to(self.device)
batch_size = _mu.size(0)
_mu = _mu.view(batch_size, 2 ** layer_idx)
_path_prob = _path_prob.view(batch_size, 2 ** (layer_idx + 1))
for node in range(2 ** (layer_idx + 1)):
alpha = torch.sum(
_path_prob[:, node] * _mu[:, node // 2], dim=0) / torch.sum(_mu[:, node // 2], dim=0)
coeff = self.penalty_list[layer_idx]
penalty -= 0.5 * coeff * (torch.log(alpha) + torch.log(1 - alpha))
return penalty
def _data_augment(self, X: torch.Tensor) -> torch.Tensor:
"""
Adds a constant bias term to the input data.
Parameters:
X (torch.Tensor): Original input data.
Returns:
torch.Tensor: Augmented input data.
"""
batch_size = X.size(0)
X = X.view(batch_size, -1)
bias = torch.ones(batch_size, 1).to(self.device)
X = torch.cat((bias, X), 1)
return X
def _validate_parameters(self):
"""
Validates model parameters.
"""
if not self.depth > 0:
raise ValueError(
f"The tree depth should be strictly positive, got {self.depth} instead.")
if not self.lamda >= 0:
raise ValueError(
f"The coefficient of the regularization term should not be negative, got {self.lamda} instead.")
def compute_nfm(self, X):
# Ensure model is in evaluation mode for consistent output
self.eval()
# We need to enable gradients for input for NFM computation
X.requires_grad_(True)
# Forward pass through the model
mu, penalty = self._forward(X)
y_pred = self.leaf_nodes(mu)
# Initialize NFM as a zero tensor with the same size as the input
nfm = torch.zeros_like(X)
# Compute gradients for each output dimension
for i in range(self.output_dim):
self.zero_grad() # Clear existing gradients
# Backpropagate from each output dimension
y_pred[:, i].sum().backward(retain_graph=True)
# Sum gradients for each feature across all samples
nfm += X.grad.data
# Divide by the number of output dimensions to get the average influence
nfm /= self.output_dim
# Detach the NFM from the current graph to prevent further gradient computation
nfm = nfm.detach()
# Turn off gradients for input
X.requires_grad_(False)
return nfm.cpu().numpy() # Return NFM as a NumPy array for analysis
def compute_nfm_for_target(model, data_loader, target_class, device):
"""
Compute the Neural Feature Map (NFM) for a specific target class.
Args:
model: The Soft Decision Tree model.
data_loader: DataLoader providing the dataset.
target_class: The target class for which to compute the NFM.
device: The device (CPU or CUDA) on which to perform computations.
Returns:
A tensor representing the NFM for the specified target class.
"""
model.eval()
feature_contributions = [] # List to store feature contributions
for data, targets in data_loader:
data, targets = data.to(device), targets.to(device)
data = data.view(data.size(0), -1) # Flatten the data if necessary
# Forward pass through the model to get the paths and predictions
# Assuming model.forward() has been modified to return paths or contributions
output, paths = model.forward(data, return_paths=True)
# Filter paths for the specific target_class
for i in range(len(data)):
if targets[i] == target_class:
# Assuming `paths` contains contribution info per input
# Modify as per your implementation
feature_contributions.append(paths[i])
# Aggregate feature contributions across all filtered instances
nfm = torch.mean(torch.stack(feature_contributions), dim=0)
return nfm