In [1]:
import numpy as np
import torch
import pytorch_model_summary

from models.transformer import get_transformer_encoder
from models.anomaly_transformer import get_anomaly_transformer

  from .autonotebook import tqdm as notebook_tqdm


## Check Transformer Encoder (before adding relative position embedding)

In [8]:
model = get_transformer_encoder(positional_encoding=None)
print(pytorch_model_summary.summary(model, torch.zeros(10, 512, 512), show_input=False))

-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
    EncoderLayer-1      [10, 512, 512]       3,152,384       3,152,384
    EncoderLayer-2      [10, 512, 512]       3,152,384       3,152,384
    EncoderLayer-3      [10, 512, 512]       3,152,384       3,152,384
    EncoderLayer-4      [10, 512, 512]       3,152,384       3,152,384
    EncoderLayer-5      [10, 512, 512]       3,152,384       3,152,384
    EncoderLayer-6      [10, 512, 512]       3,152,384       3,152,384
Total params: 18,914,304
Trainable params: 18,914,304
Non-trainable params: 0
-----------------------------------------------------------------------


In [6]:
model = get_transformer_encoder(positional_encoding='Sinusoidal')
print(pytorch_model_summary.summary(model, torch.zeros(10, 512, 512), show_input=False))

--------------------------------------------------------------------------------------
                     Layer (type)        Output Shape         Param #     Tr. Param #
   SinusoidalPositionalEncoding-1      [10, 512, 512]               0               0
                   EncoderLayer-2      [10, 512, 512]       3,152,384       3,152,384
                   EncoderLayer-3      [10, 512, 512]       3,152,384       3,152,384
                   EncoderLayer-4      [10, 512, 512]       3,152,384       3,152,384
                   EncoderLayer-5      [10, 512, 512]       3,152,384       3,152,384
                   EncoderLayer-6      [10, 512, 512]       3,152,384       3,152,384
                   EncoderLayer-7      [10, 512, 512]       3,152,384       3,152,384
Total params: 18,914,304
Trainable params: 18,914,304
Non-trainable params: 0
--------------------------------------------------------------------------------------


In [7]:
model = get_transformer_encoder(positional_encoding='Absolute')
print(pytorch_model_summary.summary(model, torch.zeros(10, 512, 512), show_input=False))

-----------------------------------------------------------------------------------
                  Layer (type)        Output Shape         Param #     Tr. Param #
   AbsolutePositionEmbedding-1      [10, 512, 512]         262,144         262,144
                EncoderLayer-2      [10, 512, 512]       3,152,384       3,152,384
                EncoderLayer-3      [10, 512, 512]       3,152,384       3,152,384
                EncoderLayer-4      [10, 512, 512]       3,152,384       3,152,384
                EncoderLayer-5      [10, 512, 512]       3,152,384       3,152,384
                EncoderLayer-6      [10, 512, 512]       3,152,384       3,152,384
                EncoderLayer-7      [10, 512, 512]       3,152,384       3,152,384
Total params: 19,176,448
Trainable params: 19,176,448
Non-trainable params: 0
-----------------------------------------------------------------------------------


## Relative Position Embedding

In [12]:
# Set position index
max_seq_len = 10  # Maximum sequence length example
coords_h = np.arange(max_seq_len)
coords_w = np.arange(max_seq_len-1, -1, -1)
coords = coords_h[:, None] + coords_w[None, :]
print(coords)

[[ 9  8  7  6  5  4  3  2  1  0]
 [10  9  8  7  6  5  4  3  2  1]
 [11 10  9  8  7  6  5  4  3  2]
 [12 11 10  9  8  7  6  5  4  3]
 [13 12 11 10  9  8  7  6  5  4]
 [14 13 12 11 10  9  8  7  6  5]
 [15 14 13 12 11 10  9  8  7  6]
 [16 15 14 13 12 11 10  9  8  7]
 [17 16 15 14 13 12 11 10  9  8]
 [18 17 16 15 14 13 12 11 10  9]]


In [18]:
# Relative position embedding
n_head = 2  # Number of heads example
relative_position_embedding_table = torch.rand(2*max_seq_len-1, n_head)
print('table :')
for i, bias in enumerate(relative_position_embedding_table):
    print(i, ':', bias.numpy())

table :
0 : [0.14965618 0.16788995]
1 : [0.9504504  0.10648322]
2 : [0.3820259  0.26660728]
3 : [0.5858676 0.9569891]
4 : [0.30918825 0.55566347]
5 : [0.5504618 0.6121726]
6 : [0.42828113 0.7249437 ]
7 : [0.69184756 0.5860372 ]
8 : [0.81165934 0.7614068 ]
9 : [0.147039   0.13363111]
10 : [0.93855715 0.44990999]
11 : [0.8766833 0.7013564]
12 : [0.8035398 0.8917592]
13 : [0.7712254  0.45665962]
14 : [0.54262006 0.34000492]
15 : [0.9725495 0.4922533]
16 : [0.8967813  0.36961406]
17 : [0.82937473 0.42501265]
18 : [0.6240979  0.92276263]


