In [None]:
import torch
import torch.nn as nn

class KalmanAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(KalmanAttention, self).__init__()
        self.Watt = nn.Linear(hidden_dim, hidden_dim)
        self.Uatt = nn.Linear(hidden_dim, hidden_dim)
        self.vt = nn.Linear(hidden_dim, 1)

    def forward(self, hidden_states):
        batch_size, seq_len, hidden_dim = hidden_states.size()
        ct_prev = torch.zeros(batch_size, hidden_dim).to(hidden_states.device)
        att_scores = []

        for t in range(seq_len):
            ht = hidden_states[:, t, :]
            score = torch.tanh(self.Watt(ct_prev) + self.Uatt(ht))
            score = self.vt(score).squeeze(-1)
            att_scores.append(score)

        att_weights = torch.stack(att_scores, dim=1)
        att_weights = torch.softmax(att_weights, dim=1)
        context_vector = torch.bmm(att_weights.unsqueeze(1), hidden_states).squeeze(1)

        return context_vector, att_weights
