# Reverse NN

In [1]:
import numpy as np
import tensorflow as tf
import sympy as sp
from sympy.solvers.solveset import linsolve
from functools import partial

In [2]:
"""
Activation Functions
"""

def leaky_relu(x,alpha=0.2):
    if x>0:
        return x
    else:
        return alpha*x
    
def ilr_case1(y):
    return y

def ilr_case2(y,alpha=0.2):
    return y/alpha

def identity(x):
    return x
def minus_identity(x):
    return -x

In [3]:
"""
Preimages
"""
def linear_reverse(y):
    return y

def relu_reverse(y):
    if type(y)==sp.core.add.Add or type(y)==sp.core.symbol.Symbol:
        return sp.Function("iR")(y)
    if y>0:
        return y
    else:
        return sp.Symbol('R-')
    
def leaky_relu_reverse(y,alpha=0.2):
    if type(y)==sp.core.add.Add or type(y)==sp.core.symbol.Symbol:
        return sp.Function("iLR")(y)
    if y>0:
        return y
    else:
        return y*(1/alpha)
    
def tanh_reverse(y):
    return sp.atanh(y)
    

def sigmoid_reverse(y):
    if type(y)==sp.core.add.Add or type(y)==sp.core.symbol.Symbol:
        return sp.Function("logit")(y)
    return sp.log(y/(1-y))

In [4]:
class Symbol_dict:
    def __init__(self):
        self.variables={}
    def return_index(self,name):
        if name in self.variables.keys():
            res=self.variables[name]
            self.variables[name]+=1
            return res
        else:
            self.variables[name]=1
            return 0
    def give_families(self):
        return self.variables
    
class Set_dict:
    """
    Describe the constraints
    """
    def __init__(self):
        self.table=[]
    def add_new(self,sols,sets):
        self.table.append({"sols":sols,"sets":sets})
    def add_dict(self,dict):
        self.table.append(dict)
    def merge_with(self,other_dict):
        self.table=self.table+other_dict.table
    def copy(self):
        res=Set_dict()
        res.table=self.table.copy()
        return res
        

# Solving Leaky_ReLu equations

In [21]:
def convert2bin(n,length):
    res=[0]*length
    i=1
    while n!=0:
        res[-i]=n%2
        n=n//2
        i+=1
    return res
    
def set_cases(container,funs):
    """
    Retourne l'ensemble des 2^n inéquations (n = taille de container) à vérifier
    """
    res=[]
    length=len(container)
    n=2**length
    for i in range(n):
        template=convert2bin(i,length)
        temp=container.copy()
        for k in range(length):
            if template[k]:
                temp[k]=funs[0](temp[k])
            else:
                temp[k]=funs[1](temp[k])
        res.append(temp)
    return res
            

def process_system(system):
    """
    Solve system
    """
    ncols=system.cols
    target=system[:ncols-1,:]
    constraints=system[ncols-1:,:]
    target=linsolve(target)
    mat_temp=constraints[:,:ncols-1]*sp.Matrix(make_target(target))-constraints[:,ncols-1:]
    sol_temp=sp.solve(mat_temp)
    target=make_target(target)
    for i in range(len(target)):
        target[i]=target[i].subs(sol_temp)
    return target,sol_temp

            

In [95]:
def char_part(symb):
    """
    Return the name of the symbol "symb" without it index
    """
    res=''
    for char in str(symb):
        if char.isnumeric():
            return res
        else:
            res+=char
            
def get_i_weights(weights,i):
    """
    Return values of ith columns of the weights matrix
    """
    res=[]
    for w in weights:
        res.append(w[i])
    return res

def char_part(symb):
    """
    Return the name of the symbol "symb" without it index
    """
    res=''
    for char in str(symb):
        if char.isnumeric:
            return res
        else:
            res+=char
        

def get_free_symb(eqs,names):
    """
    Return free variables of the equation "eq" of name in "name"
    """
    res=set()
    for eq in eqs:
        symb_dict=eq.free_symbols
        for symb in symb_dict:
            if char_part(symb.name) in names:
                res.add(symb)
    return res


