In [1]:
from dataclasses import dataclass
import numpy as np
import numba as nb
from numba import njit, prange
from tqdm.auto import tqdm
from numba.typed import List, Dict
from scipy.special import factorial
from prettytable import PrettyTable 

In [2]:
%load_ext autoreload
%autoreload 2

## Layer Norm

In [3]:
from src.blocks.LayerNorm import der_bound_variance, der_bound_g, der_bounds_g, \
    der_bounds_variance, der_bounds_g_circ_variance

In [4]:
from src.blocks.LayerNorm import der_bounds_layer_norm
from src.PrintUtil import pretty_results, pretty_results_latex, pretty_max_values, pretty_max_values_latex, get_max_vals

In [28]:
# by_level = True
by_level = False

In [16]:
n = 5
k = 10
results = {
    "g": der_bounds_g(n, by_level),
    "var": der_bounds_variance(n, k, 10.0, 1.0, by_level),
    "g ° var": der_bounds_g_circ_variance(n, k, 10.0, 1.0, by_level),
    "layer norm": der_bounds_layer_norm(n, k, 10.0, 1.0, by_level)
}
print(pretty_results(n, k, dict(results)))

+-------------------------+-------+-------+----------+------------+
| Type                    | g     | var   | g ° var  | layer norm |
+-------------------------+-------+-------+----------+------------+
| (1, 0, ...)             | 0.75  | 20.00 | 32.17    | 321.71     |
| (1, 1, 0, ...)          | NAN   | 20.00 | 782.17   | 7.89E+03   |
| (2, 0, ...)             | 0.94  | 20.00 | 782.17   | 7.89E+03   |
| (1, 1, 1, 0, ...)       | NAN   | 20.00 | 7.99E+04 | 8.01E+05   |
| (2, 1, 0, ...)          | NAN   | 20.00 | 7.99E+04 | 8.01E+05   |
| (3, 0, ...)             | 1.64  | 20.00 | 7.99E+04 | 8.01E+05   |
| (1, 1, 1, 1, 0, ...)    | NAN   | 20.00 | 1.43E+07 | 1.44E+08   |
| (2, 1, 1, 0, ...)       | NAN   | 20.00 | 1.43E+07 | 1.44E+08   |
| (2, 2, 0, ...)          | NAN   | 20.00 | 1.43E+07 | 1.44E+08   |
| (3, 1, 0, ...)          | NAN   | 20.00 | 1.43E+07 | 1.44E+08   |
| (4, 0, ...)             | 3.69  | 20.00 | 1.43E+07 | 1.44E+08   |
| (1, 1, 1, 1, 1, 0, ...) | NAN   | 20.00 | 4.49

In [10]:
n = 5
results = {
    "base": der_bounds_layer_norm(n, 10, 10.0, 1.0,),
    "a-": der_bounds_layer_norm(n, 5, 10.0, 1.0,),
    "a+": der_bounds_layer_norm(n, 20, 10.0, 1.0,),
    "b--": der_bounds_layer_norm(n, 20, .1, 1.0,),
    "b-": der_bounds_layer_norm(n, 10, 1.0, 1.0,),
    "b+": der_bounds_layer_norm(n, 10, 100.0, 1.0,),
    "b++": der_bounds_layer_norm(n, 10, 1000.0, 1.0,),
}
print(pretty_results(n, 5, dict(results)))

+-------------------+----------+----------+----------+--------+----------+----------+----------+
| Type              | base     | a-       | a+       | b--    | b-       | b+       | b++      |
+-------------------+----------+----------+----------+--------+----------+----------+----------+
| (1, 0, ...)       | 321.71   | 321.71   | 321.71   | 1.73   | 18.67    | 1.67E+04 | 1.52E+06 |
| (1, 1, 0, ...)    | 7.75E+03 | 7.75E+03 | 7.75E+03 | 36.52  | 63.51    | 7.50E+06 | 7.50E+09 |
| (2, 0, ...)       | 5.87E+03 | 5.87E+03 | 5.87E+03 | 36.44  | 60.89    | 5.63E+06 | 5.63E+09 |
| (1, 1, 1, 0, ...) | 7.91E+05 | 7.91E+05 | 7.91E+05 | 58.08  | 185.68   | 7.88E+09 | 7.88E+13 |
| (2, 1, 0, ...)    | 5.28E+05 | 5.28E+05 | 5.28E+05 | 57.25  | 151.18   | 5.25E+09 | 5.25E+13 |
| (3, 0, ...)       | 2.86E+05 | 2.86E+05 | 2.86E+05 | 55.67  | 118.12   | 2.84E+09 | 2.84E+13 |
| (1, 1, 1, 1, 0)   | 1.42E+08 | 1.42E+08 | 1.42E+08 | 77.64  | 2.04E+03 | 1.42E+13 | 1.42E+18 |
| (2, 1, 1, 0, ...) | 8.89E+07

