# Seismic Transformer V3.0 cookbook

*Author: Jason Jiang (Xunfun Lee)*

*Date: 2023.12.31*

> **Note:** This is a cookbook for Seismic Transformer V3.0. It has a vast differece with the previous version. I almost rewrite all the code including MLP Block, MHA Block, `train_step()`, `validation_step()`, `train()`, `test()` and so on, so make sure you are reading the corret version of them inside PythonScripts. 
> The reason is that SeT-3 is more complicated than SeT-1 and SeT-2 after adding the decoder and splicer into the model. In SeT-3 I also update the mask process(for the reason that it is totally a mess in SeT-2 and there are two masks in SeT-3 rather than 1 mask in SeT-2), frequency, token embedding(more elegant). In conclusion, it is a fresh new and efficient model.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary


from PythonScripts.utility import SetDevice

In [2]:
device = SetDevice()

GPU: cuda
CUDA device numbers:  1


## Encoder

### MLP Block

In [3]:
# Define the MLP block class
class MLPBlock(nn.Module):
    def __init__(self, 
                 hidden_size: int = 768, 
                 fc_hidden_size: int = 3072, 
                 dropout_rate: float = 0.1):

        super(MLPBlock, self).__init__()
        
        # Pre-Layer Normalization
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.linear1 = nn.Linear(hidden_size, fc_hidden_size)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout_rate)
        self.linear2 = nn.Linear(fc_hidden_size, hidden_size)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x):
        # Save the residual
        residual = x
        # Apply pre-layer normalization
        x = self.layer_norm(x)
        # First fully connected layer
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.dropout1(x)
        # Second fully connected layer
        x = self.linear2(x)
        x = self.dropout2(x)
        # Add the residual
        x = x + residual
        # Implementing the residual connection
        return x


In [4]:
# Define test functions
def test_mlp_block_summary(mlp_block_instance:nn.Module,
                           input_data:torch.Tensor):
    # Using torchinfo to summarize the model
    print(summary(mlp_block_instance, input_data=input_data))

def test_mlp_block_forward(mlp_block_instance, input_tensor):
    # Forward pass and print output shape
    output = mlp_block_instance(input_tensor)
    print("Output shape:", output.shape)

In [5]:
batch_size = 64

# Create MLP block instance for testing
mlp_block_instance = MLPBlock().to(device)

# Generate a random tensor for testing
input_tensor_for_testing = torch.rand(batch_size, 14, 768).to(device)  # Batch size of 14

# Call test functions
print("Testing MLPBlock with torchinfo.summary:")
test_mlp_block_summary(mlp_block_instance, input_tensor_for_testing)

print("\nTesting MLPBlock with forward pass:")
test_mlp_block_forward(mlp_block_instance, input_tensor_for_testing)


Testing MLPBlock with torchinfo.summary:
Layer (type:depth-idx)                   Output Shape              Param #
MLPBlock                                 [64, 14, 768]             --
├─LayerNorm: 1-1                         [64, 14, 768]             1,536
├─Linear: 1-2                            [64, 14, 3072]            2,362,368
├─GELU: 1-3                              [64, 14, 3072]            --
├─Dropout: 1-4                           [64, 14, 3072]            --
├─Linear: 1-5                            [64, 14, 768]             2,360,064
├─Dropout: 1-6                           [64, 14, 768]             --
Total params: 4,723,968
Trainable params: 4,723,968
Non-trainable params: 0
Total mult-adds (M): 302.33
Input size (MB): 2.75
Forward/backward pass size (MB): 33.03
Params size (MB): 18.90
Estimated Total Size (MB): 54.68

Testing MLPBlock with forward pass:
Output shape: torch.Size([64, 14, 768])


### MHA Block

In [6]:
class MHABlock(nn.Module):
    """Multi-head attention block class, can be used in the Encoder, Decoder, and Cross-Attention parts of the Transformer model.

    Args:
        hidden_size (int): Hidden size of the input tensor. Defaults to 768.
        num_heads (int): Number of attention heads. Defaults to 12.
        dropout_attn (float): Dropout rate. Defaults to 0.1.
    """
    def __init__(self, 
                 hidden_size: int = 768, 
                 num_heads: int = 12, 
                 dropout_attn: float = 0.1,
                 batch_first: bool = True):
        
        super(MHABlock, self).__init__()
        
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.multihead_attn = nn.MultiheadAttention(embed_dim=hidden_size, 
                                                    num_heads=num_heads, 
                                                    dropout=dropout_attn,
                                                    batch_first=batch_first)
        self.dropout = nn.Dropout(dropout_attn)

    def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, need_weights=True):
        # Save the residual
        residual = query
        # Apply pre-layer normalization
        normed_query = self.layer_norm(query)
        normed_key = self.layer_norm(key)
        normed_value = self.layer_norm(value)

        # Multi-head attention
        attn_output, attn_output_weights = self.multihead_attn(normed_query,
                                                               normed_key,
                                                               normed_value,
                                                               key_padding_mask=key_padding_mask,
                                                               attn_mask=attn_mask,
                                                               need_weights=need_weights)
        # Apply dropout
        attn_output = self.dropout(attn_output)
        # Add the residual
        attn_output = attn_output + residual
        
        return attn_output, attn_output_weights

In [7]:
# Instantiate three different instances of MHABlock for encoder, masked MHA (typically used in decoder), and cross MHA.

# Instance for the encoder's MHA
encoder_mha = MHABlock(hidden_size=768, num_heads=12, dropout_attn=0.1, batch_first=True)

# Instance for the decoder's masked MHA
# For masked MHA in the decoder, we need to ensure that the attention mechanism 
# does not attend to subsequent positions. This is typically handled outside the MHABlock,
# by providing an appropriate mask to the 'forward' method.
decoder_masked_mha = MHABlock(hidden_size=768, num_heads=12, dropout_attn=0.1, batch_first=True)

