In [1]:
import os
import numpy as np
import pandas as pd
import random
import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoderLayer, TransformerDecoder
import matplotlib.pyplot as plt
from scipy.signal import cont2discrete, lti, dlti, dstep
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
import math
import scipy.integrate
import random

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=201):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

class KalmanFormer(nn.Module):
    def __init__(self, f, h, m, n, hidden_dim=64, num_heads=2, num_layers=2, dropout=0.1):
        super(KalmanFormer, self).__init__()

        self.f = f
        self.h = h
        self.m = m
        self.n = n

        feature_dim_enc = 2 * n
        feature_dim_dec = 2 * m

        self.encoder_embedding = nn.Linear(feature_dim_enc, hidden_dim)
        self.decoder_embedding = nn.Linear(feature_dim_dec, hidden_dim)

        self.pos_encoder = PositionalEncoding(hidden_dim)
        self.pos_decoder = PositionalEncoding(hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=num_heads, dim_feedforward=64,
            dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim, nhead=num_heads, dim_feedforward=64,
            dropout=dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        self.kalman_gain_head = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.Tanh(),
            nn.Linear(64, m * n)
        )

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def InitSequence(self, M1_0, T):
        self.T = T
        self.m1x_posterior = M1_0.to(self.device)
        self.m1x_posterior_previous = self.m1x_posterior
        self.m1x_prior_previous = self.m1x_posterior
        self.y_previous = self.h(self.m1x_posterior)

    def step_prior(self):
        self.m1x_prior = self.f(self.m1x_posterior)
        self.m1y = self.h(self.m1x_prior)

    def compute_kalmanformer_inputs(self, x_post, x_post_prev, x_prior_prev, y_prior, y_previous, y):
        F1 = (y - y_previous).squeeze(-1)
        F2 = (y - y_prior).squeeze(-1)
        F3 = (x_post - x_post_prev).squeeze(-1)
        F4 = (x_post - x_prior_prev).squeeze(-1)

        
        F_encoder = torch.cat([F1, F2], dim=-1)  
        F_decoder = torch.cat([F3, F4], dim=-1) 

        F_encoder = F.normalize(F_encoder, p=2, dim=1)
        F_decoder = F.normalize(F_decoder, p=2, dim=1)

        F_encoder = F_encoder.unsqueeze(1)
        F_decoder = F_decoder.unsqueeze(1)


        return F_encoder, F_decoder

    def Kalman_step(self, y_t):
        self.step_prior()

        F_encoder, F_decoder = self.compute_kalmanformer_inputs(
            self.m1x_posterior, self.m1x_posterior_previous,
            self.m1x_prior_previous, self.m1y, self.y_previous, y_t
        )

        enc = self.encoder_embedding(F_encoder)
        enc = self.pos_encoder(enc)
        memory = self.encoder(enc)

        dec = self.decoder_embedding(F_decoder)
        dec = self.pos_decoder(dec)
        decoded = self.decoder(dec, memory)

        kalman_gain = self.kalman_gain_head(decoded).squeeze(1)
        kalman_gain = kalman_gain.view(-1, self.m, self.n)

        dy = y_t - self.m1y
        INOV = torch.bmm(kalman_gain, dy)

        self.m1x_posterior_previous = self.m1x_posterior
        self.m1x_posterior = self.m1x_prior + INOV
        self.m1x_prior_previous = self.m1x_prior
        self.y_previous = y_t

        return self.m1x_posterior

    def forward(self, y_sequence):
        batch_size, seq_len, _, _ = y_sequence.shape
        preds = []
        for t in range(seq_len):
            y_t = y_sequence[:, t, :, :]
            x_hat = self.Kalman_step(y_t)
            preds.append(x_hat.unsqueeze(1))

        return torch.cat(preds, dim=1)

