In [1]:
import torch
from torch.utils.data import DataLoader, random_split
from TrajectoryDataset import TrajectoryDataset
from DLModels import GraphTrajectoryLSTM
from Trainer import Trainer
from visualizer import visualize_predictions
import pickle
import os

def load_sequences(folder_path):
    all_sequences = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.pkl'):
            file_path = os.path.join(folder_path, filename)
            with open(file_path, 'rb') as f:
                sequences = pickle.load(f)
                all_sequences.extend(sequences)
    print(f"Loaded {len(all_sequences)} sequences")
    return all_sequences

def train():
    # Hyperparameters
    input_sizes = {
        'node_features': 4,
        'position': 2,
        'velocity': 2,
        'steering': 1,
        'object_in_path': 1,
        'traffic_light_detected': 1
    }
    hidden_size = 128
    num_layers = 2
    input_seq_len = 3  # past trajectory length
    output_seq_len = 3  # future prediction length
    batch_size = 128
    num_epochs = 10
    learning_rate = 0.001
    
    # Data loading
    data_folder = "Dataset/Sequence_Dataset"
    dataset = TrajectoryDataset(data_folder)
    
    # Split the dataset
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    print(f"Train set size: {len(train_dataset)}")
    print(f"Test set size: {len(test_dataset)}")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn)
    
    # Model initialization
    model = GraphTrajectoryLSTM(input_sizes, hidden_size, num_layers, input_seq_len, output_seq_len)
    
    # Training
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    trainer = Trainer(model, train_loader, test_loader, learning_rate, device)
    trained_model = trainer.train(num_epochs)
    
    # Save the trained model
    torch.save(trained_model.state_dict(), "graph_trajectory_model_gaussian.pth")
    
    # Evaluate on test set
    test_loss = trainer.validate()
    #print(f"Test Loss: {test_loss:.4f}")

    # Visualization
    all_sequences = load_sequences(data_folder)
    visualize_predictions(trained_model, dataset, device, all_sequences)

def collate_fn(batch):
    past_batch = {k: torch.stack([item[0][k] for item in batch]) for k in batch[0][0].keys()}
    future_batch = {k: torch.stack([item[1][k] for item in batch]) for k in batch[0][1].keys()}
    graph_batch = {
        'node_features': torch.stack([item[2]['node_features'] for item in batch]),
        'adj_matrix': torch.stack([item[2]['adj_matrix'] for item in batch])
    }
    return past_batch, future_batch, graph_batch

train()

Loaded 5661 sequences
Train set size: 4528
Test set size: 1133


Training: 100%|██████████| 36/36 [00:10<00:00,  3.35it/s]



Epoch 1/10
Train Loss: 11.3339, Val Loss: 2.4314
Train Individual Losses:
  position: 3.5006
  velocity: 2.7747
  steering: 1.4292
  object_in_path: 1.9113
  traffic_light_detected: 1.7181
Val Individual Losses:
  position: -0.0591
  velocity: -0.5148
  steering: -0.4210
  object_in_path: 1.8193
  traffic_light_detected: 1.6070


Training: 100%|██████████| 36/36 [00:10<00:00,  3.58it/s]



Epoch 2/10
Train Loss: 1.1997, Val Loss: -3.8597
Train Individual Losses:
  position: -0.8782
  velocity: -0.7389
  steering: -0.6500
  object_in_path: 1.8692
  traffic_light_detected: 1.5977
Val Individual Losses:
  position: -2.8659
  velocity: -2.7691
  steering: -1.6451
  object_in_path: 1.8149
  traffic_light_detected: 1.6056


Training: 100%|██████████| 36/36 [00:10<00:00,  3.58it/s]



Epoch 3/10
Train Loss: -2.8117, Val Loss: -6.3703
Train Individual Losses:
  position: -2.0671
  velocity: -2.9188
  steering: -1.2793
  object_in_path: 1.8649
  traffic_light_detected: 1.5887