def reverse_layer(layer,table):
    """
    Invert the layer "layer" according to constraints given by table (dict)
    """
    target_output=table["sols"]
    sets=table["sets"]
    temp=layer.get_weights()
    weights=temp[0]
    bias=temp[1]
    fa_name=layer.activation.__name__
    sd=process_inverse(fa_name,target_output,sets)
    new_sd=Set_dict()
    for table in sd.table:
        list_eq=[]
        target_output=table["sols"]
        sets=table["sets"]
        for i in range(len(target_output)):
            extra=target_output[i]-bias[i]
            list_eq.append(get_i_weights(weights,i)+[extra])

        system=sp.Matrix(list_eq)
        if system.rows>=system.cols:
            if len(system.free_symbols)>system.rows-system.cols:
                targ,targ_sol=process_system(system)
                for symb in targ_sol:
                    sets[symb]["equal"]=targ_sol[symb]
                new_sd.add_new(targ,sets)
            else:
                new_sd.add_new(process_system_eps(system),sets)
        else:
            new_sd.add_new(make_target(linsolve(system)),sets)
    sd=new_sd.copy()
    return sd


def make_target(finiteset):
    """
    Convert finiteset in list
    """
    res=[]
    for obj in finiteset.args[0]:
        res.append(obj)
    return res

def reverse_network(model):
    """
    Give the preimage of model
    """
    n=model.layers[-1].output.shape[1]
    target_output=[]
    for i in range(n):
        target_output.append(sp.Symbol("y"+str(i)))
    list_layers=model.layers
    target=target_output
    table={"sols":target_output,"sets":{}}
    full_table=Set_dict()
    return reverse_layers(list_layers,table,len(list_layers)-1,full_table)

def reverse_layers(list_layers,table,i,full_table):
    """
    Iteratively inverts layers in list_layers
    """
    new_sd=Set_dict()
    temp_sd=reverse_layer(list_layers[i],table)
    new_sd.merge_with(temp_sd)
    
    if i==0:
        full_table.merge_with(new_sd)
    else:
        for new_table in new_sd.table:
            reverse_layers(list_layers,new_table,i-1,full_table)
        if i==len(list_layers)-1:
            return full_table
        
def standard_sets(symbols):
    res={}
    for symbol in symbols:
        res[symbol]={"lower_than":sp.oo, "greater_than":-sp.oo}
    return res
        
def process_inverse(fa_name,target_output,sets):
    """
    Choose the correct method to invert the layer according to the activation
    """
    sd=Set_dict()
    symbols=get_free_symb(target_output,["tau","y"])
    fa_inverse=fa_name+'_reverse'
    if fa_inverse=='linear_reverse':
        sets=standard_sets(symbols)
        sd.add_new(target_output,sets)
    if fa_inverse=='leaky_relu_reverse':
        container=target_output
        cases=set_cases(container,[identity,minus_identity])
        temp_target=set_cases(container,[ilr_case1,ilr_case2])
        for i in range(len(cases)):
            #new_sets=solve_linear_inequalities(cases[i],sets)
          
            new_sets=solve_linear_inequalities(cases[i],symbols)
            sd.add_new(temp_target[i],new_sets)
    return sd

def char_part(symb):
    """
    Return the name of the symbol "symb" without it index
    """
    res=''
    for char in str(symb):
        if char.isnumeric():
            return res
        else:
            res+=char
    return res

def get_free_symb(eqs,names):
    """
    Return free variables of the equation "eq" of name in "name"
    """
    res=set()
    for eq in eqs:
        symb_dict=eq.free_symbols
        for symb in symb_dict:
            if char_part(symb) in names:
                res.add(symb)
    return res

# Solving linear inequalities systems

In [97]:
        
def _find_pivot(inequalities,vars):
    """
    Return a variable that has at least two coefficients with opposite
    sign in a system of inequalities.

    Examples
    ========

    >>> eq1 = 2*x - 3*y + z + 1
    >>> eq2 = x - y + 2*z - 2
    >>> eq3 = x + y + 3*z + 4
    >>> eq4 = x - z

    >>> inequalities = [eq1, eq2, eq3, eq4]
    >>> vars={x,y,z}
    >>> _find_pivot(inequalities,symbols)
    y
    """
    memory = {}
    for eq in inequalities:
        symbols=vars.intersection(eq.free_symbols)
        for symbol in symbols:
            if not (symbol in memory.keys()):
                memory[symbol] = [False, False]
            coeff = eq.coeff(symbol)
            if coeff > 0:
                memory[symbol][0] = True
            else:
                memory[symbol][1] = True
            if memory[symbol] == [True, True]:
                return symbol


