In [2]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionBetterLSTM(nn.Module):
    def __init__(self, input_size=6, hidden_size=128, num_layers=2, output_len=5, dropout=0.3, attn_heads=4):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_len = output_len
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.norm = nn.LayerNorm(hidden_size)

        # Multi-head attention over LSTM outputs
        self.attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=attn_heads, batch_first=True)

        # Output FC for both mean and log-variance prediction (for uncertainty-aware loss)
        self.fc_mean = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, output_len * 2)
        )
        self.fc_logvar = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, output_len * 2)
        )

    def forward(self, x):
        lstm_out, _ = self.lstm(x)  # [B, T, H]
        lstm_out = self.norm(lstm_out)

        # Attention over time (query = last step, key/value = all steps)
        query = lstm_out[:, -1:, :]  # [B, 1, H]
        attn_out, _ = self.attn(query, lstm_out, lstm_out)  # [B, 1, H]
        attn_out = attn_out.squeeze(1)  # [B, H]

        mean = self.fc_mean(attn_out).view(-1, self.output_len, 2)
        logvar = self.fc_logvar(attn_out).view(-1, self.output_len, 2)
        return mean, logvar


In [3]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

# --- Load trajectory data ---
with open("/workspace/hjs/python/lstm_train/car_trajectories.pkl", "rb") as f:
    data = pickle.load(f)

input_len, pred_len = 10, 5
X_seqs, Y_seqs = [], []

for track_id, points in data.items():
    coords = np.array([[p[1], p[2]] for p in points])
    if len(coords) < input_len + pred_len:
        continue
    for i in range(len(coords) - input_len - pred_len + 1):
        X_seqs.append(coords[i:i+input_len])
        Y_seqs.append(coords[i+input_len:i+input_len+pred_len])

X = np.array(X_seqs)
Y = np.array(Y_seqs)
def add_kinematics(X):
    """
    X: [B, T, 2] â†’ return [B, T, 6] with velocity and acceleration
    """
    velocity = np.diff(X, axis=1, prepend=X[:, :1])  # [B, T, 2]
    acceleration = np.diff(velocity, axis=1, prepend=velocity[:, :1])  # [B, T, 2]
    return np.concatenate([X, velocity, acceleration], axis=-1)  # [B, T, 6]


X = add_kinematics(X)  # Now shape: [B, T, 6]
print(f"Prepared dataset: X={X.shape}, Y={Y.shape}")
# --- Train/Val split ---
X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.2, random_state=42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- DataLoaders ---
BATCH_SIZE = 64
train_ds = TensorDataset(torch.Tensor(X_train), torch.Tensor(Y_train))
val_ds = TensorDataset(torch.Tensor(X_val), torch.Tensor(Y_val))
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE)

Prepared dataset: X=(609284, 10, 6), Y=(609284, 5, 2)
Using device: cuda


In [7]:
def physics_aware_loss(pred_mean, pred_logvar, target):
    # Gaussian NLL Loss
    precision = torch.exp(-pred_logvar)
    nll = (precision * (target - pred_mean) ** 2 + pred_logvar).mean()

    # Cosine Similarity Loss on velocity vectors
    pred_vecs = pred_mean[:, 1:] - pred_mean[:, :-1]
    target_vecs = target[:, 1:] - target[:, :-1]
    cosine_loss = 1 - F.cosine_similarity(pred_vecs.flatten(1), target_vecs.flatten(1), dim=1).mean()

    return nll + 0.1 * cosine_loss  # You can tune the weight
def hybrid_loss(pred_mean, pred_logvar, target):
    # NLL loss
    precision = torch.exp(-pred_logvar)
    nll = (precision * (target - pred_mean) ** 2 + pred_logvar).mean()

    # Directional cosine loss
    pred_vecs = pred_mean[:, 1:] - pred_mean[:, :-1]
    target_vecs = target[:, 1:] - target[:, :-1]
    cosine_loss = 1 - F.cosine_similarity(pred_vecs.flatten(1), target_vecs.flatten(1), dim=1).mean()

    return nll + 0.1 * cosine_loss