Val Individual Losses:
  position: -3.2341
  velocity: -4.7323
  steering: -1.8280
  object_in_path: 1.8151
  traffic_light_detected: 1.6091


Training: 100%|██████████| 36/36 [00:10<00:00,  3.58it/s]



Epoch 4/10
Train Loss: -6.2035, Val Loss: -8.6062
Train Individual Losses:
  position: -2.5883
  velocity: -5.5216
  steering: -1.5461
  object_in_path: 1.8644
  traffic_light_detected: 1.5881
Val Individual Losses:
  position: -3.3309
  velocity: -6.8678
  steering: -1.8225
  object_in_path: 1.8144
  traffic_light_detected: 1.6007


Training: 100%|██████████| 36/36 [00:10<00:00,  3.39it/s]



Epoch 5/10
Train Loss: -8.2861, Val Loss: -9.8835
Train Individual Losses:
  position: -2.8020
  velocity: -7.2606
  steering: -1.6751
  object_in_path: 1.8648
  traffic_light_detected: 1.5868
Val Individual Losses:
  position: -3.4468
  velocity: -7.9539
  steering: -1.8939
  object_in_path: 1.8153
  traffic_light_detected: 1.5959


Training: 100%|██████████| 36/36 [00:09<00:00,  3.62it/s]



Epoch 6/10
Train Loss: -9.2368, Val Loss: -10.5463
Train Individual Losses:
  position: -2.8891
  velocity: -8.0604
  steering: -1.7316
  object_in_path: 1.8679
  traffic_light_detected: 1.5763
Val Individual Losses:
  position: -3.4633
  velocity: -8.5562
  steering: -1.9356
  object_in_path: 1.8166
  traffic_light_detected: 1.5923


Training: 100%|██████████| 36/36 [00:10<00:00,  3.60it/s]



Epoch 7/10
Train Loss: -9.8154, Val Loss: -10.8916
Train Individual Losses:
  position: -2.9474
  velocity: -8.5394
  steering: -1.7685
  object_in_path: 1.8625
  traffic_light_detected: 1.5774
Val Individual Losses:
  position: -3.4523
  velocity: -8.9004
  steering: -1.9388
  object_in_path: 1.8149
  traffic_light_detected: 1.5850


Training: 100%|██████████| 36/36 [00:10<00:00,  3.58it/s]



Epoch 8/10
Train Loss: -9.9591, Val Loss: -10.8849
Train Individual Losses:
  position: -2.9476
  velocity: -8.6864
  steering: -1.7643
  object_in_path: 1.8655
  traffic_light_detected: 1.5737
Val Individual Losses:
  position: -3.4520
  velocity: -8.8740
  steering: -1.9541
  object_in_path: 1.8150
  traffic_light_detected: 1.5802


Training: 100%|██████████| 36/36 [00:09<00:00,  3.63it/s]



Epoch 9/10
Train Loss: -10.3080, Val Loss: -11.5126
Train Individual Losses:
  position: -2.9337
  velocity: -9.0372
  steering: -1.7613
  object_in_path: 1.8648
  traffic_light_detected: 1.5594
Val Individual Losses:
  position: -3.4536
  velocity: -9.5417
  steering: -1.9128
  object_in_path: 1.8148
  traffic_light_detected: 1.5807


Training: 100%|██████████| 36/36 [00:09<00:00,  3.62it/s]



Epoch 10/10
Train Loss: -10.0399, Val Loss: -7.7596
Train Individual Losses:
  position: -2.9834
  velocity: -8.7205
  steering: -1.7619
  object_in_path: 1.8716
  traffic_light_detected: 1.5543
Val Individual Losses:
  position: -3.2724
  velocity: -5.9739
  steering: -1.8840
  object_in_path: 1.8158
  traffic_light_detected: 1.5549


TypeError: unsupported format string passed to tuple.__format__