# Instance for the cross MHA (used in decoder to attend over encoder's outputs)
cross_mha = MHABlock(hidden_size=768, num_heads=12, dropout_attn=0.1, batch_first=True)

(encoder_mha, decoder_masked_mha, cross_mha)  # Return the instances


(MHABlock(
   (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
   (multihead_attn): MultiheadAttention(
     (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
   )
   (dropout): Dropout(p=0.1, inplace=False)
 ),
 MHABlock(
   (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
   (multihead_attn): MultiheadAttention(
     (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
   )
   (dropout): Dropout(p=0.1, inplace=False)
 ),
 MHABlock(
   (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
   (multihead_attn): MultiheadAttention(
     (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
   )
   (dropout): Dropout(p=0.1, inplace=False)
 ))

In [8]:
# Define test functions for each type of MHABlock instance with appropriate masks

def test_encoder_mha(encoder_mha_instance, input_tensor):
    """
    Test the encoder's MHABlock instance. 
    The encoder does not require a special mask for self-attention.
    """
    print("Testing Encoder's MHABlock with forward pass:")
    output, _ = encoder_mha_instance(input_tensor, input_tensor, input_tensor)
    print("Output shape:", output.shape)
    return output

def test_encoder_mha_with_padding_mask(encoder_mha_instance, input_tensor, padding_mask):
    """
    Test the encoder's MHABlock instance with a padding mask.
    The padding mask is used to ignore padded positions in the input sequence.
    """
    print("Testing Encoder's MHABlock with Padding Mask and forward pass:")
    output, _ = encoder_mha_instance(input_tensor, input_tensor, input_tensor, key_padding_mask=padding_mask)
    print("Output shape:", output.shape)
    return output

def test_decoder_masked_mha(decoder_masked_mha_instance, input_tensor, attn_mask):
    """
    Test the decoder's masked MHABlock instance with a correctly shaped attention mask.
    The mask is used to prevent positions from attending to subsequent positions.
    """
    print("Testing Decoder's Masked MHABlock with Correctly Shaped Attn Mask and forward pass:")
    output, _ = decoder_masked_mha_instance(input_tensor, input_tensor, input_tensor, attn_mask=attn_mask)
    print("Output shape:", output.shape)
    return output

def test_cross_mha(cross_mha_instance, input_tensor, memory_tensor):
    """
    Test the cross MHABlock instance. 
    The cross-attention does not require a special mask in this context.
    """
    print("Testing Cross MHABlock with forward pass:")
    output, _ = cross_mha_instance(input_tensor, memory_tensor, memory_tensor)
    print("Output shape:", output.shape)
    return output

#### Test encoder MHA

key_padding_mask.size = torch.Size([batch_size, seq_len]) (dtype = torch.bool)

In [9]:
encoder_input = torch.rand(batch_size, 14, 768)  # (L, N, E) for encoder
output_encoder = test_encoder_mha(encoder_mha, encoder_input)

padding_mask = torch.zeros(batch_size, 14, dtype=torch.bool)
output_encoder_mask = test_encoder_mha_with_padding_mask(encoder_mha, encoder_input, padding_mask)

output_encoder[:,1:13,:].shape

Testing Encoder's MHABlock with forward pass:
Output shape: torch.Size([64, 14, 768])
Testing Encoder's MHABlock with Padding Mask and forward pass:
Output shape: torch.Size([64, 14, 768])


torch.Size([64, 12, 768])

#### Test decoder: masked MHA

attn_mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).bool()

`torch.triu` 是“triangular upper”的缩写，表示生成一个上三角矩阵。

参数 `diagonal=1` 表示对角线上方一行的元素（而不是对角线上的元素）开始保持原有的值（在这里是1），对角线及以下的元素设置为0。

`.bool()` 将上述上三角矩阵转换为布尔类型，其中1变成 `True` ，0变成 `False` 

In [10]:
# Perform tests
decoder_input = torch.rand(batch_size, 12, 768)  # (L, N, E) for decoder

seq_length = decoder_input.shape[1]
attn_mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).bool()
output_decoder = test_decoder_masked_mha(decoder_masked_mha, decoder_input, attn_mask)

output_decoder.shape

Testing Decoder's Masked MHABlock with Correctly Shaped Attn Mask and forward pass:
Output shape: torch.Size([64, 12, 768])


torch.Size([64, 12, 768])

#### Test decoder: cross attention MHA

In [11]:
memory_tensor = output_encoder[:,1:13,:]
output_cross = test_cross_mha(cross_mha, output_decoder, memory_tensor)
output_cross.shape

Testing Cross MHABlock with forward pass:
Output shape: torch.Size([64, 12, 768])


torch.Size([64, 12, 768])

### Patch Embedding Block (PE Block)

In [12]:
# Define a class to perform the operations on the left side of the diagram (slice, linear)
class PatchEmbeddingBlock(nn.Module):
    def __init__(self, 
                 len_gm:int=3000, 
                 patch_size:int=250, 
                 output_size:int=768):
        super(PatchEmbeddingBlock, self).__init__()

        # Assume that slice_size is a two-dimensional tuple (seq_len, features)
        self.patch_size = patch_size
        self.output_size = output_size
        self.linear = nn.Linear(patch_size, output_size)
        self.num_of_patches = len_gm // patch_size

    def forward(self, x):
        # verify the input shape
        assert x.shape[1] == self.num_of_patches * self.patch_size, \
            f'Input sequence length should be {self.num_of_patches * self.patch_size}'
        
        # [batch_size, sequence_length, 1] --> [batch_size, num_patches, patch_size]
        # e.g. [64, 3000, 1] --> [64, 12, 250]
        x = x.view(-1, self.num_of_patches, self.patch_size)

        # Reshape for the linear layer
        x = self.linear(x)
        # Reshape to the desired output size (batch_size, seq_len, output_size)
        return x

In [13]:
def test_PEBlock(PEBlock_instance, input_tensor):
    """
    Test the cross MHABlock instance. 
    The cross-attention does not require a special mask in this context.
    """
    print("Testing Cross MHABlock with forward pass:")
    output = PEBlock_instance(input_tensor)
    print("Output shape:", output.shape)
    return output

In [14]:
PEBlock_Instance = PatchEmbeddingBlock().to(device)
input_PE = torch.rand(batch_size, 3000, 1).to(device)
input_PE.shape

torch.Size([64, 3000, 1])

In [15]:
output_PE = test_PEBlock(PEBlock_Instance, input_PE)
output_PE.shape

Testing Cross MHABlock with forward pass:
Output shape: torch.Size([64, 12, 768])


torch.Size([64, 12, 768])

### Frequency Embedding Block (FE Block)

In [16]:
import torch
import numpy as np
import torch.nn as nn

class FreqEmbeddingBlock(nn.Module):
  def __init__(self, 
               conv_output_size:int=750, 
               linear_output_size:int=768):
    
    super(FreqEmbeddingBlock, self).__init__()

    self.fft = np.fft.rfft
    self.conv1d = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=2, stride=2)
    self.linear = nn.Linear(conv_output_size, linear_output_size)

  def forward(self, x):
    # FFT
    x = torch.fft.rfft(x).real
    # Keep only the first 1500 elements (as FFT will return N/2+1 elements for real input)
    x = x[:, :1500, :]
    # Convolution
    x = x.permute(0, 2, 1)
    x = self.conv1d(x)
    # Linear transformation
    x = self.linear(x)

    return x


In [17]:
FEBlock_Instance = FreqEmbeddingBlock().to(device)
input_FE = torch.rand(64, 3000, 1).to(device)

In [18]:
def test_FEBlock(FEBlock_Instance, input_tensor):
    """
    Test the cross MHABlock instance. 
    The cross-attention does not require a special mask in this context.
    """
    print("Testing Cross MHABlock with forward pass:")
    output = FEBlock_Instance(input_tensor)
    print("Output shape:", output.shape)
    return output

In [19]:
output = test_FEBlock(FEBlock_Instance, input_FE)

Testing Cross MHABlock with forward pass:


Output shape: torch.Size([64, 1, 768])


## Encoder Block

In [20]:
class EncoderBlock(nn.Module):
    def __init__(self, 
                 hidden_size: int = 768, 
                 num_heads: int = 12, 
                 fc_hidden_size: int = 3072, 
                 dropout_attn: float = 0.1,
                 dropout_mlp: float = 0.1,
                 batch_first: bool = True):
        
        super(EncoderBlock, self).__init__()
        
        self.mha_block = MHABlock(hidden_size=hidden_size, 
                                  num_heads=num_heads, 
                                  dropout_attn=dropout_attn,
                                  batch_first=batch_first)
        
        self.mlp_block = MLPBlock(hidden_size=hidden_size, 
                                  fc_hidden_size=fc_hidden_size, 
                                  dropout_rate=dropout_mlp)
        
    def forward(self, x, key_padding_mask=None, need_weights=True):
        # Multi-head attention block
        x, attn_weights = self.mha_block(x, x, x, key_padding_mask=key_padding_mask, need_weights=need_weights)
        # MLP block
        x = self.mlp_block(x)

        return x, attn_weights

In [21]:
def test_EBBlock(EncoderBlock_Instance, input_tensor):
    """
    Test the cross MHABlock instance. 
    The cross-attention does not require a special mask in this context.
    """
    print("Testing Cross MHABlock with forward pass:")
    output, attn_weights = EncoderBlock_Instance(input_tensor)
    print("Output shape:", output.shape)
    return output, attn_weights

In [22]:
EncoderBlock_Instance = EncoderBlock().to(device)
input_EB = torch.rand(64, 14, 768).to(device)

In [23]:
output, attn_weights = test_EBBlock(EncoderBlock_Instance, input_EB)
output.shape, attn_weights.shape

Testing Cross MHABlock with forward pass:
Output shape: torch.Size([64, 14, 768])


(torch.Size([64, 14, 768]), torch.Size([64, 14, 14]))

## Encoder

In [24]:
class EncoderV1(nn.Module):

    def __init__(self,
                 len_gm:int=3000,
                 patch_size:int=250,
                 hidden_size:int=768,
                 num_heads:int=12,
                 num_layers:int=12,
                 dropout_attn:float=0.1,
                 dropout_mlp:float=0.1,
                 dropout_embed:float=0.1):

        super().__init__()

        # Calculate the number of patches
        self.num_of_patch = len_gm // patch_size

        # Initialize a variable to stroe the attention weights
        self.attention_weights_list = []  # Initialize it here
        
        # BLOCK
        # patch embedding
        self.PatchEmbedding = PatchEmbeddingBlock(len_gm=len_gm,
                                      patch_size=patch_size,
                                      output_size=hidden_size)
        
        # frequency embedding
        self.FreqEmbedding = FreqEmbeddingBlock(conv_output_size=len_gm // 2 // 2,         # default is 750
                                     linear_output_size=hidden_size)

        # encoder layer
        self.EncoderLayers = nn.Sequential(*[EncoderBlock(hidden_size=hidden_size,
                                                          num_heads=num_heads,
                                                          fc_hidden_size=hidden_size*4,
                                                          dropout_attn=dropout_attn,
                                                          dropout_mlp=dropout_mlp) for _ in range(num_layers)])

        # [TOKEN]
        # [TIME] - time token
        self.time_token = nn.Parameter(torch.randn(1, 1, hidden_size),
                                       requires_grad=True)  # trainable parameter

        # [FREQ] - frequency token
        self.freq_token = nn.Parameter(torch.randn(1, 1, hidden_size),
                                       requires_grad=True)  # trainable parameter

        # [CLS] - class token
        self.class_token = nn.Parameter(torch.randn(1, 1, hidden_size),
                                        requires_grad=True)  # trainable parameter

        # POSITION
        # positional embedding
        self.positional_embedding = nn.Parameter(torch.randn(1, self.num_of_patch+2, hidden_size),
                                                  requires_grad=True)  # trainable parameter
        
        # Dropout
        self.embedding_dropout = nn.Dropout(dropout_embed)


    def forward(self, x, key_padding_mask=None, need_weights=True):

        # Get the batch size
        batch_size = x.shape[0]

        # clear the attention weights list
        self.attention_weights_list = []

        # patch embedding
        time_sequence = self.PatchEmbedding(x)

        # [TIME] token
        time_tokens = self.time_token.repeat(batch_size, self.num_of_patch, 1)

        # concatenate the time sequence with the time tokens
        time_sequence_with_token = time_sequence + time_tokens                                      # [batch_size, 12, hidden_size]

        # frequency embedding
        freq_sequence = self.FreqEmbedding(x)

        # [FREQ] token
        freq_tokens = self.freq_token.repeat(batch_size, 1, 1)

        # concatenate the frequency sequence with the frequency tokens
        freq_sequence_with_token = freq_sequence + freq_tokens                                      # [batch_size, 1, hidden_size]

        # cat the time sequence and the frequency sequence
        sequence_combine = torch.cat((time_sequence_with_token, freq_sequence_with_token), dim=1)   # [batch_size, 13, hidden_size]

        # [CLS] token
        class_tokens = self.class_token.expand(batch_size, -1, -1) # "-1" means to infer the dimension (try this line on its own)

        # concatenate the class token with the sequence
        sequence_combine_with_cls = torch.cat((class_tokens, sequence_combine), dim=1)              # [batch_size, 14, hidden_size]

        # embedding dropout
        x = self.embedding_dropout(sequence_combine_with_cls)

        # Encoder Layer
        for layer in self.EncoderLayers:
            x, attn_weights = layer(x, key_padding_mask=key_padding_mask, need_weights=need_weights)
            self.attention_weights_list.append(attn_weights)

        return x

In [25]:
def test_Encoder(Encoder_Instance, input_tensor):
    """
    Test the cross MHABlock instance. 
    The cross-attention does not require a special mask in this context.
    """
    print("Testing Cross MHABlock with forward pass:")
    output = Encoder_Instance(input_tensor)
    print("Output shape:", output.shape)
    return output

In [26]:
EncoderV1_Instance = EncoderV1().to(device)
input_Encoder = torch.rand(64, 3000, 1).to(device)

In [27]:
output = test_Encoder(EncoderV1_Instance, input_Encoder)
output.shape

Testing Cross MHABlock with forward pass:
Output shape: torch.Size([64, 14, 768])


torch.Size([64, 14, 768])

## Decoder

### Decoder Block

In [28]:
class DecoderBlock(nn.Module):
    def __init__(self, 
                 hidden_size: int = 768, 
                 num_heads: int = 12, 
                 fc_hidden_size: int = 3072, 
                 dropout_attn: float = 0.1,
                 dropout_mlp: float = 0.1,
                 batch_first: bool = True):
        
        super(DecoderBlock, self).__init__()
        
        self.mmha_block = MHABlock(hidden_size=hidden_size, 
                                  num_heads=num_heads, 
                                  dropout_attn=dropout_attn,
                                  batch_first=batch_first)
        
        self.cmha_block = MHABlock(hidden_size=hidden_size, 
                                  num_heads=num_heads, 
                                  dropout_attn=dropout_attn,
                                  batch_first=batch_first)
        
        self.mlp_block = MLPBlock(hidden_size=hidden_size, 
                                  fc_hidden_size=fc_hidden_size, 
                                  dropout_rate=dropout_mlp)
        
    def forward(self, query, key, value, output_encoder, attn_mask=None, need_weights=True):

        # Masked Multi-head attention block
        mmha_output, mmha_attn_weights = self.mmha_block(query, key, value, need_weights=need_weights)

        # Cross Multi-head attention block
        cmha_output, cmha_attn_weights = self.cmha_block(mmha_output, output_encoder, output_encoder, attn_mask=attn_mask, need_weights=need_weights)

        # MLP block
        output = self.mlp_block(cmha_output)

        return output, mmha_attn_weights, cmha_attn_weights

In [29]:
class DecoderV1(nn.Module):
    
    def __init__(self,
                 len_gm:int=3000,
                 patch_size:int=250,
                 hidden_size:int=768,
                 num_heads:int=12,
                 num_layers:int=12,
                 dropout_attn:float=0.1,
                 dropout_mlp:float=0.1,
                 dropout_embed:float=0.1,
                 device:torch.device="cuda"):

        super().__init__()

        # Calculate the number of patches
        self.num_of_patch = len_gm // patch_size

        # Initialize a variable to stroe the attention weights
        self.mmha_attn_weights_list = []  # Initialize it here
        self.cmha_attn_weights_list = []  # Initialize it here
        
        # BLOCK
        # patch embedding
        self.PatchEmbedding = PatchEmbeddingBlock(len_gm=len_gm,
                                      patch_size=patch_size,
                                      output_size=hidden_size)

        # encoder layer
        self.DecoderLayers = nn.Sequential(*[DecoderBlock(hidden_size=hidden_size,
                                                          num_heads=num_heads,
                                                          fc_hidden_size=hidden_size*4,
                                                          dropout_attn=dropout_attn,
                                                          dropout_mlp=dropout_mlp) for _ in range(num_layers)])

        # POSITION
        # positional embedding
        self.positional_embedding = nn.Parameter(torch.randn(1, self.num_of_patch, hidden_size),
                                                  requires_grad=True)  # trainable parameter
        
        # Dropout
        self.embedding_dropout = nn.Dropout(dropout_embed)

        # Set device
        self.device = device

    def forward(self, output_encoder, target_sequence=None, attn_mask=None, need_weights=True):
        # Check if target_sequence is provided
        if target_sequence is not None:
            # Training mode
            # 使用target_sequence（目标序列）作为输入
            x = target_sequence
        else:
            # Inference mode
            # 初始化一个序列来逐步构建输出
            x = self._init_sequence(batch_size=output_encoder.shape[0])

        # Patch embedding and position encoding can be done here if needed
        # patch embedding
        time_sequence = self.PatchEmbedding(x)

        time_sequence = time_sequence + self.positional_embedding

        # embedding dropout
        x = self.embedding_dropout(time_sequence)

                # clear the attention weights list
        self.mmha_attn_weights_list = []
        self.cmha_attn_weights_list = []

        # Encoder Layer
        for layer in self.DecoderLayers:
            x, mmha_attn_weights, cmha_attn_weights = layer(query=x, 
                                                            key=x,
                                                            value=x,
                                                            output_encoder=output_encoder,
                                                            attn_mask=attn_mask, need_weights=need_weights)
            # Store attention weights if needed
            self.mmha_attn_weights_list.append(mmha_attn_weights)
            self.cmha_attn_weights_list.append(cmha_attn_weights)

        return x

    def generate_sequence(self, output_encoder, attn_mask=None):
        # Initial sequence generation for inference mode
        print(output_encoder.shape[0])
        generated_sequence = self._init_sequence(batch_size=output_encoder.shape[0])
        for _ in range(self.num_of_patch):
            # Assume that you append to generated_sequence at each step
            # You may need to modify this loop to match your actual sequence generation process
            generated_sequence = self.forward(output_encoder, generated_sequence, attn_mask)
        
        return generated_sequence

    def _init_sequence(self, batch_size):
        # Initialize the sequence for the decoder to start generating the output
        # This can be zeros, learned embeddings, or some form of encoder output processing
        initial_sequence = torch.zeros((batch_size, 3000, 1)).to(self.device)
        # Modify this to suit how you want to start sequence generation
        return initial_sequence


In [30]:
def test_Decoder(Decoder_Instance, input_decoder, output_encoder, attn_mask=None):
    """
    Test the cross MHABlock instance. 
    The cross-attention does not require a special mask in this context.
    """
    print("Testing Cross MHABlock with forward pass:")
    output = Decoder_Instance(output_encoder, input_decoder, attn_mask=attn_mask, need_weights=True)
    print("Output shape:", output.shape)
    return output

In [31]:
DecoderV1_Instance = DecoderV1().to(device)
input_decoder = torch.rand(64, 3000, 1).to(device)
output_encoder = torch.rand(64, 12, 768).to(device)
attn_mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).bool().to(device)

In [32]:
# training mode, with target sequence
output = test_Decoder(DecoderV1_Instance, input_decoder, output_encoder, attn_mask=attn_mask)

Testing Cross MHABlock with forward pass:
Output shape: torch.Size([64, 12, 768])


In [33]:
# inference mode, generate sequence
output = test_Decoder(DecoderV1_Instance, None, output_encoder, attn_mask=attn_mask)

Testing Cross MHABlock with forward pass:
Output shape: torch.Size([64, 12, 768])


## Classifier

In [34]:
class ClassifierV1(nn.Module):

    def __init__(self,
                 hidden_size:int=768,
                 num_of_classes:int=5) -> None:

        super().__init__()
        
        # LN
        self.LayerNorm = nn.LayerNorm(normalized_shape=hidden_size)
        # Linear
        self.Linear = nn.Linear(in_features=hidden_size, out_features=num_of_classes)
    
    def forward(self, x):
        # LN
        x = self.LayerNorm(x)
        # Linear                        [N, 768]
        logits = self.Linear(x)

        return logits

## Splicer

In [35]:
class SplicerV1(nn.Module):

    def __init__(self, 
                 hidden_size:int=768,
                 patch_size:int=250,
                 len_gm:int=3000) -> None:

        super().__init__()

        # LN
        self.LayerNorm = nn.LayerNorm(normalized_shape=hidden_size)

        # linear1: [batch_size, 12, 768] --> [batch_size, 12, 250]
        self.Linear1 = nn.Linear(in_features=hidden_size, out_features=patch_size)

        # linear2: [batch_size, 3000, 1] --> [batch_size, 3000, 1]
        self.Linear2 = nn.Linear(in_features=len_gm, out_features=len_gm)

    def forward(self, x):
        # LN
        x = self.LayerNorm(x)
        # linear1 with GELU
        x = F.gelu(self.Linear1(x))
        # [N, 12, 250] --> [N, 3000]
        x = x.view(x.size(0), -1)  
        # linear2
        x = self.Linear2(x)
        # [N, 3000, 1]
        x = x.view(x.size(0), -1, 1)

        return x 
        

In [36]:
model = SplicerV1().to(device)
input_tensor = torch.rand(64, 12, 768).to(device)

In [37]:
output = model(input_tensor)
output.shape

torch.Size([64, 3000, 1])

## Final test


In [38]:
input_gm = torch.rand(64, 3000, 1).to(device)
input_floorResponse = torch.rand(64, 3000, 1).to(device)

In [39]:
EncoderV1 = EncoderV1().to(device)
DecoderV1 = DecoderV1().to(device)
ClassifierV1 = ClassifierV1().to(device)
SplicerV1 = SplicerV1().to(device)

In [40]:
key_padding_mask = torch.zeros(64, 14, dtype=torch.bool).to(device)
attn_mask = torch.triu(torch.ones(12, 12), diagonal=1).bool().to(device)

### Encoder

In [41]:
output_encoder = EncoderV1(input_gm, key_padding_mask=key_padding_mask)
output_encoder.shape

torch.Size([64, 14, 768])

### Classifier

In [42]:
output_classfier = ClassifierV1(output_encoder)
output_classfier.shape

torch.Size([64, 14, 5])

### Decoder

In [43]:
# input_decoder = torch.rand(64, 3000, 1).to(device)
# output_decoder = DecoderV1(output_encoder[:,1:13,:], input_decoder, attn_mask=attn_mask)
# output_decoder.shape

### Splicer

In [44]:
# output_splicer = SplicerV1(output_decoder)
# output_splicer.shape

> **Note:** Haha, now our SeismicTransformer is ready to go! Let's build entire model the train step for it!

----------------

## Training

This is a multi task learning problem, so we need to define the loss function for each task.

- Task 1: Classification (damage state of the building)

- Task 2: Regression (dynamic response of the top floor)

It is wise to create a big class containing all the module and caculate each loss within each task, but using a total loss and optimizer to train the model.

### SeismicTransformer V3.0

In [45]:
class SeismicTransformerV3(nn.Module):

    def __init__(self,
                 len_gm:int=3000,
                 patch_size:int=250,
                 hidden_size:int=768,
                 num_heads:int=12,
                 num_layers:int=12,
                 dropout_attn:float=0.1,
                 dropout_mlp:float=0.1,
                 dropout_embed:float=0.1,
                 num_of_classes:int=5):

        super().__init__()

        # Encoder
        self.encoder = EncoderV1(len_gm=len_gm,
                                 patch_size=patch_size,
                                 hidden_size=hidden_size,
                                 num_heads=num_heads,
                                 num_layers=num_layers,
                                 dropout_attn=dropout_attn,
                                 dropout_mlp=dropout_mlp,
                                 dropout_embed=dropout_embed)
        
        # Decoder
        self.decoder = DecoderV1(len_gm=len_gm,
                                 patch_size=patch_size,
                                 hidden_size=hidden_size,
                                 num_heads=num_heads,
                                 num_layers=num_layers,
                                 dropout_attn=dropout_attn,
                                 dropout_mlp=dropout_mlp,
                                 dropout_embed=dropout_embed)

        # Classifier
        self.classifier = ClassifierV1(hidden_size=hidden_size,
                                       num_of_classes=num_of_classes)
        
        # Splicer
        self.splicer = SplicerV1(hidden_size=hidden_size,
                                 patch_size=patch_size,
                                 len_gm=len_gm)
        
    def forward(self, input_sequence, target_sequence=None, key_padding_mask=None, attn_mask=None):
        # Encoder output
        encoder_output = self.encoder(input_sequence, key_padding_mask=key_padding_mask)

        encoder_output_to_decoder = encoder_output[:,1:13,:]

        # If target sequence is provided, we are in training mode, otherwise we are in inference mode
        if target_sequence is not None:
            # training mode
            decoder_output = self.decoder(output_encoder=encoder_output_to_decoder, 
                                          target_sequence=target_sequence, 
                                          attn_mask=attn_mask,
                                          need_weights=True)
            # Splicer forward pass to generate the dynamic response
            dynamic_response = self.splicer(decoder_output)
        else:
            # inference mode
            decoder_output = self.decoder(output_encoder=encoder_output_to_decoder,
                                          target_sequence=None,
                                          attn_mask=attn_mask,
                                          need_weights=True)
            
            dynamic_response = self.splicer(decoder_output)

        # Classifier forward pass to determine the damage state
        # damage_state is logits
        damage_state = self.classifier(encoder_output)

        return damage_state, dynamic_response
        

In [46]:
SeismicTransformerV3_instance = SeismicTransformerV3().to(device)

input_gm = torch.rand(64, 3000, 1).to(device)
input_floorResponse = torch.rand(64, 3000, 1).to(device)
key_padding_mask = torch.zeros(64, 14, dtype=torch.bool).to(device)
attn_mask = torch.triu(torch.ones(12, 12), diagonal=1).bool().to(device)

TypeError: EncoderV1.forward() got an unexpected keyword argument 'len_gm'

In [None]:
# training mode
damage_state, dynamic_response = SeismicTransformerV3_instance(input_sequence=input_gm, 
                                                            target_sequence=input_floorResponse, 
                                                            key_padding_mask=key_padding_mask, 
                                                            attn_mask=attn_mask)

In [None]:
damage_state.shape, dynamic_response.shape

(torch.Size([64, 14, 5]), torch.Size([64, 3000, 1]))

In [None]:
# inference mode
with torch.inference_mode():
    damage_state, dynamic_response = SeismicTransformerV3_instance(input_sequence=input_gm, 
                                                                target_sequence=None, 
                                                                key_padding_mask=key_padding_mask, 
                                                                attn_mask=attn_mask)

In [None]:
damage_state.shape, dynamic_response.shape

(torch.Size([64, 14, 5]), torch.Size([64, 3000, 1]))

### Loss Function and Optimizer

In [None]:
from torch.nn import CrossEntropyLoss, MSELoss

loss_fn_classification = CrossEntropyLoss()
loss_fn_regression = MSELoss()

optimizer = torch.optim.AdamW(SeismicTransformerV3_instance.parameters(), lr=1e-4, weight_decay=0.01)

### Training Step

In [None]:
import torch
from typing import Tuple

global global_step, warmup_done

def train_step_set3(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader, 
               loss_fn_classification: torch.nn.Module, 
               loss_fn_regression: torch.nn.Module,
               loss_fn_weight_classification: float,
               optimizer: torch.optim.Optimizer, 
               lr_scheduler_warmup: torch.optim.lr_scheduler.LambdaLR, 
               num_warmup_steps: int, 
               device: torch.device) -> Tuple[float, float]:
    
    global global_step, warmup_done

    # Set model to training mode
    model.train()

    # Initialize the loss and classification accuracy
    train_loss, train_acc_classification, train_mse_regression = 0.0, 0.0, 0.0

    for _, (gm_sequence, label, floor_sequence, key_padding_mask, attn_mask) in enumerate(dataloader):

        gm_sequence = gm_sequence.to(device)
        label = label.to(device)
        floor_sequence = floor_sequence.to(device)
        key_padding_mask = key_padding_mask.to(device)
        attn_mask = attn_mask.to(device)

        # Forward pass
        damage_state_pred, dynamic_response = model(encoder_input=gm_sequence,
                                                    decoder_input=floor_sequence,
                                                    key_padding_mask=key_padding_mask,
                                                    attn_mask=attn_mask)

        # Calculate classification and regression losses
        loss_classification = loss_fn_classification(damage_state_pred, label)
        loss_regression = loss_fn_regression(dynamic_response, floor_sequence)

        # Check that weight is between 0 and 1
        assert 0 <= loss_fn_weight_classification <= 1, \
            "loss_fn_weight_classification should be between 0 and 1"

        # Combine losses
        loss = loss_fn_weight_classification * loss_classification + \
               (1 - loss_fn_weight_classification) * loss_regression

        # Accumulate loss
        train_loss += loss.item()

        # Zero gradients, backward pass, and optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update learning rate scheduler
        lr_scheduler_warmup.step()
        global_step += 1

        # Check for warmup completion
        if not warmup_done and global_step >= num_warmup_steps:
            print(f"Warmup completed at step {global_step}")
            warmup_done = True

        # Early stopping in case of NaN loss
        if torch.isnan(loss):
            print("Loss is nan, stopping training.")
            break

        # Calculate and accumulate classification accuracy
        y_pred_class = torch.argmax(torch.softmax(damage_state_pred, dim=1), dim=1)
        train_acc_classification += (y_pred_class == label).sum().item() / label.size(0)

        # Regression MSE
        mse = torch.nn.functional.mse_loss(dynamic_response, floor_sequence, reduction='sum').item()
        train_mse_regression += mse / floor_sequence.numel()

    # Average the accumulated loss and classification accuracy over all batches
    train_loss /= len(dataloader)
    train_acc_classification /= len(dataloader)
    train_mse_regression /= len(dataloader)

    return train_loss, train_acc_classification, train_mse_regression


### Validation step

In [None]:
def validation_step_set3(model: torch.nn.Module,
                         dataloader: torch.utils.data.DataLoader,
                         loss_fn_classification: torch.nn.Module,
                         loss_fn_regression: torch.nn.Module,
                         loss_fn_weight_classification: float,
                         device: torch.device) -> Tuple[float, float, float]:
    model.eval()
    val_loss, val_acc_classification, val_mse_regression = 0.0, 0.0, 0.0

    # inference mode
    with torch.inference_mode():
        for _, (gm_sequence, label, floor_sequence, key_padding_mask, attn_mask) in enumerate(dataloader):
            # Move data to device
            gm_sequence = gm_sequence.to(device)
            label = label.to(device)
            floor_sequence = floor_sequence.to(device)
            key_padding_mask = key_padding_mask.to(device)
            attn_mask = attn_mask.to(device)

            # Forward pass
            damage_state_pred, dynamic_response = model(encoder_input=gm_sequence,
                                                        decoder_input=floor_sequence,
                                                        key_padding_mask=key_padding_mask,
                                                        attn_mask=attn_mask)

            # Calculate classification and regression losses
            loss_classification = loss_fn_classification(damage_state_pred, label)
            loss_regression = loss_fn_regression(dynamic_response, floor_sequence)
            loss = loss_fn_weight_classification * loss_classification + \
                   (1 - loss_fn_weight_classification) * loss_regression

            val_loss += loss.item()
            # Calculate classification accuracy
            y_pred_class = torch.argmax(torch.softmax(damage_state_pred, dim=1), dim=1)
            val_acc_classification += (y_pred_class == label).sum().item() / label.size(0)

            # Regression MSE
            mse = torch.nn.functional.mse_loss(dynamic_response, floor_sequence, reduction='sum').item()
            val_mse_regression += mse / floor_sequence.numel()

    val_loss /= len(dataloader)
    val_acc_classification /= len(dataloader)
    val_mse_regression /= len(dataloader)

    return val_loss, val_acc_classification, val_mse_regression


### `Train()`

In [None]:
import math
from PythonScripts.utility import LogEpochDataV3
from typing import Dict, List
import tqdm

def train_set3(model: torch.nn.Module,
               train_loader: torch.utils.data.DataLoader,
               val_loader: torch.utils.data.DataLoader,
               loss_fn_classification: torch.nn.Module,
               loss_fn_regression: torch.nn.Module,
               loss_fn_weight_classification: float,
               optimizer: torch.optim.Optimizer,
               lr_scheduler_warmup: torch.optim.lr_scheduler.LambdaLR,
               lr_scheduler_decay: torch.optim.lr_scheduler.ReduceLROnPlateau,
               num_warmup_steps: int,
               num_epochs: int,
               device: torch.device,
               log_filename: str) -> Dict[str, List]:
    
    # Create empty results dictionary
    results = {"train_loss": [],
               "train_acc": [],
               "train_mse": [],
               "validation_loss": [],
               "validation_acc": [],
               "validation_mse": [],
               "is_nan": []
    }

    # epoch
    for epoch in tqdm(range(num_epochs)):
        # train step
        train_loss, train_acc, train_mse = train_step_set3(model, train_loader,
                                                           loss_fn_classification,
                                                           loss_fn_regression,
                                                           loss_fn_weight_classification,
                                                           optimizer,
                                                           lr_scheduler_warmup,
                                                           num_warmup_steps,
                                                           device)
        
        # if train loss = nan, break
        if math.isnan(train_loss):
            print(f"Epoch {epoch}:Train loss is NaN. Stopping training.")
            results["is_nan"].append("yes")
            break
        
        # validation step
        val_loss, val_acc, val_mse = validation_step_set3(model, val_loader,
                                                         loss_fn_classification,
                                                         loss_fn_regression,
                                                         loss_fn_weight_classification,
                                                         device)
        
        # if validation loss = nan, break
        if math.isnan(val_loss):
            print(f"Epoch {epoch}:Validation loss is NaN. Stopping training.")
            results["is_nan"].append("yes")
            break


        # put validation_loss to lr_scheduler
        lr_scheduler_decay.step(val_loss)


        # update log file(csv file), need to modify and re-import
        LogEpochDataV3(epoch=epoch,
                     train_loss=train_loss,
                     train_acc=train_acc,
                     train_mse=train_mse,
                     validation_loss=val_loss,
                     validation_acc=val_acc,
                     validation_mse=val_mse,
                     log_filename=log_filename)

        print(
          f"Epoch: {epoch+1} | "
          f"train_loss: {train_loss:.4f} | "
          f"train_acc: {train_acc:.4f} | "
          f"train_mse: {train_mse:.4f} | "
          f"validation_loss: {val_loss:.4f} | "
          f"validation_acc: {val_acc:.4f} | "
          f"validation_mse: {val_mse:.4f}"
        )

        # Update results dictionary
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["train_mse"].append(train_mse)
        results["validation_loss"].append(val_loss)
        results["validation_acc"].append(val_acc)
        results["validation_mse"].append(val_mse)

    return results

### `test()`

In [None]:
from sklearn.metrics import f1_score, recall_score, mean_squared_error
import torch
import numpy as np
from typing import Dict, Tuple

def test_set3(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              device: torch.device) -> Tuple[Dict[str, float], Tuple[np.ndarray, np.ndarray]]:
    '''

    for test:
    Cat all the predictions and true values, then compare them all at once, not one by one.
    As long as the length of the two arrays is the same and fix (3000 for example), this method can work.

    '''

    model.to(device)
    model.eval()  # Set the model to evaluation mode

    # Initialize the correct predictions count and all predictions for classification
    correct_label_preds = 0
    total_label_preds = 0

    y_label_preds = []
    y_label_trues = []

    # Initialize lists for dynamic response predictions and true values
    dynamic_response_preds = []
    dynamic_response_trues = []
    
    with torch.inference_mode():
        
        for _, (gm_sequence, label, floor_sequence, key_padding_mask, attn_mask) in enumerate(dataloader):

            gm_sequence = gm_sequence.to(device)
            label = label.to(device)
            floor_sequence = floor_sequence.to(device)
            key_padding_mask = key_padding_mask.to(device)
            attn_mask = attn_mask.to(device)

            # Forward pass
            damage_state_pred, dynamic_response = model(encoder_input=gm_sequence,
                                                        decoder_input=floor_sequence,
                                                        key_padding_mask=key_padding_mask,
                                                        attn_mask=attn_mask)
            
            # Calculate the correct predictions count (classification accuracy)
            _, predicted_labels = torch.max(damage_state_pred, dim=1)
            correct_label_preds += (predicted_labels == label).sum().item()
            total_label_preds += label.size(0)

            y_label_preds.append(predicted_labels.cpu())
            y_label_trues.append(label.cpu())

            # Append dynamic response predictions and true values
            dynamic_response_preds.append(dynamic_response.cpu())
            dynamic_response_trues.append(floor_sequence.cpu())
    
    # Convert predictions and actual labels from list to single tensor
    y_label_preds_tensor = torch.cat(y_label_preds)
    y_label_trues_tensor = torch.cat(y_label_trues)

    # convert to numpy array for classification metrics
    y_label_preds_numpy = y_label_preds_tensor.numpy()
    y_label_trues_numpy = y_label_trues_tensor.numpy()

    # Calculate overall accuracy for classification
    test_acc = correct_label_preds / total_label_preds if total_label_preds > 0 else 0.0

    # Calculate F1 score for classification
    test_f1 = f1_score(y_label_trues_numpy, y_label_preds_numpy, average='macro')

    # Calculate Recall score for classification
    test_Recall = recall_score(y_label_trues_numpy, y_label_preds_numpy, average='macro')

    # Convert dynamic response predictions and actual values from list to single tensor
    dynamic_response_preds_tensor = torch.cat(dynamic_response_preds)
    dynamic_response_trues_tensor = torch.cat(dynamic_response_trues)

    # Convert to numpy array for MSE calculation
    dynamic_response_preds_numpy = dynamic_response_preds_tensor.numpy()
    dynamic_response_trues_numpy = dynamic_response_trues_tensor.numpy()

    # Calculate MSE for dynamic response predictions
    test_mse = mean_squared_error(dynamic_response_trues_numpy, dynamic_response_preds_numpy)

    # Build the results dictionary
    results = {
        'test_accuracy': test_acc,
        'f1_score': test_f1,
        'recall_score': test_Recall,
        'mse_dynamic_response': test_mse,
    }

    # Return results and prediction values
    return results, (y_label_preds_numpy, y_label_trues_numpy)
