In [9]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [10]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

# --- Load trajectory data ---
with open("/workspace/hjs/python/lstm_train/car_trajectories.pkl", "rb") as f:
    data = pickle.load(f)

input_len, pred_len = 10, 5
X_seqs, Y_seqs = [], []

for track_id, points in data.items():
    coords = np.array([[p[1], p[2]] for p in points])
    if len(coords) < input_len + pred_len:
        continue
    for i in range(len(coords) - input_len - pred_len + 1):
        X_seqs.append(coords[i:i+input_len])
        Y_seqs.append(coords[i+input_len:i+input_len+pred_len])

X = np.array(X_seqs)
Y = np.array(Y_seqs)
def add_kinematics(X):
    """
    X: [B, T, 2] → return [B, T, 6] with velocity and acceleration
    """
    velocity = np.diff(X, axis=1, prepend=X[:, :1])  # [B, T, 2]
    acceleration = np.diff(velocity, axis=1, prepend=velocity[:, :1])  # [B, T, 2]
    return np.concatenate([X, velocity, acceleration], axis=-1)  # [B, T, 6]


X = add_kinematics(X)  # Now shape: [B, T, 6]
print(f"Prepared dataset: X={X.shape}, Y={Y.shape}")
# --- Train/Val split ---
X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.2, random_state=42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- DataLoaders ---
BATCH_SIZE = 64
train_ds = TensorDataset(torch.Tensor(X_train), torch.Tensor(Y_train))
val_ds = TensorDataset(torch.Tensor(X_val), torch.Tensor(Y_val))
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE)

Prepared dataset: X=(609284, 10, 6), Y=(609284, 5, 2)
Using device: cuda


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BetterAttentionLSTM(nn.Module):
    def __init__(self, input_size=6, hidden_size=128, num_layers=2, output_len=5, dropout=0.3):
        super().__init__()
        self.output_len = output_len
        self.hidden_size = hidden_size
        self.bi = 2  # Bidirectional

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, 
                            dropout=dropout, batch_first=True, bidirectional=True)

        self.attention = nn.MultiheadAttention(embed_dim=hidden_size * self.bi, num_heads=4, batch_first=True)

        self.fc = nn.Sequential(
            nn.Linear(hidden_size * self.bi, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, output_len * 2)
        )

    def forward(self, x):
        # x: [B, T, input_size]
        residual = x
        out, _ = self.lstm(x)  # out: [B, T, hidden*2]
        
        # Attention: query=last step, keys/values=entire sequence
        query = out[:, -1:, :]  # [B, 1, H*2]
        attn_out, _ = self.attention(query, out, out)  # [B, 1, H*2]

        # Residual connection
        attn_out = attn_out + query  # [B, 1, H*2]

        out = attn_out.squeeze(1)  # [B, H*2]
        pred = self.fc(out).view(-1, self.output_len, 2)  # [B, 5, 2]
        return pred
def hybrid_loss(pred, target):
    # pred, target: [B, T, 2]
    mse = F.mse_loss(pred, target)

    # Cosine similarity between predicted and target direction vectors
    pred_vel = pred[:, 1:] - pred[:, :-1]
    true_vel = target[:, 1:] - target[:, :-1]
    cos_sim = F.cosine_similarity(pred_vel, true_vel, dim=-1)
    cos_loss = 1 - cos_sim.mean()

    # Smoothness penalty (acceleration consistency)
    pred_acc = pred[:, 2:] - 2 * pred[:, 1:-1] + pred[:, :-2]
    smoothness_loss = pred_acc.abs().mean()

    return mse + 0.1 * cos_loss + 0.1 * smoothness_loss


In [17]:
from sklearn.metrics import mean_absolute_error

# --- Setup ---
model = BetterAttentionLSTM(input_size=6).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)


In [18]:

num_epochs = 200
patience = 10
best_val_loss = float('inf')
early_stop_counter = 0