In [21]:
relative_position_embedding = relative_position_embedding_table[coords.flatten()].view(max_seq_len, max_seq_len, -1)
relative_position_embedding = relative_position_embedding.permute(2, 0, 1).contiguous()
print('rel pos :')
print('<head 1>')
for i, bias in enumerate(relative_position_embedding[0]):
    print(i, ':', bias.numpy())

rel pos :
<head 1>
0 : [0.147039   0.81165934 0.69184756 0.42828113 0.5504618  0.30918825
 0.5858676  0.3820259  0.9504504  0.14965618]
1 : [0.93855715 0.147039   0.81165934 0.69184756 0.42828113 0.5504618
 0.30918825 0.5858676  0.3820259  0.9504504 ]
2 : [0.8766833  0.93855715 0.147039   0.81165934 0.69184756 0.42828113
 0.5504618  0.30918825 0.5858676  0.3820259 ]
3 : [0.8035398  0.8766833  0.93855715 0.147039   0.81165934 0.69184756
 0.42828113 0.5504618  0.30918825 0.5858676 ]
4 : [0.7712254  0.8035398  0.8766833  0.93855715 0.147039   0.81165934
 0.69184756 0.42828113 0.5504618  0.30918825]
5 : [0.54262006 0.7712254  0.8035398  0.8766833  0.93855715 0.147039
 0.81165934 0.69184756 0.42828113 0.5504618 ]
6 : [0.9725495  0.54262006 0.7712254  0.8035398  0.8766833  0.93855715
 0.147039   0.81165934 0.69184756 0.42828113]
7 : [0.8967813  0.9725495  0.54262006 0.7712254  0.8035398  0.8766833
 0.93855715 0.147039   0.81165934 0.69184756]
