In [None]:
import torch
from torch.cuda.amp import autocast
import numpy as np
import matplotlib.pyplot as plt

In [None]:
class TransformerAE(torch.nn.Module):
  def __init__(self, in_dim, d_model):
    super(TransformerAE, self).__init__()
    self.linear1 = torch.nn.Linear(in_dim, d_model)
    self.transformer1 = torch.nn.TransformerEncoderLayer(d_model=d_model, nhead=8, batch_first=True)
    self.transformer2 = torch.nn.TransformerEncoderLayer(d_model=d_model, nhead=8, batch_first=True)
    self.linear3 = torch.nn.Linear(d_model, d_model//4)
    self.linear4 = torch.nn.Linear(d_model//4, d_model)
    self.transformer3 = torch.nn.TransformerEncoderLayer(d_model=d_model, nhead=8, batch_first=True)
    self.transformer4 = torch.nn.TransformerEncoderLayer(d_model=d_model, nhead=8, batch_first=True)
    self.activation = torch.nn.ReLU()
    self.linear2 = torch.nn.Linear(d_model, in_dim)

  def forward(self, x, m):
    x = self.linear1(x)
    x = self.activation(x)
    x = self.transformer1(x, src_key_padding_mask=m)
    x = self.transformer2(x, src_key_padding_mask=m)
    x = self.linear3(x)
    x = self.linear4(x)
    x = self.transformer3(x, src_key_padding_mask=m)
    x = self.transformer4(x, src_key_padding_mask=m)
    x = self.activation(x)
    x = self.linear2(x)

    return x

In [None]:
def convert_sync_batchnorm(module):
  print("Check module: ", module)
  if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
    print("Find!: ", module)
  for name, child in module.named_children():
    convert_sync_batchnorm(child)

In [None]:
device = 'cpu'
model = TransformerAE(in_dim=10, d_model=128).to(device)
model.eval()
data = torch.rand((1, 10, 10))

In [None]:
src_key_mask = torch.tensor([[False, False, False, False, False, False, False, False, False, True]]) #, True]])
# torch.tensor([[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0]]) #
output = model(data, src_key_mask)
print(output)

tensor([[[ 0.6527, -0.1148,  0.2900, -0.5350, -0.1129,  0.3372, -0.1269,
           0.0983, -0.0913, -0.5693],
         [ 0.5595, -0.3358,  0.0597, -0.4438, -0.2606,  0.1555, -0.2767,
           0.0903,  0.1831, -0.5175],
         [ 0.4743, -0.3141,  0.1776, -0.4160, -0.3058,  0.1261, -0.2394,
          -0.0543,  0.2081, -0.5527],
         [ 0.4412,  0.1224,  0.2447, -0.5129, -0.1260,  0.3266, -0.1891,
          -0.0256, -0.2262, -0.5772],
         [ 0.5774, -0.2624,  0.2801, -0.4644, -0.2591,  0.1187,  0.0482,
          -0.0228, -0.0529, -0.5919],
         [ 0.6015, -0.4583,  0.1572, -0.4470, -0.4000,  0.2482,  0.0099,
          -0.1215, -0.6056, -0.6185],
         [ 0.7189, -0.1084,  0.0248, -0.2436, -0.2130,  0.4159, -0.3751,
          -0.0783, -0.2534, -0.6384],
         [ 0.8287, -0.4687, -0.0795, -0.3876, -0.4481,  0.1981, -0.1872,
          -0.0986, -0.2466, -0.6868],
         [ 0.7980, -0.2646,  0.2422, -0.3622, -0.2688,  0.1366,  0.0313,
          -0.0995, -0.1208, -0.6254],
 

In [None]:
convert_sync_batchnorm(model)

Check module:  TransformerAE(
  (linear1): Linear(in_features=128, out_features=128, bias=True)
  (transformer1): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
    (linear1): Linear(in_features=128, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=128, bias=True)
    (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (transformer2): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
    (linear1): Linear(in_features=128, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2)