In [1]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [2]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

# --- Load trajectory data ---
with open("/workspace/hjs/python/lstm_train/car_trajectories.pkl", "rb") as f:
    data = pickle.load(f)

input_len, pred_len = 10, 5
X_seqs, Y_seqs = [], []

for track_id, points in data.items():
    coords = np.array([[p[1], p[2]] for p in points])
    if len(coords) < input_len + pred_len:
        continue
    for i in range(len(coords) - input_len - pred_len + 1):
        X_seqs.append(coords[i:i+input_len])
        Y_seqs.append(coords[i+input_len:i+input_len+pred_len])

X = np.array(X_seqs)
Y = np.array(Y_seqs)
def add_kinematics(X):
    """
    X: [B, T, 2] → return [B, T, 6] with velocity and acceleration
    """
    velocity = np.diff(X, axis=1, prepend=X[:, :1])  # [B, T, 2]
    acceleration = np.diff(velocity, axis=1, prepend=velocity[:, :1])  # [B, T, 2]
    return np.concatenate([X, velocity, acceleration], axis=-1)  # [B, T, 6]


X = add_kinematics(X)  # Now shape: [B, T, 6]
print(f"Prepared dataset: X={X.shape}, Y={Y.shape}")
# --- Train/Val split ---
X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.2, random_state=42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- DataLoaders ---
BATCH_SIZE = 64
train_ds = TensorDataset(torch.Tensor(X_train), torch.Tensor(Y_train))
val_ds = TensorDataset(torch.Tensor(X_val), torch.Tensor(Y_val))
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE)

Prepared dataset: X=(609284, 10, 6), Y=(609284, 5, 2)
Using device: cuda


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

# --- Define the model ---
class AttentionLSTM(nn.Module):
    def __init__(self, input_size=6, hidden_size=128, num_layers=2, output_len=5, dropout=0.3):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.attn = nn.MultiheadAttention(hidden_size, num_heads=4, batch_first=True)
        self.norm = nn.LayerNorm(hidden_size)
        self.output_len = output_len
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, output_len * 2)
        )

    def forward(self, x):
        out, _ = self.lstm(x)
        attn_out, _ = self.attn(out, out, out)
        last = self.norm(attn_out[:, -1])
        pred = self.fc(last)
        return pred.view(-1, self.output_len, 2)

# --- Define the loss ---
def hybrid_loss(pred_seq, true_seq, alpha=1.0, beta=0.3, gamma=0.3):
    mse_loss = F.mse_loss(pred_seq, true_seq)

    vel = pred_seq[:, 1:] - pred_seq[:, :-1]
    acc = vel[:, 1:] - vel[:, :-1]
    jerk = acc[:, 1:] - acc[:, :-1]
    physics_loss = torch.mean(jerk ** 2)

    true_vel = true_seq[:, 1:] - true_seq[:, :-1]
    pred_vel = pred_seq[:, 1:] - pred_seq[:, :-1]
    cos_sim = F.cosine_similarity(pred_vel, true_vel, dim=-1)
    dir_loss = torch.mean(1 - cos_sim)

    return alpha * mse_loss + beta * physics_loss + gamma * dir_loss


In [7]:
from sklearn.metrics import mean_absolute_error

# --- Setup ---
model = AttentionLSTM(input_size=6).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)


In [None]:

num_epochs = 200
patience = 10
best_val_loss = float('inf')
early_stop_counter = 0

for epoch in range(1, num_epochs + 1):
    model.train()
    train_loss = 0.0
    train_loader = tqdm(train_dl, desc=f"[Epoch {epoch}] Training", leave=False)
    
    for X_batch, Y_batch in train_loader:
        X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
        pred = model(X_batch)
        loss = hybrid_loss(pred, Y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_dl)

    # --- Validation ---
    model.eval()
    val_loss = 0.0
    mae_list, ade_list, fde_list = [], [], []

    val_loader = tqdm(val_dl, desc=f"[Epoch {epoch}] Validation", leave=False)
    with torch.no_grad():
        for X_batch, Y_batch in val_loader:
            X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
            pred = model(X_batch)
            loss = hybrid_loss(pred, Y_batch)
            val_loss += loss.item()

            # --- Metrics ---
            pred_np = pred.cpu().numpy()
            true_np = Y_batch.cpu().numpy()

            # MAE (flattened)
            mae = mean_absolute_error(true_np.reshape(-1, 2), pred_np.reshape(-1, 2))
            mae_list.append(mae)

            # ADE (mean L2 error over timesteps)
            ade = np.mean(np.linalg.norm(pred_np - true_np, axis=-1))
            ade_list.append(ade)

            # FDE (L2 error at last time step)
            fde = np.mean(np.linalg.norm(pred_np[:, -1] - true_np[:, -1], axis=-1))
            fde_list.append(fde)

    avg_val_loss = val_loss / len(val_dl)
    avg_mae = np.mean(mae_list)
    avg_ade = np.mean(ade_list)
    avg_fde = np.mean(fde_list)

    print(f"Epoch {epoch:03} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | "
          f"MAE: {avg_mae:.4f} | ADE: {avg_ade:.4f} | FDE: {avg_fde:.4f}")

    # --- Early stopping ---
    if avg_val_loss < best_val_loss - 1e-4:
        best_val_loss = avg_val_loss
        early_stop_counter = 0
        torch.save(model.state_dict(), "best_model.pt")
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch}")
            break


                                                                          

Epoch 001 | Train Loss: 0.3043 | Val Loss: 0.1874 | MAE: 0.1352 | ADE: 0.2192 | FDE: 0.2779


                                                                          

Epoch 002 | Train Loss: 0.1686 | Val Loss: 0.1501 | MAE: 0.0865 | ADE: 0.1448 | FDE: 0.2167


                                                                          

Epoch 003 | Train Loss: 0.1602 | Val Loss: 0.1357 | MAE: 0.0737 | ADE: 0.1249 | FDE: 0.1801


                                                                          

Epoch 004 | Train Loss: 0.1546 | Val Loss: 0.1909 | MAE: 0.1410 | ADE: 0.2196 | FDE: 0.2146


                                                                          

Epoch 005 | Train Loss: 0.1475 | Val Loss: 0.1617 | MAE: 0.0806 | ADE: 0.1332 | FDE: 0.1730


[Epoch 6] Training:  60%|██████    | 4572/7617 [00:13<00:09, 337.09it/s]