In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TemporalAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attn = nn.Linear(hidden_dim, 1)

    def forward(self, x):  # x: [B, T, H]
        weights = torch.softmax(self.attn(x), dim=1)  # [B, T, 1]
        context = torch.sum(weights * x, dim=1)       # [B, H]
        return context, weights


class KalmanLSTM(nn.Module):
    def __init__(self, input_size=6, hidden_size=128, num_layers=2, output_len=5, dropout=0.3):
        super().__init__()
        self.output_len = output_len
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
                            batch_first=True, dropout=dropout)
        self.attn = TemporalAttention(hidden_size)
        self.norm = nn.LayerNorm(hidden_size)
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, output_len * 2)
        )

    def forward(self, x):  # x: [B, T, 6]
        lstm_out, _ = self.lstm(x)  # [B, T, H]
        context, _ = self.attn(lstm_out)  # [B, H]
        context = self.norm(context)
        pred = self.fc(context)  # [B, output_len * 2]
        return pred.view(-1, self.output_len, 2)
