In [1]:
#@title Download CROP if necessary

import os
import sys
if 'crop.py' not in os.listdir('..'):
    !git clone https://github.com/Josuelmet/CROP.git
    sys.path.append('./CROP')
else:
    sys.path.append('..')

from crop import *

import torch
from torch import nn
import torchvision

Cloning into 'CROP'...
remote: Enumerating objects: 88, done.[K
remote: Counting objects: 100% (88/88), done.[K
remote: Compressing objects: 100% (81/81), done.[K
remote: Total 88 (delta 28), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (88/88), 665.41 KiB | 4.62 MiB/s, done.
Resolving deltas: 100% (28/28), done.


# Testing conv. CROP on a random high-dimensional model

In [2]:
from collections import OrderedDict

model = torchvision.models.vgg16()

# We have to convert to Sequential to add the Flatten layer
# because casting replaces the model's original forward() function
model = nn.Sequential(OrderedDict([
    ("features",   model.features),
    ("avgpool",    model.avgpool),
    ("flatten",    nn.Flatten(1)),
    ("classifier", model.classifier)
]))

# Optional:
# Converting MaxPool to AvgPool
#for i,m in enumerate(model.features._modules.values()):
#    if isinstance(m, nn.MaxPool2d):
#        model.features[i] = nn.AvgPool2d(**{k:v for k,v in vars(m).items() if k in nn.AvgPool2d.__constants__})

for m in model.modules():
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        with torch.no_grad():
            m.weight.data = torch.randn_like(m.weight)
            m.bias.data = torch.randn_like(m.bias)

In [3]:
model = ConstrainedSequential.cast(model.eval(), constrain_last=True) # THE EVAL IS IMPORTANT B/C OF DROPOUT

In [4]:
R, V = 4, 10
c = torch.randn(R, V, 3, 32, 32)
x = torch.randn(1, 3, 32, 32)

In [5]:
out = model(c.reshape((R*V,) + c.shape[2:]), c)
out = out.reshape((R,V) + out.shape[1:])

In [6]:
model = ConstrainedSequential.uncast(model)

In [7]:
# On these pretrained VGG networks, sometimes

for each_constraint in c:
    flags, unmatched_act, acts = check_layerwise_signs(model, each_constraint)

    # Index up to :-1 because the last layer's signs do not need to agree.
    print('All signs agree', all(flags[:-1]))
    print('Abs sum of activations that disagree',sum(unmatched_act[:-1]))

    #break

All signs agree True
Abs sum of activations that disagree tensor(0.)
All signs agree True
Abs sum of activations that disagree tensor(0.)
All signs agree True
Abs sum of activations that disagree tensor(0.)
All signs agree True
Abs sum of activations that disagree tensor(0.)


In [8]:
for i, a in enumerate(acts):
    s = a.sign().sum(0)
    if a.ndim > 2:
        s = s.sum((1, 2))
    print(f'{i}:\t{(s % s.max()).abs().sum() == 0}')

0:	True
1:	True
2:	True
3:	True
4:	True
5:	True
6:	True
7:	True
8:	True
9:	True
10:	True
11:	True
12:	True
13:	True
14:	True
15:	True
16:	True
17:	True
18:	True
19:	True
20:	True
21:	True
22:	True
23:	True
24:	True
25:	True
26:	True
27:	True
28:	True
29:	True
30:	True
31:	True
32:	True
33:	True
34:	True
35:	True
36:	True
37:	True
38:	True
39:	True


# Line-by-line testing

In [9]:
model = torchvision.models.vgg16()

m = model.features[0]
with torch.no_grad():
    m.weight += torch.randn_like(m.weight)
x = torch.randn(6, 3, 32, 32)
R, V = (3, 2)
assert R * V == x.shape[0]

force_linearity = True

h = m(x)

In [10]:
# Get the number of channels
C = h.shape[1]

# Isolate the preactivations belonging to constraint region vertices
h_c = torch.clone(h[-R * V:])
# Make the first dimension the channel dimension
h_c = h_c.transpose(0, 1)
# Flatten all but the first two dimensions
h_c_flat = h_c.reshape((C, R, -1))


# The bias applied to each channel's output is a scalar.
# Therefore, we need to take the aggregate sum of *all* entry signs from
# *all* vertices of *each* constraint region for *each* channel.
# We do this via sign(h_c_flat).sum(2), which returns a (C x R) matrix.
# Taking the sign() of that matrix indicates the majority sign of each
# channel of each constraint region.
regionwise_majority = sign(sign(h_c_flat).sum(2))
regionwise_majority

tensor([[ 1., -1.,  1.],
        [-1., -1., -1.],
        [ 1.,  1.,  1.],
        [-1., -1., -1.],
        [ 1.,  1.,  1.],
        [-1.,  1., -1.],
        [-1., -1., -1.],
        [ 1.,  1., -1.],
        [-1., -1.,  1.],
        [-1.,  1.,  1.],
        [ 1., -1., -1.],
        [-1., -1., -1.],
        [-1.,  1.,  1.],
        [-1.,  1., -1.],
        [ 1., -1.,  1.],
        [ 1.,  1., -1.],
        [-1., -1., -1.],
        [ 1., -1., -1.],
        [ 1., -1.,  1.],
        [ 1.,  1.,  1.],
        [-1.,  1., -1.],
        [-1., -1., -1.],
        [ 1.,  1.,  1.],
        [ 1., -1.,  1.],
        [-1., -1., -1.],
        [-1.,  1., -1.],
        [-1., -1.,  1.],
        [-1., -1., -1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [-1., -1., -1.],
        [-1.,  1.,  1.],
        [-1., -1., -1.],
        [-1.,  1., -1.],
        [-1.,  1., -1.],
        [ 1., -1.,  1.],
        [ 1., -1.,  1.],
        [-1., -1., -1.],
        [-1.,  1., -1.],
        [ 1.,  1., -1.],


In [11]:
# regionwise_majority is now a (C x R) binary tensor with values {-1, 1}.
# We compute desired_signs (a length-C vector) via majority vote among regions
desired_signs = sign(regionwise_majority.sum(1))
desired_signs

tensor([ 1., -1.,  1., -1.,  1., -1., -1.,  1., -1.,  1., -1., -1.,  1., -1.,
         1.,  1., -1., -1.,  1.,  1., -1., -1.,  1.,  1., -1., -1., -1., -1.,
         1.,  1., -1.,  1., -1., -1., -1.,  1.,  1., -1., -1.,  1.,  1., -1.,
        -1.,  1., -1.,  1., -1., -1.,  1.,  1.,  1., -1., -1.,  1., -1., -1.,
         1.,  1., -1.,  1.,  1., -1., -1., -1.], grad_fn=<AddBackward0>)

In [12]:
# If we don't need to force linearity in all regions,
# then we only need to look at points that are part of the majority-sign coalition.
# All other points will be multiplied by zero.
if not force_linearity:
    h_c_flat *= (regionwise_majority == desired_signs[:, None]).unsqueeze(2)


# Calculate extra bias
extra_bias = (h_c_flat * desired_signs[:, None, None]).amin((1,2)).clamp(max=0) * desired_signs * (1 + 1e-3)
# Reshape extra_bias to be compatible with the shape of h
extra_bias = extra_bias.reshape((1, C,) + tuple(torch.ones(h.ndim - 2).to(int)))

extra_bias.flatten()

tensor([-19.6577,  14.3502, -22.4661,  15.1722, -19.3050,  17.9165,  25.0987,
        -20.9312,  15.7228, -18.6171,  20.3821,  19.0568, -20.1554,  19.4312,
        -17.6010, -20.3626,  21.3508,  16.7045, -20.4194, -15.5700,  21.1873,
         12.5803, -19.3323, -18.4732,  15.0266,  21.7979,  18.4891,  18.0978,
        -17.6602, -17.2189,  22.2625, -20.7817,  19.7589,  21.6744,  13.8114,
        -16.2153, -14.2942,  21.4378,  22.4948, -20.2694, -15.6798,  21.7140,
         18.1330, -21.3525,  15.1057, -18.7802,  23.1020,  21.9895, -25.5568,
        -15.3556, -23.2312,  23.9515,  16.1611, -20.5781,  19.5204,  25.2621,
        -18.3033, -19.9708,  19.0404, -21.5399, -16.8333,  19.0248,  14.7314,
         15.8933], grad_fn=<ReshapeAliasBackward0>)

In [13]:
# This should be an array of 0s, where each entry represents a channel.
(h - extra_bias).sign().sum((0,2,3)).abs() - (h.shape[0] * h.shape[2] * h.shape[3])

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       grad_fn=<SubBackward0>)