In [3]:
def data_generation(num_sequences, sequence_length, number_masses):
    dim_y = number_masses
    dim_x= 2*dim_y 

    X_data_array = np.empty((num_sequences, sequence_length, dim_x))
    Y_data_array = np.empty((num_sequences, sequence_length, dim_y))

    m = np.ones(dim_y)
    m = 10*m
   
    k = np.ones(dim_y)
    k = 800*k

    d = np.ones(dim_y)
    d = 6*d

    A_c = np.zeros((dim_x,dim_x))

    offset = 0
    for i in range(dim_x):
        if i % 2 == 0:
            A_c[i,i+1] = 1

        if i % 2 == 1:
            if i != dim_x-1:
                A_c[i,i-1] = -(k[i-1-offset]+k[i-offset])/m[i-1-offset]
                A_c[i,i] = -(d[i-1-offset]+d[i-offset])/m[i-1-offset]
                A_c[i,i+1] = k[i-offset]/m[i-1-offset]
                A_c[i,i+2] = d[i-offset]/m[i-1-offset]
            else:
                A_c[i,i-1] = -k[i-dim_y]/m[i-dim_y]
                A_c[i,i] = -d[i-dim_y]/m[i-dim_y]

            if i != 1:
                A_c[i,i-3] = k[i-1-offset]/m[i-1-offset]
                A_c[i,i-2] = d[i-1-offset]/m[i-1-offset]

            offset += 1

    B_c = np.zeros((dim_x,dim_y))

    H_c = np.zeros((dim_y,dim_x))
    offset = 0
    for i in range(dim_y):
        H_c[i,i+offset] = 1

        offset += 1

    D_c = np.array([[0.]])

   
    dt = 0.1 
    d_system = cont2discrete((A_c, B_c, H_c, D_c),dt)
    A = d_system[0] 
    H = d_system[2] 

    def is_schur(matrix):
       
        eigenvalues, _ = np.linalg.eig(matrix)
        if np.all(np.abs(eigenvalues) < 1):
            print(np.abs(eigenvalues))
            return True
        else:
            return False

    # if is_schur(A):
    #     print("The matrix is Schur.")
    # else:
    #     print("The matrix is not Schur.")


    sigma_p = 0.01 
    sigma_p_diag = (sigma_p**2)*np.ones(dim_x)
    Q = np.diag(sigma_p_diag)

    sigma_m = 0.01 
    sigma_m_diag = (sigma_m**2)*np.ones(dim_y)
    R = np.diag(sigma_m_diag)

    sigma_x = 0.01 
    sigma_x_diag = (sigma_p**2)*np.ones(dim_x)
    P = np.diag(sigma_x_diag)


    for s in range(num_sequences):
        mu_x0 = np.random.uniform(-10,10,size=dim_x) 
        x=np.random.multivariate_normal(mu_x0,P)
       
        X_data_array[s,0,:] = np.squeeze(np.asarray(x))

       
        v_0 = np.random.multivariate_normal(np.zeros(dim_y),R).reshape(-1,1)

        y = H.dot(x.reshape(-1,1)) + v_0


        W = np.random.multivariate_normal(np.zeros(dim_x), Q, sequence_length)
        V = np.random.multivariate_normal(np.zeros(dim_y), R, sequence_length)
        Y_data_array[s,0,:] = np.squeeze(np.asarray(y))

        for t in range(1,sequence_length+1):
            w = W[t-1].reshape(-1,1) 
            v = V[t-1].reshape(-1,1)
        
            x = A.dot(x.reshape(-1,1)) + w 
            y = H.dot(x.reshape(-1,1)) + v 

            X_data_array[s,t:t+1,:] = x.T
            Y_data_array[s,t:t+1,:] = y.T

    return X_data_array, Y_data_array, A, H

print(data_generation(200, 100, 2)[0].shape)

(200, 100, 4)


In [4]:
def data_loaders(X_np, Y_np, batch_size, train_ratio, val_ratio):
   
  
    X = torch.tensor(X_np, dtype=torch.float32).unsqueeze(-1)  
    Y = torch.tensor(Y_np, dtype=torch.float32).unsqueeze(-1)  

    dataset = TensorDataset(Y, X)

    total_size = len(dataset)
    train_size = int(train_ratio * total_size)
    val_size = int(val_ratio * total_size)
    test_size = total_size - train_size - val_size

    train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

    train_data = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_data = DataLoader(val_set, batch_size=batch_size, shuffle=False)
    test_data = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    return train_data, val_data, test_data

    

In [5]:
def train_kalmanformer(model, data_loader, optimizer, epochs):
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)
    model.train()
    

    for epoch in range(epochs):
        total_loss = 0
        for y_seq, x_true in data_loader:
            _, T, n, _ = y_seq.shape

            x0 = torch.zeros_like(x_true[:, 0, :, :])
            model.InitSequence(x0, T)