def _split_min_max(inequalities, pivot):
    """return expressions that are less than or greater than the pivot
    (have a coefficient on the pivot that is negative or positive).
    Inequalities that do not contain a pivot are returned as a list.

    Examples
    ========

    >>> eq1 = 2*x - 3*y + z + 1
    >>> eq2 = x - y + 2*z - 2
    >>> eq3 = x + y + 3*z + 4
    >>> eq4 = x - z

    >>> inequalities = [eq1, eq2, eq3, eq4]
    >>> pivot = y
    >>> _split_min_max(inequalities, pivot)
    (Min(2*x/3 + z/3 + 1/3, x + 2*z - 2), -x - 3*z - 4, [x - z])
    """

    greater_than = []
    lower_than = []
    extra = []
    for eq in inequalities:
        coeff = eq.coeff(pivot)
        if coeff > 0:
            greater_than.append(-(eq - (pivot * coeff)) / coeff)
        elif coeff < 0:
            lower_than.append(-(eq - (pivot * coeff)) / coeff)
        else:
            extra.append(eq)
    return sp.Min(*lower_than), sp.Max(*greater_than), extra


def _merge_mins_maxs(mins, maxs,symbols):
    """Build the system of inequalities which verify that all equations
    of maxs are greater than those of mins.

    Examples
    ========
    
    >>> maxs = -x - 3*z - 4
    >>> mins = Min((2*x + z + 1)/3, x + 2*z - 2)
    >>> symbols = {x,z}
    >>> _merge_mins_maxs(mins, maxs)
    [2*x + 5*z + 2, 5*x/3 + 10*z/3 + 13/3]
    """
    if not isinstance(mins, sp.Min):
        mins = [mins]
    else:
        mins = mins.args

    if not isinstance(maxs, sp.Max):
        maxs = [maxs]
    else:
        maxs = maxs.args
    return [i - j for i in mins for j in maxs]


def _fourier_motzkin(inequalities,symbols):
    """Eliminate variables of system of linear inequalities by using
    Fourier-Motzkin elimination algorithm

    Examples
    ========

    >>> eq1 = 2*x - 3*y + z + 1
    >>> eq2 = x - y + 2*z - 2
    >>> eq3 = x + y + 3*z + 4
    >>> eq4 = x - z
    >>> symbols = {x,y,z}

    >>> ie, d = _fourier_motzkin([eq1, eq2, eq3, eq4])
    >>> ie
    [3*x/2 + 13/10, 7*x/5 + 2/5]
    >>> assert set(d) == set([y, z])
    >>> d[y]
    (Min(2*x/3 + z/3 + 1/3, x + 2*z - 2) > y, y > -x - 3*z - 4)
    >>> d[z]
    (x > z, z > Max(-x/2 - 13/10, -2*x/5 - 2/5))
    """
    pivot = _find_pivot(inequalities,symbols)
    res = {}
    while pivot != None:
        mins, maxs, extra = _split_min_max(inequalities, pivot)
        res[pivot] = {"lower_than":mins, "greater_than":maxs}
        inequalities = _merge_mins_maxs(mins, maxs,symbols) + extra
        pivot = _find_pivot(inequalities,symbols)
    return inequalities, res


def _pick_var(inequalities,vars):
    """Return a free variable of the system of inequalities

    Examples
    ========

    >>> eq1 = 2*x - 3*y + z + 1
    >>> eq2 = x - y + 2*z - 2
    >>> eq3 = x + y + 3*z + 4
    >>> eq4 = x - z
    >>> vars={x,y,z}

    >>> inequalities = [eq1, eq2, eq3, eq4]
    >>> _pick_var(inequalities)
    x
    """
    for eq in inequalities:  # should already be in canonical order
        symbols=vars.intersection(eq.free_symbols)
        for symb in symbols:
            return symb


def _fourier_motzkin_extension(inequalities,symbols):
    """Extension of the Fourier-Motzkin algorithm to the case where
    inequalities do not contain variables that have at least two
    coefficients with opposite sign.

    Examples
    ========

    >>> eq1 = 2*x - 3*y + z + 1
    >>> eq2 = x - y + 2*z - 2
    >>> eq3 = x - y + 3*z + 4
    >>> eq4 = x + z
    >>> symbols = {x,y,z}

    >>> d = _fourier_motzkin_extension([eq1, eq2, eq3, eq4])
    >>> assert set(d) == {x}
    >>> d[x]
    (oo > x, x > Max(-z, y - 3*z - 4, y - 2*z + 2, 3*y/2 - z/2 - 1/2))
    >>> _fourier_motzkin_extension([x - 3, 5 - x])
    {x: (5 > x, x > 3)}
    """

    res = {}
    pivot = _pick_var(inequalities,symbols)
    while pivot and inequalities:
        mins, maxs, extra = _split_min_max(inequalities, pivot)
        res[pivot] = {"lower_than":mins, "greater_than":maxs}
        inequalities = extra
        pivot = _pick_var(inequalities,symbols)
    return res