for epoch in range(1, num_epochs + 1):
    model.train()
    train_loss = 0.0
    train_loader = tqdm(train_dl, desc=f"[Epoch {epoch}] Training", leave=False)
    
    for X_batch, Y_batch in train_loader:
        X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
        pred = model(X_batch)
        loss = hybrid_loss(pred, Y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_dl)

    # --- Validation ---
    model.eval()
    val_loss = 0.0
    mae_list, ade_list, fde_list = [], [], []

    val_loader = tqdm(val_dl, desc=f"[Epoch {epoch}] Validation", leave=False)
    with torch.no_grad():
        for X_batch, Y_batch in val_loader:
            X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
            pred = model(X_batch)
            loss = hybrid_loss(pred, Y_batch)
            val_loss += loss.item()

            # --- Metrics ---
            pred_np = pred.cpu().numpy()
            true_np = Y_batch.cpu().numpy()

            # MAE (flattened)
            mae = mean_absolute_error(true_np.reshape(-1, 2), pred_np.reshape(-1, 2))
            mae_list.append(mae)

            # ADE (mean L2 error over timesteps)
            ade = np.mean(np.linalg.norm(pred_np - true_np, axis=-1))
            ade_list.append(ade)

            # FDE (L2 error at last time step)
            fde = np.mean(np.linalg.norm(pred_np[:, -1] - true_np[:, -1], axis=-1))
            fde_list.append(fde)

    avg_val_loss = val_loss / len(val_dl)
    avg_mae = np.mean(mae_list)
    avg_ade = np.mean(ade_list)
    avg_fde = np.mean(fde_list)

    print(f"Epoch {epoch:03} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | "
          f"MAE: {avg_mae:.4f} | ADE: {avg_ade:.4f} | FDE: {avg_fde:.4f}")

    # --- Early stopping ---
    if avg_val_loss < best_val_loss - 1e-4:
        best_val_loss = avg_val_loss
        early_stop_counter = 0
        torch.save(model.state_dict(), "best_model.pt")
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch}")
            break


                                                                          

Epoch 001 | Train Loss: 0.1631 | Val Loss: 0.0798 | MAE: 0.1058 | ADE: 0.1700 | FDE: 0.1924


                                                                          

Epoch 002 | Train Loss: 0.0802 | Val Loss: 0.0657 | MAE: 0.0821 | ADE: 0.1316 | FDE: 0.1951


                                                                          

Epoch 003 | Train Loss: 0.0675 | Val Loss: 0.0569 | MAE: 0.0647 | ADE: 0.1056 | FDE: 0.1568


                                                                          

Epoch 004 | Train Loss: 0.0659 | Val Loss: 0.0617 | MAE: 0.0802 | ADE: 0.1304 | FDE: 0.1427


                                                                          

Epoch 005 | Train Loss: 0.0674 | Val Loss: 0.0579 | MAE: 0.0691 | ADE: 0.1192 | FDE: 0.1447


                                                                          

Epoch 006 | Train Loss: 0.0639 | Val Loss: 0.0536 | MAE: 0.0558 | ADE: 0.0946 | FDE: 0.1594


                                                                          

Epoch 007 | Train Loss: 0.0599 | Val Loss: 0.0539 | MAE: 0.0610 | ADE: 0.1007 | FDE: 0.1262


                                                                          

Epoch 008 | Train Loss: 0.0615 | Val Loss: 0.0619 | MAE: 0.0729 | ADE: 0.1211 | FDE: 0.1491


                                                                          

Epoch 009 | Train Loss: 0.0584 | Val Loss: 0.0501 | MAE: 0.0528 | ADE: 0.0876 | FDE: 0.1316


                                                                           

Epoch 010 | Train Loss: 0.0588 | Val Loss: 0.0554 | MAE: 0.0564 | ADE: 0.0925 | FDE: 0.1156


                                                                           

Epoch 011 | Train Loss: 0.0576 | Val Loss: 0.0533 | MAE: 0.0602 | ADE: 0.0994 | FDE: 0.1461


                                                                           

Epoch 012 | Train Loss: 0.0554 | Val Loss: 0.0542 | MAE: 0.0590 | ADE: 0.0959 | FDE: 0.1250


                                                                           

Epoch 013 | Train Loss: 0.0553 | Val Loss: 0.0548 | MAE: 0.0702 | ADE: 0.1179 | FDE: 0.1628


                                                                           

Epoch 014 | Train Loss: 0.0596 | Val Loss: 0.0590 | MAE: 0.0660 | ADE: 0.1114 | FDE: 0.1764


                                                                           

Epoch 015 | Train Loss: 0.0560 | Val Loss: 0.0586 | MAE: 0.0675 | ADE: 0.1166 | FDE: 0.1265


                                                                           

