In [1]:
from dataclasses import dataclass
import numpy as np
import numba as nb
from numba import njit, prange
from tqdm.auto import tqdm
from numba.typed import List, Dict
from scipy.special import factorial
from prettytable import PrettyTable 

In [2]:
%load_ext autoreload
%autoreload 2

## Layer Norm

In [3]:
from src.blocks.LayerNorm import der_bound_variance, der_bound_g, der_bounds_g, \
    der_bounds_variance, der_bounds_g_circ_variance

In [8]:
from src.blocks.LayerNorm import der_bounds_layer_norm
from src.PrintUtil import pretty_results

In [79]:
n = 5
k = 10
results = {
    "g": der_bounds_g(n),
    "var": der_bounds_variance(n, k, 10.0, 1.0,),
    "g ° var": der_bounds_g_circ_variance(n, k, 10.0, 1.0,),
    "layer norm": der_bounds_layer_norm(n, k, 10.0, 1.0,)
}
print(pretty_results(n, k, dict(results)))

+-------------------------+-------+-------+----------+------------+
| Type                    | g     | var   | g ° var  | layer norm |
+-------------------------+-------+-------+----------+------------+
| (1, 0, ...)             | 0.75  | 20.00 | 32.17    | 321.71     |
| (1, 1, 0, ...)          | NAN   | 2.00  | 768.67   | 64.34      |
| (2, 0, ...)             | 0.94  | 2.00  | 580.42   | 64.34      |
| (1, 1, 1, 0, ...)       | NAN   | 0.00  | 7.89E+04 | 2.31E+03   |
| (2, 1, 0, ...)          | NAN   | 0.00  | 5.26E+04 | 2.12E+03   |
| (3, 0, ...)             | 1.64  | 0.00  | 2.85E+04 | 1.74E+03   |
| (1, 1, 1, 1, 0, ...)    | NAN   | 0.00  | 1.42E+07 | 3.16E+05   |
| (2, 1, 1, 0, ...)       | NAN   | 0.00  | 8.87E+06 | 2.63E+05   |
| (2, 2, 0, ...)          | NAN   | 0.00  | 5.62E+06 | 2.10E+05   |
| (3, 1, 0, ...)          | NAN   | 0.00  | 4.33E+06 | 1.86E+05   |
| (4, 0, ...)             | 3.69  | 0.00  | 1.85E+06 | 1.14E+05   |
| (1, 1, 1, 1, 1, 0, ...) | NAN   | 0.00  | 3.55

## Neural Network

In [11]:
from src.blocks.NeuralNet import der_bounds_neural_network

In [78]:
%%time
n = 6
k = 10
results = der_bounds_neural_network(n, k, 64, 1.0, 1.0)
print(pretty_results(n, k, dict(Bounds=results)))

+----------------------------+----------+
| Type                       | Bounds   |
+----------------------------+----------+
| (1, 0, ...)                | 65.00    |
| (1, 1, 0, ...)             | 128.00   |
| (2, 0, ...)                | 320.00   |
| (1, 1, 1, 0, ...)          | 192.00   |
| (2, 1, 0, ...)             | 512.00   |
| (3, 0, ...)                | 1.60E+03 |
| (1, 1, 1, 1, 0, ...)       | 256.00   |
| (2, 1, 1, 0, ...)          | 896.00   |
| (2, 2, 0, ...)             | 1.92E+03 |
| (3, 1, 0, ...)             | 2.82E+03 |
| (4, 0, ...)                | 9.41E+03 |
| (1, 1, 1, 1, 1, 0, ...)    | 320.00   |
| (2, 1, 1, 1, 0, ...)       | 2.18E+03 |
| (2, 2, 1, 0, ...)          | 4.74E+03 |
| (3, 1, 1, 0, ...)          | 6.91E+03 |
| (3, 2, 0, ...)             | 1.16E+04 |
| (4, 1, 0, ...)             | 1.97E+04 |
| (5, 0, ...)                | 6.53E+04 |
| (1, 1, 1, 1, 1, 1, 0, ...) | 384.00   |
| (2, 1, 1, 1, 1, 0, ...)    | 8.45E+03 |
| (2, 2, 1, 1, 0, ...)       | 1.8

## Attention

In [81]:
from src.blocks.Attention import der_bounds_dp, der_bounds_softmax, der_bounds_softmax_circ_dp, der_bounds_attention, der_bounds_multi_head_attention

In [83]:
n = 5
input_dim = 64
sequence_length=40
k = input_dim * sequence_length
ndim_k = 16

results = {
r"dp" : der_bounds_dp(n, input_dim * sequence_length, ndim_k, max_weight_k=1.0, max_weight_q=1.0, domain_bound=1.0),
r"Softmax" : der_bounds_softmax(n, input_dim * sequence_length),
r"Softmax°dp" : der_bounds_softmax_circ_dp(n, input_dim, sequence_length, ndim_k, max_weight_k=1.0, max_weight_q=1.0, domain_bound=1.0),
r"Attention" : der_bounds_attention(n, input_dim, sequence_length, ndim_k, max_weight_k=1.0, max_weight_q=1.0, max_weight_v=1.0, domain_bound=1.0),
r"Multi-Head Attention" : der_bounds_multi_head_attention(
    n, input_dim, sequence_length, ndim_k, ndim_v=ndim_k, max_weight_k=1.0, max_weight_q=1.0, max_weight_w=1.0, max_weight_v=1.0, domain_bound=1.0
)
}
print(pretty_results(n, k, results))

+-------------------------+----------+---------+------------+-----------+----------------------+
| Type                    | dp       | Softmax | Softmax°dp | Attention | Multi-Head Attention |
+-------------------------+----------+---------+------------+-----------+----------------------+
| (1, 0, ...)             | 8.19E+04 | 1.00    | 1.25E+08   | 3.20E+11  | 5.12E+12             |
| (1, 1, 0, ...)          | 32.00    | 2.00    | 4.19E+13   | 1.07E+17  | 1.72E+18             |
| (2, 0, ...)             | 32.00    | 2.00    | 3.14E+13   | 8.04E+16  | 2.57E+18             |
| (1, 1, 1, 0, ...)       | 0.00     | 6.00    | 1.18E+21   | 3.01E+24  | 4.82E+25             |
| (2, 1, 0, ...)          | 0.00     | 6.00    | 7.84E+20   | 2.01E+24  | 6.42E+25             |
| (3, 0, ...)             | 0.00     | 6.00    | 4.25E+20   | 1.09E+24  | 1.04E+26             |
| (1, 1, 1, 1, 0, ...)    | 0.00     | 24.00   | 6.06E+28   | 1.55E+32  | 2.48E+33             |
| (2, 1, 1, 0, ...)       | 0.