In [None]:
#@title Download CROP if necessary

import os
import sys
if 'crop.py' not in os.listdir('..'):
    !git clone https://github.com/Josuelmet/CROP.git
    sys.path.append('./CROP')
else:
    sys.path.append('..')

from crop import *

Cloning into 'CROP'...
remote: Enumerating objects: 56, done.[K
remote: Counting objects: 100% (56/56), done.[K
remote: Compressing objects: 100% (51/51), done.[K
remote: Total 56 (delta 16), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (56/56), 654.37 KiB | 3.41 MiB/s, done.
Resolving deltas: 100% (16/16), done.


# Testing CROP on a random high-dimensional model

In [None]:
import torch
from torch import nn
from collections import OrderedDict

model =   nn.Sequential(
          OrderedDict([
          ("flat1", nn.Flatten()),
          ("fc0", nn.Linear(32*32*1, 400)),
          # Affine Layers
          ("fc1", nn.Linear(400, 120)),
          ("relu3", nn.LeakyReLU()),
          ("fc2", nn.Linear(120, 84)),
          ("relu4", nn.LeakyReLU()),
          ("fc3", nn.Linear(84, 10))
          ]))

model = ConstrainedSequential.cast(model)

b = 100
c = torch.randn(4, b, 1, 32, 32)
x = torch.randn(1, 1, 32, 32)

model(x, c)
model = ConstrainedSequential.uncast(model)

In [None]:
for each_constraint in c:
    flags, unmatched_act = check_layerwise_signs(model, each_constraint)

    # Index up to :-1 because the last layer's signs do not need to agree.
    print('All signs agree', all(flags[:-1]))
    print('Abs sum of activations that disagree',sum(unmatched_act[:-1]))

All signs agree True
Abs sum of activations that disagree tensor(0.)
All signs agree True
Abs sum of activations that disagree tensor(0.)
All signs agree True
Abs sum of activations that disagree tensor(0.)
All signs agree True
Abs sum of activations that disagree tensor(0.)


# Line-by-line testing

In [None]:
import torch
from torch import nn

m = nn.Linear(10, 3)
x = torch.randn(8, 10)
R, V = (2, 4)
assert R * V == x.shape[0]

force_linearity = False

h = m(x)

def sign(tensor):
    return tensor.sign() + (tensor == 0)

#@torch.no_grad()

# Given h, the pre-activation for everyone (data + constraints).
# Shape is thus (N + R * V, K) with K the output dim (since W is R^D:-> R^K)
h_c = torch.clone(h[-R * V:])
h_c_signs = sign(h_c).reshape((R, V) + h_c.shape[1:])

In [None]:
h_c_signs

tensor([[[-1.,  1.,  1.],
         [-1., -1., -1.],
         [-1.,  1.,  1.],
         [ 1., -1., -1.]],

        [[ 1.,  1., -1.],
         [-1., -1., -1.],
         [-1.,  1.,  1.],
         [ 1., -1.,  1.]]], grad_fn=<ReshapeAliasBackward0>)

In [None]:
# Select which units/neurons actually need intervention;
# i.e., which neurons do not have signs that agree within each of the R constraint regions.
# conflict_dims is a length-K boolean vector.
conflict_dims = (h_c_signs.sum(1).abs() != V).any(0)
conflict_dims

tensor([True, True, True])

In [None]:
# Calculating the overall majority sign:
desired_signs = sign(h_c_signs.sum((0,1)))[conflict_dims]
desired_signs

tensor([-1.,  1.,  1.], grad_fn=<IndexBackward0>)

In [None]:
# Calculate each region's majority sign for each neuron.
# regionwise_majority has shape (R, K_conflict)
regionwise_majority = sign(h_c_signs.sum(1))[:, conflict_dims]
regionwise_majority

tensor([[-1.,  1.,  1.],
        [ 1.,  1.,  1.]], grad_fn=<IndexBackward0>)

In [None]:
if not force_linearity:

    # Reshape the conflicted part of h_c to (R, V, K_conflict), then
    # multiply by 0 all neurons that do not agree with the regionwise majority.
    h_c[:, conflict_dims] = (
        h_c.reshape_as(h_c_signs)[:, :, conflict_dims] * (regionwise_majority == desired_signs).unsqueeze(1)
    ).reshape_as(h_c[:, conflict_dims])

    print(h_c)

tensor([[-0.2121,  0.3760,  0.6179],
        [-0.7720, -0.2179, -0.7318],
        [-1.3371,  0.9049,  0.5732],
        [ 1.3041, -0.1099, -1.1811],
        [ 0.0000,  0.0483, -0.6730],
        [-0.0000, -0.6817, -0.3817],
        [-0.0000,  0.0661,  1.1537],
        [ 0.0000, -0.4754,  0.3144]], grad_fn=<CopySlices>)


In [None]:
# Look by how much do we have to shift each hyper-plane
# so that all constraints have the majority sign
extra_bias = (h_c[:, conflict_dims] * desired_signs).amin(0).clamp(max=0) * desired_signs * (1 + 1e-3)
h[:, conflict_dims] -= extra_bias

h.sign()

tensor([[-1.,  1.,  1.],
        [-1.,  1.,  1.],
        [-1.,  1.,  1.],
        [-1.,  1.,  1.],
        [-1.,  1.,  1.],
        [-1.,  1.,  1.],
        [-1.,  1.,  1.],
        [-1.,  1.,  1.]], grad_fn=<SignBackward0>)