## Neural Network

In [18]:
from src.blocks.NeuralNet import der_bounds_neural_network

In [29]:
%%time
n = 6
k = 10
softplus = der_bounds_neural_network(n, k, 64, 1., 1.0, "softplus", by_level)
tanh = der_bounds_neural_network(n, k, 64, 1., 1.0, "tanh", by_level)
GeLU = der_bounds_neural_network(n, k, 64, 1., 1.0, "GeLU", by_level)
SWISH = der_bounds_neural_network(n, k, 64, 1., 1.0, "SWISH", by_level)
results = dict(softplus=softplus, tanh=tanh, GeLU=GeLU, SWISH=SWISH)
print(pretty_results(n, k, results))

+----------------------------+----------+----------+----------+----------+
| Type                       | softplus | tanh     | GeLU     | SWISH    |
+----------------------------+----------+----------+----------+----------+
| (1, 0, ...)                | 17.00    | 129.00   | 73.25    | 71.39    |
| (1, 1, 0, ...)             | 22.16    | 384.00   | 103.22   | 102.39   |
| (2, 0, ...)                | 50.47    | 1.02E+03 | 237.41   | 236.78   |
| (1, 1, 1, 0, ...)          | 30.16    | 896.00   | 151.46   | 122.11   |
| (2, 1, 0, ...)             | 76.31    | 2.82E+03 | 399.40   | 283.67   |
| (3, 0, ...)                | 236.94   | 8.96E+03 | 1.25E+03 | 870.75   |
| (1, 1, 1, 1, 0, ...)       | 38.33    | 1.92E+03 | 257.73   | 154.11   |
| (2, 1, 1, 0, ...)          | 125.68   | 9.98E+03 | 1.15E+03 | 500.23   |
| (2, 2, 0, ...)             | 267.71   | 2.20E+04 | 2.52E+03 | 1.06E+03 |
| (3, 1, 0, ...)             | 393.39   | 3.20E+04 | 3.67E+03 | 1.56E+03 |
| (4, 0, ...)            

In [30]:
print(pretty_max_values(n, k, results))

+----------+--------+----------+----------+----------+----------+----------+
|          | 1      | 2        | 3        | 4        | 5        | 6        |
+----------+--------+----------+----------+----------+----------+----------+
| softplus | 17.00  | 50.47    | 236.94   | 1.34E+03 | 1.33E+04 | 1.61E+05 |
| tanh     | 129.00 | 1.02E+03 | 8.96E+03 | 9.83E+04 | 1.34E+06 | 4.98E+07 |
| GeLU     | 73.25  | 237.41   | 1.25E+03 | 1.16E+04 | 1.81E+05 | 4.27E+06 |
| SWISH    | 71.39  | 236.78   | 870.75   | 5.33E+03 | 4.13E+04 | 5.91E+05 |
+----------+--------+----------+----------+----------+----------+----------+


In [39]:
from src.blocks.NeuralNet import der_bound_activation

## Attention

In [32]:
from src.blocks.Attention import der_bounds_dp, der_bounds_softmax, der_bounds_softmax_circ_dp, der_bounds_attention, der_bounds_multi_head_attention

In [37]:
by_level = True
n = 5
input_dim = 3
sequence_length = 5
k = input_dim * sequence_length
ndim_k = 3

