-
Notifications
You must be signed in to change notification settings - Fork 1
/
initialization_scaling_with_aggregation_size.py
69 lines (53 loc) · 2.44 KB
/
initialization_scaling_with_aggregation_size.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import matplotlib.pyplot as plt
import tikzplotlib
def mean_norm(sequence):
return np.mean(np.linalg.norm(sequence, axis=-1))
def softmax(logs):
return np.exp(logs) / np.sum(np.exp(logs), axis=-1, keepdims=True)
def normalize(logs, axis=-1, epsilon=0.001):
if logs.shape[-1] == 1:
return np.ones_like(logs)
return (logs - np.mean(logs, axis=axis, keepdims=True)) / (np.std(logs, axis=axis, keepdims=True) + epsilon)
def plot(init_samples=1024*16, max_sequence_length_exp=12, model_dim=128):
std_results = {'softmax attention': [],
'mean pooling': [],
'sum pooling': [],
'max pooling': [],
'normalized': []}
norm_results = {'softmax attention': [],
'mean pooling': [],
'sum pooling': [],
'max pooling': [],
'normalized': []}
sequence_lengths = [2 ** sequence_length_exp for sequence_length_exp in range(max_sequence_length_exp)]
for sequence_length in sequence_lengths:
samples = init_samples // sequence_length
values = np.random.normal(size=[samples, sequence_length, model_dim])
keys = np.random.normal(size=[samples, sequence_length, model_dim])
querries = np.random.normal(size=[samples, model_dim, sequence_length])
logits = np.matmul(keys, querries) / np.sqrt(model_dim)
attention = softmax(logits)
out = {'softmax attention': np.matmul(attention, values),
'mean pooling': np.mean(values, axis=1, keepdims=True),
'sum pooling': np.sum(values, axis=1, keepdims=True),
'max pooling': np.max(values, axis=1, keepdims=True),
'normalized': normalize(np.sum(values, axis=1, keepdims=True))}
for key in std_results.keys():
std_results[key].append(np.std(out[key]))
norm_results[key].append(mean_norm(out[key]))
plt.figure('Standard Deviation')
for result in std_results.values():
plt.loglog(sequence_lengths, result)
plt.ylabel('Standard Deviation')
plt.xlabel('Sequence Length')
tikzplotlib.save('std_scaling.tex')
plt.figure('Norm')
for result in norm_results.values():
plt.loglog(sequence_lengths, result)
plt.ylabel('Norm')
plt.xlabel('Sequence Length')
plt.legend(norm_results.keys(), loc='lower left')
tikzplotlib.save('norm_scaling.tex')
plot()
plt.show()