def solve_linear_inequalities(eqs,symbols):
    """Solve a system of linear inequalities

    Parameters
    ==========

    inequalities: list of sympy equations
        The system of inequalities to solve. All equations in the list
        are assumed to be linear and greater than 0. The system must
        be expressed as follows:

        2x - 3y +  z + 1 > 0
        x  -  y + 2z - 2 > 0
        x  +  y + 3z + 4 > 0
        x       -  z     > 0

    Examples
    ========

    >>> eq1 = 2*x - 3*y + z + 1
    >>> eq2 = x - y + 2*z - 2
    >>> eq3 = x + y + 3*z + 4
    >>> eq4 = x - z

    >>> symbols = {x,y,z}

    >>> d = solve_linear_inequalities([eq1, eq2, eq3, eq4])
    >>> assert set(d) == set([x, y, z])
    >>> d[x]
    (oo > x, x > -2/7)
    >>> d[y]
    (Min(x + 1/3, 3*x - 2) > y, y > -4*x - 4)
    >>> d[z]
    (x > z, z > Max(-2*x + 3*y - 1, -x/2 + y/2 + 1, -x/3 - y/3 - 4/3))

    Explanation
    ===========

    x = 2 is valid because: oo > 2 > -2/7
    y = -1 is valid because: Min(x + 1/3, 3*x - 2) > -1 > -4*x - 4
    z = 1 is valid because: x > 1 > Max(-2*x + 3*y - 1, -x/2 + y/2 + 1, -x/3 - y/3 - 4/3)
    """
    eqs, res1 = _fourier_motzkin(eqs,symbols)
    res2 = _fourier_motzkin_extension(eqs,symbols)
    return {**res1, **res2}    

# Tests

In [98]:
model=tf.keras.Sequential()
model.add(tf.keras.Input(shape=(1,)))
model.add(tf.keras.layers.Dense(2,activation=tf.nn.leaky_relu))
model.add(tf.keras.layers.Dense(1,activation="linear"))
model.compile(optimizer="Adam",loss="mse",metrics='mse')
model.summary()

Model: "sequential_15"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_32 (Dense)             (None, 2)                 4         
_________________________________________________________________
dense_33 (Dense)             (None, 1)                 3         
Total params: 7
Trainable params: 7
Non-trainable params: 0
_________________________________________________________________


In [99]:
res=reverse_network(model)
res.table

[{'sols': [15.0539522567302*y0],
  'sets': {tau0: {'lower_than': Min(0, -0.826757*y0),
    'greater_than': -oo,
    'equal': 1.25171072763636*y0}}},
 {'sols': [-2.97744188077888*y0],
  'sets': {tau0: {'lower_than': -0.826757*y0,
    'greater_than': 0,
    'equal': -1.23784634079017*y0},
   y0: {'lower_than': 0, 'greater_than': -oo}}},
 {'sols': [1.36160916902608*y0],
  'sets': {tau0: {'lower_than': 0,
    'greater_than': -0.826757*y0,
    'equal': 0.113215501999231*y0},
   y0: {'lower_than': oo, 'greater_than': 0}}},
 {'sols': [3.01079269359522*y0],
  'sets': {tau0: {'lower_than': oo,
    'greater_than': Max(0, -0.826757*y0),
    'equal': 1.25171203604075*y0}}}]

In [101]:
y0=1
x0=3.01079269359522*y0
print(model(np.asarray([[x0]])))

tf.Tensor([[1.0000005]], shape=(1, 1), dtype=float32)



Among the 4 given solutions above, only one would be valid according to the chosen y0.
In this example, y0 would be equal to 1.
The solutions have constraints given in "sets" : 

    {'sols': [3.01079269359522*y0],
    'sets': {tau0: {'lower_than': oo,
    'greater_than': Max(0, -0.826757*y0),
    'equal': 1.25171203604075*y0}}}
    
tau0 is the constraint and is equal to 1.25171203604075*y0 = 1.25171203604075 in this example since y0=1.
tau0 must be lower than oo (infinity) so this condition is valid. It also must be greater than the maximum of
0 and -0.826757*y0 which is 0. This second condition is valid, so the right solution must be given by:

    'sols': [3.01079269359522*y0]

    
Let's try this:

y0=1
x0=3.01079269359522*y0
print(model(np.asarray([[x0]])))

Output: tf.Tensor([[1.0000005]], shape=(1, 1), dtype=float32) 

It works!

You will see that with y0=1, other constraints will be False.

Now you can try yourself with other random neural networks.