results = {
r"dp" : der_bounds_dp(n, input_dim * sequence_length, ndim_k, max_weight_k=1.0, max_weight_q=1.0, domain_bound=1.0, by_level=by_level),
r"Softmax" : der_bounds_softmax(n, input_dim * sequence_length),
r"Softmax°dp" : der_bounds_softmax_circ_dp(n, input_dim, sequence_length, ndim_k, max_weight_k=1.0, max_weight_q=1.0, domain_bound=1.0, by_level=by_level),
r"Attention" : der_bounds_attention(n, input_dim, sequence_length, ndim_k, max_weight_k=1.0, max_weight_q=1.0, max_weight_v=1.0, domain_bound=1.0, by_level=by_level),
r"Multi-Head Attention" : der_bounds_multi_head_attention(
    n, input_dim, sequence_length, ndim_k, ndim_v=ndim_k, max_weight_k=1.0, max_weight_q=1.0, max_weight_w=1.0, max_weight_v=1.0, domain_bound=1.0, by_level=by_level
)
}
print(pretty_results(n, k, results))

+-------------------------+-------+---------+------------+-----------+----------------------+
| Type                    | dp    | Softmax | Softmax°dp | Attention | Multi-Head Attention |
+-------------------------+-------+---------+------------+-----------+----------------------+
| (1, 0, ...)             | 90.00 | 1.00    | 5.35E+03   | 8.03E+04  | 2.41E+05             |
| (1, 1, 0, ...)          | 90.00 | 2.00    | 2.75E+05   | 4.29E+06  | 2.57E+07             |
| (2, 0, ...)             | 90.00 | 2.00    | 2.75E+05   | 4.29E+06  | 2.57E+07             |
| (1, 1, 1, 0, ...)       | 90.00 | 6.00    | 1.09E+09   | 1.64E+10  | 2.96E+11             |
| (2, 1, 0, ...)          | 90.00 | 6.00    | 1.09E+09   | 1.64E+10  | 2.96E+11             |
| (3, 0, ...)             | 90.00 | 6.00    | 1.09E+09   | 1.64E+10  | 2.96E+11             |
| (1, 1, 1, 1, 0, ...)    | 90.00 | 24.00   | 6.36E+12   | 9.55E+13  | 6.88E+15             |
| (2, 1, 1, 0, ...)       | 90.00 | 24.00   | 6.36E+12   | 9

In [45]:
from src.blocks.TransfomerBlock import der_bounds_tblock

## Transformer Block

In [46]:
n = 5
input_dim = 5
output_dim = 1
sequence_length = 1
k = input_dim * sequence_length
ndim_k = 3
mw_activ = .01
mw_nn = .001
mw_ln = .01

result, all_res = der_bounds_tblock(
    n, input_dim, output_dim, sequence_length, ndim_k, ndim_v=ndim_k, ndim_nn=64,
    max_weight_k=mw_activ, max_weight_q=mw_activ, max_weight_v=mw_activ, max_weight_w=mw_activ,
    max_weight_ln1=mw_ln/10, max_weight_a=mw_ln, max_weight_b=mw_nn, max_weight_ln2=mw_ln,
    domain_bound=1.0, verbose=2, by_level=by_level,
)
print(pretty_max_values(n, k, dict(tblock=result)))

Computing Bound for Multi-Head Attention
+--------+------+------+------+------+-------+
|        | 1    | 2    | 3    | 4    | 5     |
+--------+------+------+------+------+-------+
| tblock | 0.08 | 0.46 | 1.84 | 9.18 | 55.08 |
+--------+------+------+------+------+-------+
Computing Bound for Layer Norm 1
+--------+------+------+------+------+------+
|        | 1    | 2    | 3    | 4    | 5    |
+--------+------+------+------+------+------+
| tblock | 0.02 | 0.05 | 0.07 | 0.09 | 0.10 |
+--------+------+------+------+------+------+
Computing Bound for Neural Net
+--------+------+------+------+------+------+
|        | 1    | 2    | 3    | 4    | 5    |
+--------+------+------+------+------+------+
| tblock | 0.02 | 0.06 | 0.29 | 1.54 | 9.60 |
+--------+------+------+------+------+------+
Computing Bound for Layer Norm 2
+--------+------+------+------+------+------+
|        | 1    | 2    | 3    | 4    | 5    |
+--------+------+------+------+------+------+
| tblock | 0.17 | 0.52 | 0.69