from sklearn.metrics import mean_absolute_error
import numpy as np
from tqdm import tqdm
import torch

def train_attention_lstm(model, train_dl, val_dl, device, optimizer, hybrid_loss, num_epochs=200, patience=10):
    best_val_loss = float('inf')
    early_stop_counter = 0

    for epoch in range(1, num_epochs + 1):
        model.train()
        train_loss = 0.0
        train_loader = tqdm(train_dl, desc=f"[Epoch {epoch}] Training", leave=False)
        
        for X_batch, Y_batch in train_loader:
            X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
            mean, logvar = model(X_batch)
            loss = hybrid_loss(mean, logvar, Y_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_dl)

        # --- Validation ---
        model.eval()
        val_loss = 0.0
        mae_list, ade_list, fde_list = [], [], []

        val_loader = tqdm(val_dl, desc=f"[Epoch {epoch}] Validation", leave=False)
        with torch.no_grad():
            for X_batch, Y_batch in val_loader:
                X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
                mean, logvar = model(X_batch)
                loss = hybrid_loss(mean, logvar, Y_batch)
                val_loss += loss.item()

                # --- Metrics ---
                pred_np = mean.cpu().numpy()
                true_np = Y_batch.cpu().numpy()

                mae = mean_absolute_error(true_np.reshape(-1, 2), pred_np.reshape(-1, 2))
                ade = np.mean(np.linalg.norm(pred_np - true_np, axis=-1))
                fde = np.mean(np.linalg.norm(pred_np[:, -1] - true_np[:, -1], axis=-1))

                mae_list.append(mae)
                ade_list.append(ade)
                fde_list.append(fde)

        avg_val_loss = val_loss / len(val_dl)
        avg_mae = np.mean(mae_list)
        avg_ade = np.mean(ade_list)
        avg_fde = np.mean(fde_list)

        print(f"Epoch {epoch:03} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | "
              f"MAE: {avg_mae:.4f} | ADE: {avg_ade:.4f} | FDE: {avg_fde:.4f}")

        # --- Early Stopping ---
        if avg_val_loss < best_val_loss - 1e-4:
            best_val_loss = avg_val_loss
            early_stop_counter = 0
            torch.save(model.state_dict(), "best_model_bettter.pt")
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch}")
                break


In [8]:
# --- 1. Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 2. Model ---
model = AttentionBetterLSTM(input_size=6, hidden_size=128, num_layers=2, output_len=5, dropout=0.3).to(device)

# --- 3. Optimizer ---
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

train_attention_lstm(
    model=model,
    train_dl=train_dl,
    val_dl=val_dl,
    device=device,
    optimizer=optimizer,
    hybrid_loss=hybrid_loss,
    num_epochs=200,
    patience=10
)


                                                                          

Epoch 001 | Train Loss: -2.9204 | Val Loss: -4.0181 | MAE: 0.0638 | ADE: 0.1047 | FDE: 0.1278


                                                                          

Epoch 002 | Train Loss: -3.9885 | Val Loss: -4.9558 | MAE: 0.0483 | ADE: 0.0856 | FDE: 0.1068


                                                                          

Epoch 003 | Train Loss: -4.3408 | Val Loss: -4.9421 | MAE: 0.0424 | ADE: 0.0714 | FDE: 0.0981


                                                                          

Epoch 004 | Train Loss: -4.5538 | Val Loss: -1.4916 | MAE: 0.2382 | ADE: 0.3837 | FDE: 0.5016


                                                                          

Epoch 005 | Train Loss: -4.5301 | Val Loss: -3.4692 | MAE: 0.0980 | ADE: 0.1606 | FDE: 0.1781


                                                                          

Epoch 006 | Train Loss: -4.7104 | Val Loss: -5.3173 | MAE: 0.0420 | ADE: 0.0723 | FDE: 0.0894


                                                                          

