Skip to content

Commit

Permalink
Jan 2021 Release
Browse files Browse the repository at this point in the history
Co-authored-by: Zhouxing Shi <zhouxingshichn@gmail.com>
Co-authored-by: Huan Zhang <huan@huan-zhang.com>
Co-authored-by: Yihan Wang <wangyihan617@gmail.com>
  • Loading branch information
4 people committed Jan 13, 2021
1 parent 1d7a278 commit c8935c6
Show file tree
Hide file tree
Showing 30 changed files with 812 additions and 912 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ model = BoundedModule(model, my_input)
ptb = PerturbationLpNorm(norm=np.inf, eps=0.1)
# Make the input a BoundedTensor with perturbation
my_input = BoundedTensor(my_input, ptb)
# Forward propagation using BoundedTensor
# Regular forward propagation using BoundedTensor works as usual.
prediction = model(my_input)
# Compute LiRPA bounds
lb, ub = model.compute_bounds(method="backward")
lb, ub = model.compute_bounds(x=(my_input,), method="backward")
```

Checkout
Expand Down
67 changes: 54 additions & 13 deletions auto_LiRPA/bound_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


class BoundedModule(nn.Module):
def __init__(self, model, global_input, bound_opts={}, auto_batch_dim=True, device='cpu',
def __init__(self, model, global_input, bound_opts={}, auto_batch_dim=True, device='auto',
verbose=False):
super(BoundedModule, self).__init__()
if isinstance(model, BoundedModule):
Expand All @@ -23,7 +23,7 @@ def __init__(self, model, global_input, bound_opts={}, auto_batch_dim=True, devi
self.verbose = verbose
self.bound_opts = bound_opts
self.auto_batch_dim = auto_batch_dim
self.device = device
self.device = device if device != 'auto' else next(model.parameters()).device
self.global_input = global_input
if auto_batch_dim:
# logger.warning('Using automatic batch dimension inferring, which may not be correct')
Expand All @@ -41,10 +41,11 @@ def __call__(self, *input, **kwargs):
kwargs.pop("method_opt")
else:
opt = "forward"
if "disable_multi_gpu" in kwargs:
kwargs.pop("disable_multi_gpu")
if "no_replicas" in kwargs:
kwargs.pop("no_replicas")
for kwarg in [
'disable_multi_gpu', 'no_replicas', 'get_property',
'node_class', 'att_name']:
if kwarg in kwargs:
kwargs.pop(kwarg)
if opt == "compute_bounds":
return self.compute_bounds(**kwargs)
else:
Expand Down Expand Up @@ -471,6 +472,9 @@ def _convert(self, model, global_input):
model.load_state_dict(self.ori_state_dict)
delattr(self, 'ori_state_dict')

# The final node used in the last time calling `compute_bounds`
self.last_final_node = None

logger.debug('NodesOP:')
for node in nodesOP:
logger.debug('{}'.format(node._replace(param=None)))
Expand Down Expand Up @@ -531,11 +535,27 @@ def compute_bounds(self, x=None, aux=None, C=None, IBP=False, forward=False, met
A_dict = {} if return_A else None
if x is not None:
self._set_input(*x, new_interval=new_interval)

# Several shortcuts.
method = method.lower() if method is not None else method
if method == 'ibp':
# Pure IBP bounds.
method = None
IBP = True
elif method == 'ibp+backward' or method == 'ibp+crown' or method == 'crown-ibp':
method = 'backward'
IBP = True
elif method == 'crown':
method = 'backward'
elif method == 'forward':
forward = True
elif method == 'forward+backward':
method = 'backward'
forward = True

if IBP and method is None and reuse_ibp:
# directly return the previously saved ibp bounds
return self.ibp_lower, self.ibp_upper
if method == 'forward':
forward = True
root = [self._modules[name] for name in self.root_name]
batch_size = root[0].fv.shape[0]
dim_in = 0
Expand Down Expand Up @@ -578,14 +598,32 @@ def compute_bounds(self, x=None, aux=None, C=None, IBP=False, forward=False, met
# check whether weights are perturbed and set nonlinear for the BoundMatMul operation
for n in self._modules.values():
if isinstance(n, (BoundLinear, BoundConv, BoundBatchNormalization)):
n.nonlinear = False
for l_name in n.input_name[1:]:
node = self._modules[l_name]
if hasattr(node, 'perturbation'):
if node.perturbation is not None:
n.nonlinear = True

# BFS to find out whether each node is used given the current final node
if final != self.last_final_node:
self.last_final_node = final
for i in self._modules.values():
i.used = False
final.used = True
queue = deque([final])
while len(queue) > 0:
n = queue.popleft()
for n_pre_name in n.input_name:
n_pre = self._modules[n_pre_name]
if not n_pre.used:
n_pre.used = True
queue.append(n_pre)

for i in self._modules.values(): # for all nodes
if hasattr(i, 'nonlinear') and i.nonlinear: # if node.nonlinear
if not i.used:
continue
if hasattr(i, 'nonlinear') and i.nonlinear:
for l_name in i.input_name:
node = self._modules[l_name]
if not hasattr(node, 'lower'):
Expand All @@ -607,7 +645,8 @@ def compute_bounds(self, x=None, aux=None, C=None, IBP=False, forward=False, met
node.lower = node.forward(self._modules[node.input_name[0]].lower)
node.upper = node.forward(self._modules[node.input_name[0]].upper)
elif isinstance(node, BoundReshape) and \
hasattr(self._modules[node.input_name[0]], 'lower'):
hasattr(self._modules[node.input_name[0]], 'lower') and \
hasattr(self._modules[node.input_name[1]], 'value'):
# Node for input value.
val_input = self._modules[node.input_name[0]]
# Node for input parameter (e.g., shape, permute)
Expand All @@ -633,15 +672,15 @@ def compute_bounds(self, x=None, aux=None, C=None, IBP=False, forward=False, met
newC = Patches(None, 1, 0, [batch_size, node.default_shape[-2] * node.default_shape[-1], node.default_shape[-3], node.default_shape[-3], 1, 1], 1)
elif isinstance(node, BoundAdd) and node.mode == "patches":
num_channel = node.default_shape[-3]
patches = (torch.eye(num_channel)).unsqueeze(0).unsqueeze(0).unsqueeze(4).unsqueeze(5) # now [1 * 1 * in_C * in_C * 1 * 1]
patches = (torch.eye(num_channel, device=self.device)).unsqueeze(0).unsqueeze(0).unsqueeze(4).unsqueeze(5) # now [1 * 1 * in_C * in_C * 1 * 1]
newC = Patches(patches, 1, 0, [batch_size] + list(patches.shape[1:]))
else:
newC = torch.eye(dim, device=self.device)\
.unsqueeze(0).repeat(batch_size, 1, 1)\
.view(batch_size, dim, *node.default_shape[1:])
if return_A:
_, _, A_dict = self._backward_general(C=newC, node=node, root=root,
return_A=return_A, A_dict=A_dict)
return_A=return_A, A_dict=A_dict)
else:
self._backward_general(C=newC, node=node, root=root, return_A=return_A)

Expand Down Expand Up @@ -1085,7 +1124,9 @@ def forward(self, *inputs, **kwargs):
kwargs.pop("no_replicas")

if not self.device_ids or disable_multi_gpu:
return self.module(*inputs, **kwargs)
if kwargs.pop("get_property", False):
return self.get_property(self, *inputs, **kwargs)
return self.module(*inputs, **kwargs)

if kwargs.pop("get_property", False):
if self._replicas is None:
Expand Down
27 changes: 20 additions & 7 deletions auto_LiRPA/bound_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def get_bias(self, A, bias):

# the shape of A.patches is [batch, L, out_c, in_c, K, K]

if self.batch_dim != -1: # Here we only support batch_dim == 0
if self.batch_dim != -1:
batch_size = bias.shape[0]
bias = F.unfold(bias, kernel_size=A.patches.size(-1), stride=A.stride, padding=A.padding).transpose(-2, -1).unsqueeze(-2)
# Here the size of bias is [batch_size, L, 1, in_c * K * K]
Expand All @@ -232,7 +232,12 @@ def get_bias(self, A, bias):
bias_new = prod.sum(-1).transpose(-2, -1)
bias_new = bias_new.view(batch_size, bias_new.size(-2), int(math.sqrt(bias_new.size(-1))), int(math.sqrt(bias_new.size(-1))))
else:
raise NotImplementedError()
# Similar to BoundConstant
patches = A.patches
patches_reshape = torch.sum(patches, dim=(-1, -2, -3)) * bias.to(self.device)
patches_reshape = patches_reshape.transpose(-1, -2)
return patches_reshape.view(patches_reshape.size(0), patches_reshape.size(1), int(math.sqrt(patches_reshape.size(2))), -1).transpose(0, 1)

return bias_new
else:
return NotImplementedError()
Expand Down Expand Up @@ -914,7 +919,7 @@ def _bound_oneside(last_A):
else:
# we should create a real identity Patch
num_channel = tmp_weight.view(-1).size(0)
patches = (torch.eye(num_channel) * tmp_weight.view(-1)).unsqueeze(0).unsqueeze(0).unsqueeze(4).unsqueeze(5) # now [1 * 1 * in_C * in_C * 1 * 1]
patches = (torch.eye(num_channel, device=tmp_weight.device) * tmp_weight.view(-1)).unsqueeze(0).unsqueeze(0).unsqueeze(4).unsqueeze(5) # now [1 * 1 * in_C * in_C * 1 * 1]
next_A = Patches(patches, 1, 0, [1, 1, num_channel, 1, 1])
sum_bias = tmp_bias.unsqueeze(1).unsqueeze(2).unsqueeze(3) # squeezing batch dim, now [C * 1 * 1 * 1]
else:
Expand Down Expand Up @@ -2119,8 +2124,16 @@ def bound_backward(self, last_lA, last_uA):
def _bound_oneside(A):
if A is None:
return 0.0
while A.ndim > 2:
A = torch.sum(A, dim=-1)

if type(A) == torch.Tensor:
while A.ndim > 2:
A = torch.sum(A, dim=-1)
elif type(A) == Patches:
patches = A.patches
patches_reshape = torch.sum(patches, dim=(-1, -2, -3)) * self.value.to(self.device)
patches_reshape = patches_reshape.transpose(-1, -2)
return patches_reshape.view(patches_reshape.size(0), patches_reshape.size(1), int(math.sqrt(patches_reshape.size(2))), -1).transpose(0, 1)

return A * self.value.to(self.device)

lbias = _bound_oneside(last_lA)
Expand Down Expand Up @@ -2255,10 +2268,10 @@ def _bound_oneside(last_A):
if last_A is None:
return None
A = torch.zeros(
last_A.shape[0], last_A.shape[1], *x.lower.shape[1:], device=last_A.device)
last_A.shape[0], last_A.shape[1], *x.default_shape[1:], device=last_A.device)
A.scatter_(
dim=dim + 1,
index=index.lower.unsqueeze(0).repeat(A.shape[0], *([1] * (A.ndim - 1))),
index=self.index.unsqueeze(0).repeat(A.shape[0], *([1] * (A.ndim - 1))),
src=last_A)
return A

Expand Down
19 changes: 7 additions & 12 deletions examples/language/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def step(model, ptb, batch, eps=1.0, train=False):
if args.method == 'IBP+backward_train':
lb, ub = model_loss.compute_bounds(
x=(labels, embeddings, mask), aux=aux,
IBP=True, C=None, method='backward', bound_lower=False)
C=None, method='IBP+backward', bound_lower=False)
else:
raise NotImplementedError
loss_robust = torch.log(ub).mean()
Expand All @@ -184,21 +184,16 @@ def step(model, ptb, batch, eps=1.0, train=False):
torch.eye(num_class).type_as(embeddings).unsqueeze(0)
I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0)))
c = (c[I].view(embeddings.size(0), num_class - 1, num_class))
if args.method == 'IBP':
lb, ub = model_bound.compute_bounds(aux=aux, IBP=True, C=c, method=None)
elif args.method == 'IBP+backward':
lb, ub = model_bound.compute_bounds(aux=aux, IBP=True, C=c, method='backward', bound_upper=False)
if args.method in ['IBP', 'IBP+backward', 'forward', 'forward+backward']:
lb, ub = model_bound.compute_bounds(aux=aux, C=c, method=args.method, bound_upper=False)
elif args.method == 'IBP+backward_train':
# CROWN-IBP
if 1 - eps > 1e-4:
lb, ub = model_bound.compute_bounds(aux=aux, IBP=True, C=c, method='backward', bound_upper=False)
ilb, iub = model_bound.compute_bounds(aux=aux, IBP=True, C=c, method=None, reuse_ibp=True)
lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP+backward', bound_upper=False)
ilb, iub = model_bound.compute_bounds(aux=aux, C=c, method='IBP', reuse_ibp=True)
lb = eps * ilb + (1 - eps) * lb
else:
lb, ub = model_bound.compute_bounds(aux=aux, IBP=True, C=c, method=None)
elif args.method == 'forward':
lb, ub = model_bound.compute_bounds(aux=aux, IBP=False, C=c, method='forward', bound_upper=False)
elif args.method == 'forward+backward':
lb, ub = model_bound.compute_bounds(aux=aux, IBP=False, forward=True, C=c, method='backward', bound_upper=False)
lb, ub = model_bound.compute_bounds(aux=aux, C=c, method='IBP')
else:
raise NotImplementedError
lb_padded = torch.cat((torch.zeros(size=(lb.size(0),1), dtype=lb.dtype, device=lb.device), lb), dim=1)
Expand Down
Loading

0 comments on commit c8935c6

Please sign in to comment.