In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Use GPU if available, else use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Default Transformer class
Given an input of 

`(batch_size, seq_length, d_model)`, 

generate 

`(batch_size, out_seq_length, d_model)`.

During training, target is also 

`(batch_size, out_seq_length, d_model)`

In [3]:
model = nn.Transformer(nhead=8, num_encoder_layers=2, num_decoder_layers=2, batch_first=True)
print(model)

Transformer(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, o

In [4]:
src = torch.rand((32, 10, 512))
tgt = torch.rand((32, 20, 512))
out = model(src, tgt)
print(out.shape)

torch.Size([32, 20, 512])


## A basic Transformer that could work for ECG

In [5]:
class ECGTransformer(nn.Module):
    def __init__(self, d_model, num_classes=63, nhead=8, num_encoder_layers=2, dim_feedforward=2048):
        super().__init__()
        
        # Define encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, batch_first=True)

        # Encoder stack
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, num_classes),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        encoded = self.transformer(x)
        # encoded shape: (batch_size, seq_len, d_model)
        # Pick out only the last in the sequence for classification
        encoded = encoded[:, -1, :]
        result = self.classifier(encoded)
        return result

In [6]:
model = ECGTransformer(d_model=12, nhead=4, num_classes=63, num_encoder_layers=6, dim_feedforward=512)
model.to(device)
print(model)

ECGTransformer(
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (linear1): Linear(in_features=12, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=12, bias=True)
        (norm1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (classifier): Sequential(
    (0): Linear(in_features=12, out_features=63, bias=True)
    (1): Sigmoid()
  )
)


In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(total_params)

81723


In [8]:
for name, module in model.named_modules():
    params = sum(p.numel() for p in module.parameters())
    print(f"{name}: {params} parameters")

: 81723 parameters
transformer: 80904 parameters
transformer.layers: 80904 parameters
transformer.layers.0: 13484 parameters
transformer.layers.0.self_attn: 624 parameters
transformer.layers.0.self_attn.out_proj: 156 parameters
transformer.layers.0.linear1: 6656 parameters
transformer.layers.0.dropout: 0 parameters
transformer.layers.0.linear2: 6156 parameters
transformer.layers.0.norm1: 24 parameters
transformer.layers.0.norm2: 24 parameters
transformer.layers.0.dropout1: 0 parameters
transformer.layers.0.dropout2: 0 parameters
transformer.layers.1: 13484 parameters
transformer.layers.1.self_attn: 624 parameters
transformer.layers.1.self_attn.out_proj: 156 parameters
transformer.layers.1.linear1: 6656 parameters
transformer.layers.1.dropout: 0 parameters
transformer.layers.1.linear2: 6156 parameters
transformer.layers.1.norm1: 24 parameters
transformer.layers.1.norm2: 24 parameters
transformer.layers.1.dropout1: 0 parameters
transformer.layers.1.dropout2: 0 parameters
transformer.laye

In [None]:
inputs = torch.rand((2, 5000, 12)).to(device)
out = model(inputs)
print(out.shape)
print(out[0])

  proj = linear(q, w, b)


torch.Size([3, 63])
tensor([0.3445, 0.6134, 0.3308, 0.5061, 0.6985, 0.6163, 0.4630, 0.4196, 0.3788,
        0.3524, 0.5629, 0.4439, 0.5261, 0.7331, 0.4000, 0.7133, 0.3199, 0.2308,
        0.7299, 0.4106, 0.6710, 0.6334, 0.6253, 0.6197, 0.4280, 0.7458, 0.3168,
        0.5989, 0.4369, 0.5994, 0.5256, 0.4640, 0.7689, 0.5553, 0.6169, 0.4644,
        0.3360, 0.7362, 0.6294, 0.6007, 0.5448, 0.7205, 0.3767, 0.3987, 0.2950,
        0.4072, 0.6089, 0.4929, 0.2895, 0.5062, 0.6890, 0.2740, 0.5677, 0.4495,
        0.3671, 0.4324, 0.3318, 0.4679, 0.3701, 0.5554, 0.7422, 0.2648, 0.6657],
       device='cuda:0', grad_fn=<SelectBackward0>)


In [10]:
print(out)

tensor([[0.3445, 0.6134, 0.3308, 0.5061, 0.6985, 0.6163, 0.4630, 0.4196, 0.3788,
         0.3524, 0.5629, 0.4439, 0.5261, 0.7331, 0.4000, 0.7133, 0.3199, 0.2308,
         0.7299, 0.4106, 0.6710, 0.6334, 0.6253, 0.6197, 0.4280, 0.7458, 0.3168,
         0.5989, 0.4369, 0.5994, 0.5256, 0.4640, 0.7689, 0.5553, 0.6169, 0.4644,
         0.3360, 0.7362, 0.6294, 0.6007, 0.5448, 0.7205, 0.3767, 0.3987, 0.2950,
         0.4072, 0.6089, 0.4929, 0.2895, 0.5062, 0.6890, 0.2740, 0.5677, 0.4495,
         0.3671, 0.4324, 0.3318, 0.4679, 0.3701, 0.5554, 0.7422, 0.2648, 0.6657],
        [0.4353, 0.4194, 0.3749, 0.4381, 0.5999, 0.6827, 0.2940, 0.3459, 0.5030,
         0.5124, 0.5436, 0.6546, 0.4196, 0.4524, 0.6050, 0.6886, 0.3235, 0.3779,
         0.5857, 0.6351, 0.8313, 0.7327, 0.6514, 0.5495, 0.3315, 0.5373, 0.4017,
         0.5286, 0.6734, 0.3994, 0.6043, 0.3770, 0.6200, 0.6178, 0.5695, 0.7738,
         0.6191, 0.6293, 0.5350, 0.4410, 0.8078, 0.7634, 0.3866, 0.6132, 0.6401,
         0.2863, 0.6836, 0.