Epoch 007 | Train Loss: -4.9131 | Val Loss: -4.9315 | MAE: 0.0524 | ADE: 0.0899 | FDE: 0.1081


                                                                          

Epoch 008 | Train Loss: -4.9320 | Val Loss: -4.7061 | MAE: 0.0598 | ADE: 0.0995 | FDE: 0.1151


                                                                          

Epoch 009 | Train Loss: -5.0134 | Val Loss: -4.9280 | MAE: 0.0467 | ADE: 0.0784 | FDE: 0.0983


                                                                           

Epoch 010 | Train Loss: -4.9960 | Val Loss: -4.4827 | MAE: 0.0577 | ADE: 0.0980 | FDE: 0.1249


                                                                           

Epoch 011 | Train Loss: -5.0946 | Val Loss: -5.4390 | MAE: 0.0374 | ADE: 0.0638 | FDE: 0.0836


                                                                           

Epoch 012 | Train Loss: -5.1614 | Val Loss: -2.9940 | MAE: 0.0856 | ADE: 0.1473 | FDE: 0.1675


                                                                           

Epoch 013 | Train Loss: -5.2137 | Val Loss: -4.3854 | MAE: 0.0628 | ADE: 0.1052 | FDE: 0.1182


                                                                           

Epoch 014 | Train Loss: -5.2510 | Val Loss: -5.6367 | MAE: 0.0448 | ADE: 0.0825 | FDE: 0.1044


                                                                           

Epoch 015 | Train Loss: -5.3661 | Val Loss: -5.7008 | MAE: 0.0342 | ADE: 0.0584 | FDE: 0.0751


                                                                           

Epoch 016 | Train Loss: -5.3577 | Val Loss: -5.3231 | MAE: 0.0423 | ADE: 0.0716 | FDE: 0.0898


                                                                           

Epoch 017 | Train Loss: -5.4185 | Val Loss: -5.3108 | MAE: 0.0508 | ADE: 0.0929 | FDE: 0.1058


                                                                           

Epoch 018 | Train Loss: -5.3975 | Val Loss: -5.7383 | MAE: 0.0334 | ADE: 0.0567 | FDE: 0.0776


                                                                           

Epoch 019 | Train Loss: -5.4710 | Val Loss: -5.0552 | MAE: 0.0400 | ADE: 0.0673 | FDE: 0.0829


                                                                           

Epoch 020 | Train Loss: -5.5296 | Val Loss: -5.1331 | MAE: 0.0359 | ADE: 0.0615 | FDE: 0.0824


                                                                           

Epoch 021 | Train Loss: -5.5756 | Val Loss: -4.0589 | MAE: 0.0457 | ADE: 0.0771 | FDE: 0.0963


                                                                           

Epoch 022 | Train Loss: -5.6038 | Val Loss: -2.5856 | MAE: 0.0400 | ADE: 0.0658 | FDE: 0.0842


                                                                           

Epoch 023 | Train Loss: -5.6375 | Val Loss: 1.2603 | MAE: 0.0481 | ADE: 0.0803 | FDE: 0.1003


                                                                           

Epoch 024 | Train Loss: -5.6219 | Val Loss: 9.0808 | MAE: 0.0521 | ADE: 0.0832 | FDE: 0.0957


                                                                           

Epoch 025 | Train Loss: -5.7556 | Val Loss: 4.6372 | MAE: 0.0945 | ADE: 0.1569 | FDE: 0.1592


                                                                           

Epoch 026 | Train Loss: -5.6727 | Val Loss: 5.7274 | MAE: 0.0627 | ADE: 0.0998 | FDE: 0.1155


                                                                           

Epoch 027 | Train Loss: -5.7487 | Val Loss: 10.7427 | MAE: 0.0525 | ADE: 0.0819 | FDE: 0.0991


                                                                           

