In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.nn.utils.rnn import pad_sequence

In [2]:
# Define the Transformer model for classification with padding and masking
class TransformerClassifier(nn.Module):
    def __init__(self, input_dim, model_dim, num_heads, num_layers, num_classes, dropout=0.1):
        super(TransformerClassifier, self).__init__()
        self.model_dim = model_dim
        
        # Embedding layer to convert input to model dimensions
        self.embedding = nn.Linear(input_dim, model_dim)
        
        # Positional encoding for variable length sequences
        self.pos_encoder = PositionalEncoding(model_dim, dropout)
        
        # Transformer encoder
        encoder_layers = TransformerEncoderLayer(model_dim, num_heads, model_dim, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
        
        # Classification head
        self.fc_out = nn.Linear(model_dim, num_classes)  # Output layer for classification

    def forward(self, src, src_mask):
        src = self.embedding(src) * torch.sqrt(torch.tensor(self.model_dim, dtype=torch.float32))
        src = self.pos_encoder(src)
        transformer_output = self.transformer_encoder(src, src_mask)
        
        # Aggregate the output of the transformer (e.g., using mean pooling or just the first token)
        pooled_output = transformer_output.mean(dim=0)
        
        # Pass through classification layer
        output = self.fc_out(pooled_output)
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, model_dim, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, model_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, model_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / model_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# Example usage with variable-length sequences:
input_dim = 10  # Number of features
model_dim = 32  # Transformer model dimension
num_heads = 4   # Number of attention heads
num_layers = 3  # Number of transformer layers
num_classes = 5  # Number of classes for classification
dropout = 0.1   # Dropout rate

# Example of a batch of sequences with different lengths
seq1 = torch.randn(8, 10)  # 8 time steps, 10 features
seq2 = torch.randn(5, 10)  # 5 time steps, 10 features
seq3 = torch.randn(7, 10)  # 7 time steps, 10 features

# Pad the sequences to match the longest sequence length
padded_sequences = pad_sequence([seq1, seq2, seq3], batch_first=False)  # Shape: [max_seq_len, batch_size, input_dim]

# Create a mask to ignore the padded positions
def create_mask(padded_seqs):
    # Create a binary mask where 1 indicates valid tokens and 0 indicates padding
    return (padded_seqs != 0).float()

src_mask = create_mask(padded_sequences[:, :, 0])  # Shape: [max_seq_len, batch_size]

# Initialize the transformer classification model
model = TransformerClassifier(input_dim, model_dim, num_heads, num_layers, num_classes, dropout)

# Forward pass with masking
output = model(padded_sequences, src_mask)
print(output.shape)  # Output will have shape: [batch_size, num_classes]

# Example of using CrossEntropyLoss for training
criterion = nn.CrossEntropyLoss()
labels = torch.randint(0, num_classes, (padded_sequences.size(1),))  # Random labels for testing
loss = criterion(output, labels)
print(f"Loss: {loss.item()}")


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 12)]    0           []                               
                                                                                                  
 conv1d (Conv1D)                (None, 128, 256)     3328        ['input_1[0][0]']                
                                                                                                  
 layer_normalization (LayerNorm  (None, 128, 256)    512         ['conv1d[0][0]']                 
 alization)                                                                                       
                                                                                                  
 multi_head_attention (MultiHea  (None, 128, 256)    1051904     ['layer_normalization[0][0]',