In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

    def forward(self, x):
        # x: (batch, seq_len, embed_dim)
        attn_output, attn_weights = self.attention(x, x, x)
        return attn_output, attn_weights

class DSA(nn.Module):
    def __init__(self, num_vars, num_timesteps, num_categorical, cat_embed_dim, input_dim, hidden_dim, num_heads):
        super().__init__()
        self.cat_embedding = nn.Embedding(num_categorical, cat_embed_dim)
        self.input_dim = input_dim + cat_embed_dim

        # Self-attention over variables
        self.var_attention = MultiHeadSelfAttention(self.input_dim, num_heads)
        # Self-attention over time
        self.time_attention = MultiHeadSelfAttention(self.input_dim, num_heads)

        # BiLSTM
        self.bilstm = nn.LSTM(self.input_dim, hidden_dim, batch_first=True, bidirectional=True)

        # Output layer
        self.fc = nn.Linear(hidden_dim * 2, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x_num, x_cat):
        # x_num: (batch, time, vars)
        # x_cat: (batch, time) categorical index

        # Embed categorical and concatenate
        x_cat_embed = self.cat_embedding(x_cat)  # (batch, time, embed_dim)
        x = torch.cat([x_num, x_cat_embed], dim=-1)  # (batch, time, input_dim)

        # Variable-level attention
        x_var_attn, _ = self.var_attention(x)  # (batch, time, input_dim)

        # Time-step-level attention
        x_time_attn, _ = self.time_attention(x)  # (batch, time, input_dim)

        # Combine attentions (element-wise product)
        x_combined = x_var_attn * x_time_attn

        # BiLSTM
        lstm_out, _ = self.bilstm(x_combined)  # (batch, time, hidden_dim*2)
        lstm_out = lstm_out[:, -1, :]  # Take last time step

        # Output
        out = self.fc(lstm_out)
        return self.sigmoid(out)