Epoch 028 | Train Loss: -5.7738 | Val Loss: 24.5039 | MAE: 0.0657 | ADE: 0.1021 | FDE: 0.1129
Early stopping triggered at epoch 28




In [None]:
import os
import pickle
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

# from kalman_lstm_temporal import KalmanLSTM_TemporalAttention  # your model class here

# ------------------ Utilities ------------------

def add_kinematics(X):
    vel = np.diff(X, axis=1, prepend=X[:, :1])
    acc = np.diff(vel, axis=1, prepend=vel[:, :1])
    return np.concatenate([X, vel, acc], axis=-1)

def compute_metrics(pred, target):
    ade = torch.mean(torch.norm(pred - target, dim=-1))
    fde = torch.mean(torch.norm(pred[:, -1] - target[:, -1], dim=-1))
    mse = F.mse_loss(pred, target)
    return {'ade': ade.item(), 'fde': fde.item(), 'mse': mse.item()}

def compute_loss(pred, target, lambda_phy=0.1, lambda_smooth=0.05, lambda_dir=0.1):
    mse = F.mse_loss(pred, target)

    # --- physics loss ---
    vel = pred[:, 1:] - pred[:, :-1]
    acc = vel[:, 1:] - vel[:, :-1]
    phy_loss = acc.pow(2).mean()

    # --- smoothness loss ---
    smooth_loss = (pred[:, 2:] - 2*pred[:, 1:-1] + pred[:, :-2]).pow(2).mean()

    # --- direction loss ---
    pred_dir = F.normalize(pred[:, 1:] - pred[:, :-1], dim=-1)
    target_dir = F.normalize(target[:, 1:] - target[:, :-1], dim=-1)
    dir_loss = (1 - (pred_dir * target_dir).sum(-1)).mean()

    total_loss = mse + lambda_phy * phy_loss + lambda_smooth * smooth_loss + lambda_dir * dir_loss
    return total_loss

# ------------------ Training ------------------

def train(model, train_dl, val_dl, config):
    model = model.to(config['device'])
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    best_val_loss = float('inf')

    for epoch in range(config['epochs']):
        model.train()
        total_loss = 0
        pbar = tqdm(train_dl, desc=f"Epoch {epoch+1}/{config['epochs']} [Train]")
        for xb, yb in pbar:
            xb, yb = xb.to(config['device']), yb.to(config['device'])

            optimizer.zero_grad()
            pred = model(xb)
            loss = compute_loss(pred, yb,
                                config['lambda_phy'],
                                config['lambda_smooth'],
                                config['lambda_dir'])
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item())

        avg_train_loss = total_loss / len(train_dl)

        # Validation
        model.eval()
        val_loss, val_ade, val_fde, val_mse = 0, 0, 0, 0
        with torch.no_grad():
            for xb, yb in tqdm(val_dl, desc=f"Epoch {epoch+1} [Val]", leave=False):
                xb, yb = xb.to(config['device']), yb.to(config['device'])
                pred = model(xb)
                loss = compute_loss(pred, yb,
                                    config['lambda_phy'],
                                    config['lambda_smooth'],
                                    config['lambda_dir'])
                val_loss += loss.item()
                metrics = compute_metrics(pred, yb)
                val_ade += metrics['ade']
                val_fde += metrics['fde']
                val_mse += metrics['mse']

        val_loss /= len(val_dl)
        val_ade /= len(val_dl)
        val_fde /= len(val_dl)
        val_mse /= len(val_dl)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_temporal_model.pt")

        print(f"Epoch {epoch+1:02d} | Train Loss: {avg_train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | ADE: {val_ade:.4f} | FDE: {val_fde:.4f} | MSE: {val_mse:.4f}")

# ------------------ Main ------------------

