# Model Experiments

In [None]:
import torch
import torch.nn as nn

In [None]:

class TransformerEncoder(nn.Module):
    def __init__(self, input_dim=128, num_heads=4, hidden_dim=256, num_layers=4, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        
        # Transformer Encoder Layers
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=input_dim, 
            nhead=num_heads, 
            dim_feedforward=hidden_dim, 
            dropout=dropout, 
            batch_first=True  # Ensures (batch, seq_len, feature_dim) ordering
        )
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)

        # Pooling layer to get a fixed-size customer embedding
        self.pooling_layer = nn.AdaptiveAvgPool1d(1)  # Mean pooling

    def forward(self, x, mask=None):
        """
        x: (batch_size, seq_len, 128) -> Transaction sequences
        mask: (batch_size, seq_len) -> Padding mask (1 for real, 0 for padding)
        """
        # Convert mask: Transformer expects "True" for positions to ignore
        if mask is not None:
            mask = mask == 0  # Now 1->False (keep), 0->True (ignore)

        print(f"Mask Shape: {mask.shape if mask is not None else None}")  # Debugging
        print(f"Input Shape: {x.shape}")  # Debugging

        # Transformer encoder processes the sequence
        x = self.transformer_encoder(x, src_key_padding_mask=mask)

        # Mean Pooling: Convert (batch_size, seq_len, 128) -> (batch_size, 128)
        x = x.permute(0, 2, 1)  # Change shape for pooling
        x = self.pooling_layer(x)  # Shape: (batch, 128, 1)
        x = x.squeeze(-1)  # Shape: (batch, 128)

        return x  # Customer embeddings


In [None]:
batch_size = 3
max_seq_len = 10  # Assume max transactions per batch is 10
input_dim = 128

# Simulated batch of customer transaction sequences (padded where needed)
x = torch.randn(batch_size, max_seq_len, input_dim)

# Simulated mask (1 for real transactions, 0 for padding)
mask = torch.tensor([
    [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],  # Customer 1: 5 transactions
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],  # Customer 2: 10 transactions
    [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]   # Customer 3: 7 transactions
])

model = TransformerEncoder()
output_embeddings = model(x, mask)

print("Final Customer Embeddings Shape:", output_embeddings.shape)  # (batch_size, 128)