8 : [0.82937473 0.8967813  0.9725495  0.54262006 

In [22]:
coords.flatten()

array([ 9,  8,  7,  6,  5,  4,  3,  2,  1,  0, 10,  9,  8,  7,  6,  5,  4,
        3,  2,  1, 11, 10,  9,  8,  7,  6,  5,  4,  3,  2, 12, 11, 10,  9,
        8,  7,  6,  5,  4,  3, 13, 12, 11, 10,  9,  8,  7,  6,  5,  4, 14,
       13, 12, 11, 10,  9,  8,  7,  6,  5, 15, 14, 13, 12, 11, 10,  9,  8,
        7,  6, 16, 15, 14, 13, 12, 11, 10,  9,  8,  7, 17, 16, 15, 14, 13,
       12, 11, 10,  9,  8, 18, 17, 16, 15, 14, 13, 12, 11, 10,  9])

In [2]:
# Check Transformer encoder with relative position embedding
model = get_transformer_encoder(positional_encoding=None, relative_position_embedding=True)
print(pytorch_model_summary.summary(model, torch.zeros(10, 512, 512), show_input=False))

-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
    EncoderLayer-1      [10, 512, 512]       3,160,568       3,160,568
    EncoderLayer-2      [10, 512, 512]       3,160,568       3,160,568
    EncoderLayer-3      [10, 512, 512]       3,160,568       3,160,568
    EncoderLayer-4      [10, 512, 512]       3,160,568       3,160,568
    EncoderLayer-5      [10, 512, 512]       3,160,568       3,160,568
    EncoderLayer-6      [10, 512, 512]       3,160,568       3,160,568
Total params: 18,963,408
Trainable params: 18,963,408
Non-trainable params: 0
-----------------------------------------------------------------------


## Check Anomaly Transformer

In [2]:
# Summary the example model.
model = get_anomaly_transformer(d_data=10,
                                d_embed=64,
                                hidden_dim_rate=4.,
                                max_seq_len=100,
                                mask_token_rate=(0.05,0.15),
                                positional_encoding=None,
                                relative_position_embedding=True,
                                transformer_n_layer=3,
                                transformer_n_head=4,
                                dropout=0.1)
print(pytorch_model_summary.summary(model, torch.zeros(16, 100, 10), show_input=False))

----------------------------------------------------------------------------
           Layer (type)        Output Shape         Param #     Tr. Param #
               Linear-1       [16, 100, 64]             704             704
   TransformerEncoder-2       [16, 100, 64]         152,340         152,340
               Linear-3      [16, 100, 256]          16,640          16,640
                 GELU-4      [16, 100, 256]               0               0
               Linear-5       [16, 100, 10]           2,570           2,570
Total params: 172,254
Trainable params: 172,254
Non-trainable params: 0
----------------------------------------------------------------------------


In [3]:
# Check model instances.
print(model.mask_token_table[:10])
print()
print(model.mask_token)

[14, 6, 14, 12, 8, 12, 10, 11, 10, 14]

Parameter containing:
tensor([[ 0.0053, -0.0057, -0.0327,  0.0048,  0.0043, -0.0022, -0.0176, -0.0310,
         -0.0053,  0.0069, -0.0078,  0.0106,  0.0044,  0.0415, -0.0048, -0.0029,
          0.0119, -0.0117,  0.0170,  0.0006, -0.0185,  0.0253, -0.0154, -0.0067,
         -0.0159, -0.0109,  0.0060, -0.0160,  0.0078,  0.0099,  0.0308,  0.0719,
          0.0068, -0.0020, -0.0179,  0.0115,  0.0029,  0.0234, -0.0004,  0.0065,
          0.0188, -0.0159, -0.0118, -0.0079, -0.0005, -0.0236, -0.0139,  0.0228,
          0.0358,  0.0157, -0.0106,  0.0102, -0.0012,  0.0266,  0.0271, -0.0069,
          0.0165, -0.0182, -0.0007,  0.0148,  0.0243,  0.0238, -0.0059, -0.0005]],
       requires_grad=True)


In [5]:
# Check autograd functions.
l1_loss = torch.nn.L1Loss()

x = torch.rand(16, 100, 10)
y = model(x)
loss = l1_loss(x, y)
print('loss :', loss.item())

loss : 2.413358211517334


In [6]:
loss.backward()
print(model.linear_embedding.bias.grad)

tensor([-0.2086, -0.2762,  0.2535,  1.0324,  0.1024,  0.0914,  0.9866, -0.5405,
         0.3809, -0.0159,  0.4031,  0.1038,  0.8071, -0.4886, -0.3818, -0.1702,
         1.0533,  0.5017, -0.2825, -0.7611, -0.3886,  0.7570,  0.8553, -0.0265,
         0.0247, -0.4464,  0.1729, -0.0541,  0.0444,  0.1384, -0.3599, -0.4125,
         0.0700,  0.4699, -0.2983, -0.3459,  0.1212,  0.3143, -0.2653,  0.5975,
         0.3497,  0.0197,  0.1166, -0.0435, -0.5129,  0.4491, -0.4168,  0.1174,
        -0.2196, -0.2135, -0.0068, -1.2207,  0.1179, -0.0373, -0.1633, -0.3557,
        -0.5474, -0.8561,  0.5296,  0.0873, -0.8980, -0.6438,  0.2070, -0.0990])


In [7]:
print(model.mask_token.grad)

tensor([[-0.5545, -0.7614,  0.4503,  0.7524,  0.0604, -0.4823,  0.4422,  0.8438,
          1.0116,  0.7272,  0.1171,  0.6688,  0.7854,  0.0084,  0.6613,  0.5270,
          0.6445,  0.1250,  1.0520, -1.3416,  1.2106,  0.8090,  0.1353,  0.4056,
          0.4472, -1.2009, -0.8700, -1.2629,  0.5274, -0.5598, -0.2111,  0.1449,
          0.7595,  1.2007,  0.3101, -1.2605,  0.2475, -0.2101, -0.7676,  1.1885,
          0.7706,  0.4547, -0.6664, -0.9339, -1.1345, -0.4515, -0.1796,  1.0323,
         -0.5746, -0.0408,  0.2519, -0.4176,  0.1028, -1.4091, -0.1960, -1.2700,
         -1.6909, -1.5927, -0.2550,  0.4100, -0.2361,  0.4618,  0.0729,  0.6517]])


In [8]:
print(model.transformer_encoder.encoder_layers[0].attention_layer.relative_position_embedding_table.grad[:10])

tensor([[-1.7156e-05, -6.6179e-06,  2.1336e-06, -1.0048e-05],
        [-2.6890e-05, -2.2301e-06,  1.1036e-05,  1.2869e-06],
        [-2.6311e-05,  6.2818e-06,  9.8580e-06, -3.8717e-07],
        [-1.3054e-05, -4.7467e-06,  1.4798e-05,  8.6596e-06],
        [-2.4838e-05, -5.9972e-06,  2.3861e-05,  3.2242e-05],
        [-2.4469e-05,  3.4905e-06,  2.1670e-05,  2.0059e-05],
        [-3.3681e-05, -6.1303e-06,  4.3800e-05,  2.1766e-05],
        [-2.6824e-05,  1.8091e-05,  6.1962e-05,  1.0274e-05],
        [-4.8427e-05,  1.8547e-06,  7.4820e-05,  3.0014e-05],
        [-4.1165e-05, -1.0241e-05,  9.8454e-05,  3.9157e-05]])