#             g, i, u = model(y_seq[:, 1:], return_gain=True)
#      
#             K = torch.stack(g, dim=1)          
#             dz = torch.stack(i, dim=1)    
#             dx = torch.stack(u, dim=1)       
#             pred_dx = torch.matmul(K, dz)           
           
#             loss = F.mse_loss(pred_dx, dx)    
       
            preds = model(y_seq[:, 1:])  

            loss = F.mse_loss(preds, x_true[:, 1:, :, :]) 

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            total_loss += loss.item()
        
        scheduler.step()
        print(f"Epoch {epoch + 1}: Loss = {total_loss:.6f}")


In [None]:

X_np, Y_np, A, H = data_generation(num_sequences=200, sequence_length = 500, number_masses = 1)

A = torch.tensor(A, dtype=torch.float32)
H = torch.tensor(H, dtype=torch.float32)

def f(x): return A @ x
def h(x): return H @ x
    
import time
start = time.time(
)
model = KalmanFormer(f=f, h=h, m=A.shape[0], n=H.shape[0], hidden_dim=64, num_heads=4, num_layers=2)

train_data, val_data, test_data = data_loaders(X_np, Y_np, batch_size=60, train_ratio =0.8, val_ratio =0.1)

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, weight_decay = 0.001)
train_kalmanformer(model, train_data, optimizer, epochs = 20)
end = time.time()

print(end-start)

torch.Size([60, 500, 2, 1])
torch.Size([60, 500, 2, 1])
torch.Size([40, 500, 2, 1])
Epoch 1: Loss = 32.454988
torch.Size([60, 500, 2, 1])
torch.Size([60, 500, 2, 1])
torch.Size([40, 500, 2, 1])
Epoch 2: Loss = 30.094148
torch.Size([60, 500, 2, 1])
torch.Size([60, 500, 2, 1])
torch.Size([40, 500, 2, 1])
Epoch 3: Loss = 28.035199
torch.Size([60, 500, 2, 1])
torch.Size([60, 500, 2, 1])
torch.Size([40, 500, 2, 1])
Epoch 4: Loss = 26.107437
torch.Size([60, 500, 2, 1])
torch.Size([60, 500, 2, 1])


In [None]:
def evaluate_kalmanformer_true_mse(model, data_loader):
    model.eval()
    total_loss = 0.0
    total_seqs = 0

    for batch in data_loader:
        y_seq, x_true = batch  
        b, T, m, _ = x_true.shape
        x0 = torch.zeros_like(x_true[:, 0, :, :])  
#         x0 = x_true[:, 0, :, :]  
        model.InitSequence(x0, T)

        x_pred = model(y_seq[:, 1:])   
    

        x_gt = x_true[:, 1:, :, :]           
#         print(x_pred[0] - x_gt[0])
        loss = F.mse_loss(x_pred, x_gt)  
        total_loss += loss.item()
#         total_seqs += b  


    final_mse = total_loss 
    print(f"[MSE] = {final_mse:.6f}")
    return final_mse

In [None]:
evaluate_kalmanformer_true_mse(model, test_data)


In [None]:
def plot_pred_vs_true_simple(model, batch, title_prefix="Train"):
    model.eval()
    y_seq, x_true = batch
    x0 = torch.zeros_like(x_true[:, 0, :, :])  
#     x0 = x_true[:, 0, :, :]
    T = y_seq.shape[1]

    model.InitSequence(x0, T)
    
    with torch.no_grad():
        x_pred = model(y_seq[:, 1:]) 

    x_true = x_true[:, 1:, :, :]     

    idx = 5
    pred_np = x_pred[idx].squeeze(-1).cpu().numpy()  
    true_np = x_true[idx].squeeze(-1).cpu().numpy()  

    time = np.arange(pred_np.shape[0])
    for state in range(pred_np.shape[1]-1):
        plt.plot(time, true_np[:, state], label=f'True State {idx}')
        plt.plot(time, pred_np[:, state], '--', label=f'Pred State {idx}')

    plt.title(f'{title_prefix} Trajectory')
    plt.xlabel('Time')
    plt.ylabel('State')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


In [None]:

# plot_pred_vs_true_simple(model, next(iter(train_data)), title_prefix="Train")
plot_pred_vs_true_simple(model, next(iter(test_data)), title_prefix="Test")