Epoch 016 | Train Loss: 0.0547 | Val Loss: 0.0674 | MAE: 0.0828 | ADE: 0.1287 | FDE: 0.2001


                                                                           

Epoch 017 | Train Loss: 0.0547 | Val Loss: 0.0525 | MAE: 0.0566 | ADE: 0.0950 | FDE: 0.1236


                                                                           

Epoch 018 | Train Loss: 0.0558 | Val Loss: 0.0622 | MAE: 0.0664 | ADE: 0.1054 | FDE: 0.1169


                                                                           

Epoch 019 | Train Loss: 0.0554 | Val Loss: 0.0479 | MAE: 0.0444 | ADE: 0.0736 | FDE: 0.1043


                                                                           

Epoch 020 | Train Loss: 0.0543 | Val Loss: 0.0504 | MAE: 0.0539 | ADE: 0.0920 | FDE: 0.1399


                                                                           

Epoch 021 | Train Loss: 0.0532 | Val Loss: 0.0535 | MAE: 0.0556 | ADE: 0.0965 | FDE: 0.1522


                                                                           

Epoch 022 | Train Loss: 0.0535 | Val Loss: 0.0606 | MAE: 0.0534 | ADE: 0.0915 | FDE: 0.1073


                                                                           

Epoch 023 | Train Loss: 0.0547 | Val Loss: 0.0479 | MAE: 0.0482 | ADE: 0.0800 | FDE: 0.1009


                                                                           

Epoch 024 | Train Loss: 0.0527 | Val Loss: 0.0514 | MAE: 0.0528 | ADE: 0.0866 | FDE: 0.1247


                                                                           

Epoch 025 | Train Loss: 0.0529 | Val Loss: 0.0552 | MAE: 0.0768 | ADE: 0.1232 | FDE: 0.1443


                                                                           

Epoch 026 | Train Loss: 0.0522 | Val Loss: 0.0479 | MAE: 0.0458 | ADE: 0.0754 | FDE: 0.1144


                                                                           

Epoch 027 | Train Loss: 0.0527 | Val Loss: 0.0536 | MAE: 0.0540 | ADE: 0.0907 | FDE: 0.1259


                                                                           

Epoch 028 | Train Loss: 0.0542 | Val Loss: 0.0523 | MAE: 0.0600 | ADE: 0.1004 | FDE: 0.1461


                                                                           

Epoch 029 | Train Loss: 0.0537 | Val Loss: 0.0491 | MAE: 0.0485 | ADE: 0.0812 | FDE: 0.1071
Early stopping triggered at epoch 29




In [7]:
model = BetterAttentionLSTM(input_size=6).to(device)
model.load_state_dict(torch.load("best_model.pt"))
model.eval()


AttentionLSTM(
  (lstm): LSTM(6, 128, num_layers=2, batch_first=True, dropout=0.3)
  (attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (fc): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [8]:
from sklearn.metrics import mean_absolute_error
import numpy as np
from tqdm import tqdm

mae_list, ade_list, fde_list = [], [], []

with torch.no_grad():
    val_loader = tqdm(val_dl, desc="Evaluating Best Model", leave=False)
    for X_batch, Y_batch in val_loader:
        X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
        pred = model(X_batch)

        pred_np = pred.cpu().numpy()
        true_np = Y_batch.cpu().numpy()

        # MAE
        mae = mean_absolute_error(true_np.reshape(-1, 2), pred_np.reshape(-1, 2))
        mae_list.append(mae)

        # ADE: average L2 error per point
        ade = np.mean(np.linalg.norm(pred_np - true_np, axis=-1))  # shape [B, T]
        ade_list.append(ade)

        # FDE: L2 error at final point
        fde = np.mean(np.linalg.norm(pred_np[:, -1] - true_np[:, -1], axis=-1))  # shape [B]
        fde_list.append(fde)

# --- Final Results ---
print("\n✅ Best Model Evaluation:")
print(f"MAE: {np.mean(mae_list):.4f}")
print(f"ADE: {np.mean(ade_list):.4f}")
print(f"FDE: {np.mean(fde_list):.4f}")


                                                                           


✅ Best Model Evaluation:
MAE: 0.0632
ADE: 0.1054
FDE: 0.1427