if __name__ == "__main__":
    config = {
        'data_path': "/workspace/hjs/python/lstm_train/car_trajectories.pkl",
        'input_len': 10,
        'pred_len': 5,
        'batch_size': 64,
        'epochs': 50,
        'lr': 1e-3,
        'lambda_phy': 0.1,
        'lambda_smooth': 0.05,
        'lambda_dir': 0.1,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu'
    }

    with open(config['data_path'], "rb") as f:
        data = pickle.load(f)

    X_seqs, Y_seqs = [], []
    for points in data.values():
        coords = np.array([[p[1], p[2]] for p in points])
        if len(coords) < config['input_len'] + config['pred_len']:
            continue
        for i in range(len(coords) - config['input_len'] - config['pred_len'] + 1):
            X_seqs.append(coords[i:i+config['input_len']])
            Y_seqs.append(coords[i+config['input_len']:i+config['input_len']+config['pred_len']])

    X = add_kinematics(np.array(X_seqs))
    Y = np.array(Y_seqs)

    X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.2, random_state=42)
    train_dl = DataLoader(TensorDataset(torch.Tensor(X_train), torch.Tensor(Y_train)), batch_size=config['batch_size'], shuffle=True)
    val_dl = DataLoader(TensorDataset(torch.Tensor(X_val), torch.Tensor(Y_val)), batch_size=config['batch_size'])

    model = KalmanLSTM_TemporalAttention(input_size=X.shape[-1], output_len=config['pred_len'])
    print(f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters.")
    train(model, train_dl, val_dl, config)


Model has 283,531 trainable parameters.


Epoch 1/50 [Train]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 7617/7617 [00:26<00:00, 282.23it/s, loss=0.104]
                                                                   

Epoch 01 | Train Loss: 0.4714 | Val Loss: 0.0744 | ADE: 0.1841 | FDE: 0.2029 | MSE: 0.0306


Epoch 2/50 [Train]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 7617/7617 [00:26<00:00, 291.36it/s, loss=0.138]
                                                                   

Epoch 02 | Train Loss: 0.2512 | Val Loss: 0.0977 | ADE: 0.2766 | FDE: 0.2920 | MSE: 0.0507


Epoch 3/50 [Train]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 7617/7617 [00:25<00:00, 296.97it/s, loss=0.0454]
                                                                   

Epoch 03 | Train Loss: 0.2220 | Val Loss: 0.0673 | ADE: 0.1851 | FDE: 0.2330 | MSE: 0.0269


Epoch 4/50 [Train]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 7617/7617 [00:25<00:00, 297.38it/s, loss=0.159] 
                                                                   

Epoch 04 | Train Loss: 0.2031 | Val Loss: 0.0692 | ADE: 0.1654 | FDE: 0.1752 | MSE: 0.0242


Epoch 5/50 [Train]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 7617/7617 [00:26<00:00, 290.47it/s, loss=0.235]
                                                                   

Epoch 05 | Train Loss: 0.1879 | Val Loss: 0.0632 | ADE: 0.1395 | FDE: 0.1806 | MSE: 0.0178


Epoch 6/50 [Train]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 7617/7617 [00:25<00:00, 295.55it/s, loss=0.425] 
                                                                   

Epoch 06 | Train Loss: 0.1752 | Val Loss: 0.0928 | ADE: 0.2865 | FDE: 0.2763 | MSE: 0.0508


Epoch 7/50 [Train]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 7617/7617 [00:25<00:00, 296.46it/s, loss=0.162] 
                                                                   

Epoch 07 | Train Loss: 0.1636 | Val Loss: 0.0633 | ADE: 0.1518 | FDE: 0.1831 | MSE: 0.0218


Epoch 8/50 [Train]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 7617/7617 [00:26<00:00, 292.47it/s, loss=0.287] 
                                                                   

Epoch 08 | Train Loss: 0.1548 | Val Loss: 0.0586 | ADE: 0.1522 | FDE: 0.1763 | MSE: 0.0200


Epoch 9/50 [Train]: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 7617/7617 [00:25<00:00, 300.11it/s, loss=0.118] 
                                                                   

Epoch 09 | Train Loss: 0.1469 | Val Loss: 0.0660 | ADE: 0.1923 | FDE: 0.2262 | MSE: 0.0279


Epoch 10/50 [Train]:  28%|â–ˆâ–ˆâ–Š       | 2138/7617 [00:07<00:18, 295.76it/s, loss=0.105] 

In [14]:
import torch
import pickle
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
# from model import OptimizedKalmanLSTM  # Make sure your model class is imported
import torch.nn.functional as F
from tqdm import tqdm

# --- Config ---
config = {
    'data_path': "/workspace/hjs/python/lstm_train/car_trajectories.pkl",
    'input_len': 10,
    'pred_len': 5,
    'batch_size': 64,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'model_path': 'best_model.pt'
}

# --- Utility Functions ---
def add_kinematics(X):
    vel = np.diff(X, axis=1, prepend=X[:, :1])
    acc = np.diff(vel, axis=1, prepend=vel[:, :1])
    return np.concatenate([X, vel, acc], axis=-1)

def compute_metrics(pred, target):
    ade = torch.mean(torch.norm(pred - target, dim=-1))
    fde = torch.mean(torch.norm(pred[:, -1] - target[:, -1], dim=-1))
    mse = F.mse_loss(pred, target)
    return {'ade': ade.item(), 'fde': fde.item(), 'mse': mse.item()}

# --- Load Data ---
with open(config['data_path'], "rb") as f:
    data = pickle.load(f)

X_seqs, Y_seqs = [], []
for _, points in data.items():
    coords = np.array([[p[1], p[2]] for p in points])
    if len(coords) < config['input_len'] + config['pred_len']:
        continue
    for i in range(len(coords) - config['input_len'] - config['pred_len'] + 1):
        X_seqs.append(coords[i:i+config['input_len']])
        Y_seqs.append(coords[i+config['input_len']:i+config['input_len']+config['pred_len']])

X = add_kinematics(np.array(X_seqs))
Y = np.array(Y_seqs)

_, X_test, _, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
test_dl = DataLoader(TensorDataset(torch.Tensor(X_test), torch.Tensor(Y_test)), batch_size=config['batch_size'])

# --- Load Model ---
model = OptimizedKalmanLSTM(input_size=X.shape[-1], output_len=config['pred_len'])
model.load_state_dict(torch.load(config['model_path'], map_location=config['device']))
model = model.to(config['device'])
model.eval()

# --- Evaluate ---
total_ade, total_fde, total_mse = 0.0, 0.0, 0.0
all_preds, all_targets = [], []

with torch.no_grad():
    for xb, yb in tqdm(test_dl, desc="Evaluating"):
        xb, yb = xb.to(config['device']), yb.to(config['device'])
        pred = model(xb)
        metrics = compute_metrics(pred, yb)
        total_ade += metrics['ade']
        total_fde += metrics['fde']
        total_mse += metrics['mse']
        all_preds.append(pred.cpu().numpy())
        all_targets.append(yb.cpu().numpy())

# --- Results ---
num_batches = len(test_dl)
ade = total_ade / num_batches
fde = total_fde / num_batches
mse = total_mse / num_batches

print("\nðŸ“Š --- Final Evaluation Report ---")
print(f"Average Displacement Error (ADE): {ade:.4f}")
print(f"Final Displacement Error (FDE):   {fde:.4f}")
print(f"Mean Squared Error (MSE):         {mse:.4f}")

# # Optional: Save predictions for later visualization or analysis
# np.savez("eval_outputs.npz",
#          preds=np.concatenate(all_preds, axis=0),
#          targets=np.concatenate(all_targets, axis=0))


Evaluating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1905/1905 [00:01<00:00, 1029.87it/s]


ðŸ“Š --- Final Evaluation Report ---
Average Displacement Error (ADE): 0.0574
Final Displacement Error (FDE):   0.0773
Mean Squared Error (MSE):         0.0086



