In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from TrajectoryDataset import TrajectoryDataset
from DLModels import GraphTrajectoryLSTM
from Trainer import Trainer
from visualizer import visualize_predictions
import pickle
import os

def load_sequences(folder_path):
    all_sequences = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.pkl'):
            file_path = os.path.join(folder_path, filename)
            with open(file_path, 'rb') as f:
                sequences = pickle.load(f)
                all_sequences.extend(sequences)
    print(f"Loaded {len(all_sequences)} sequences")
    return all_sequences

def train():
    # Hyperparameters
    input_sizes = {
        'node_features': 4,
        'position': 2,
        'velocity': 2,
        'steering': 1,
        'object_in_path': 1,
        'traffic_light_detected': 1
    }
    hidden_size = 128
    num_layers = 2
    input_seq_len = 3  # past trajectory length
    output_seq_len = 3  # future prediction length
    batch_size = 128
    num_epochs = 50
    learning_rate = 0.001
    
    # Data loading
    data_folder = "Dataset/Sequence_Dataset"
    dataset = TrajectoryDataset(data_folder)
    
    # Split the dataset
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    print(f"Train set size: {len(train_dataset)}")
    print(f"Test set size: {len(test_dataset)}")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn)
    
    # Model initialization
    model = GraphTrajectoryLSTM(input_sizes, hidden_size, num_layers, input_seq_len, output_seq_len)
    
    # Training
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    trainer = Trainer(model, train_loader, test_loader, learning_rate, device)
    trained_model = trainer.train(num_epochs)
    
    # Save the trained model
    torch.save(trained_model.state_dict(), "graph_trajectory_model_gaussian.pth")
    
    # Evaluate on test set
    test_loss = trainer.validate()
    print(f"Test Loss: {test_loss:.4f}")

    # Visualization
    all_sequences = load_sequences(data_folder)
    visualize_predictions(trained_model, dataset, device, all_sequences)

def collate_fn(batch):
    past_batch = {k: torch.stack([item[0][k] for item in batch]) for k in batch[0][0].keys()}
    future_batch = {k: torch.stack([item[1][k] for item in batch]) for k in batch[0][1].keys()}
    graph_batch = {
        'node_features': torch.stack([item[2]['node_features'] for item in batch]),
        'adj_matrix': torch.stack([item[2]['adj_matrix'] for item in batch])
    }
    return past_batch, future_batch, graph_batch

train()

Loaded 5661 sequences
Train set size: 4528
Test set size: 1133


Training: 100%|██████████| 36/36 [00:10<00:00,  3.51it/s]


Epoch 1/100, Train Loss: 1694.0852, Val Loss: 1064.3019


Training: 100%|██████████| 36/36 [00:10<00:00,  3.56it/s]


Epoch 2/100, Train Loss: 886.1770, Val Loss: 723.0309


Training: 100%|██████████| 36/36 [00:09<00:00,  3.61it/s]


Epoch 3/100, Train Loss: 660.7337, Val Loss: 663.8830


Training: 100%|██████████| 36/36 [00:10<00:00,  3.56it/s]


Epoch 4/100, Train Loss: 598.7108, Val Loss: 649.0522


Training: 100%|██████████| 36/36 [00:10<00:00,  3.40it/s]


Epoch 5/100, Train Loss: 603.7912, Val Loss: 570.8174


Training: 100%|██████████| 36/36 [00:10<00:00,  3.60it/s]


Epoch 6/100, Train Loss: 568.9760, Val Loss: 536.8103


Training: 100%|██████████| 36/36 [00:09<00:00,  3.65it/s]


Epoch 7/100, Train Loss: 533.9791, Val Loss: 547.4940


Training: 100%|██████████| 36/36 [00:10<00:00,  3.59it/s]


Epoch 8/100, Train Loss: 504.0445, Val Loss: 432.8837


Training: 100%|██████████| 36/36 [00:09<00:00,  3.64it/s]


Epoch 9/100, Train Loss: 431.3593, Val Loss: 431.4651


Training: 100%|██████████| 36/36 [00:09<00:00,  3.64it/s]


Epoch 10/100, Train Loss: 390.9179, Val Loss: 339.5757


Training: 100%|██████████| 36/36 [00:10<00:00,  3.60it/s]


Epoch 11/100, Train Loss: 342.3403, Val Loss: 404.5852


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 12/100, Train Loss: 311.0886, Val Loss: 309.8230


Training: 100%|██████████| 36/36 [00:09<00:00,  3.65it/s]


Epoch 13/100, Train Loss: 271.2093, Val Loss: 215.4867


Training: 100%|██████████| 36/36 [00:09<00:00,  3.62it/s]


Epoch 14/100, Train Loss: 219.2923, Val Loss: 189.9727


Training: 100%|██████████| 36/36 [00:09<00:00,  3.61it/s]


Epoch 15/100, Train Loss: 215.1028, Val Loss: 180.4789


Training: 100%|██████████| 36/36 [00:09<00:00,  3.67it/s]


Epoch 16/100, Train Loss: 101.3025, Val Loss: 218.3227


Training: 100%|██████████| 36/36 [00:09<00:00,  3.60it/s]


Epoch 17/100, Train Loss: 251.0866, Val Loss: 84.5778


Training: 100%|██████████| 36/36 [00:09<00:00,  3.62it/s]


Epoch 18/100, Train Loss: 88.8720, Val Loss: 280.0453


Training: 100%|██████████| 36/36 [00:09<00:00,  3.60it/s]


Epoch 19/100, Train Loss: 109.9570, Val Loss: 69.4432


Training: 100%|██████████| 36/36 [00:09<00:00,  3.65it/s]


Epoch 20/100, Train Loss: -8.1922, Val Loss: 68.8712


Training: 100%|██████████| 36/36 [00:09<00:00,  3.67it/s]


Epoch 21/100, Train Loss: 56.5134, Val Loss: 97.2110


Training: 100%|██████████| 36/36 [00:09<00:00,  3.64it/s]


Epoch 22/100, Train Loss: -68.2323, Val Loss: -102.4574


Training: 100%|██████████| 36/36 [00:09<00:00,  3.61it/s]


Epoch 23/100, Train Loss: 300.6701, Val Loss: 3282.4917


Training: 100%|██████████| 36/36 [00:09<00:00,  3.62it/s]


Epoch 24/100, Train Loss: 696.7898, Val Loss: 339.7554


Training: 100%|██████████| 36/36 [00:09<00:00,  3.67it/s]


Epoch 25/100, Train Loss: 199.7150, Val Loss: 183.2760


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 26/100, Train Loss: 32.6441, Val Loss: 213.5414


Training: 100%|██████████| 36/36 [00:09<00:00,  3.64it/s]


Epoch 27/100, Train Loss: -39.4023, Val Loss: -26.0699


Training: 100%|██████████| 36/36 [00:09<00:00,  3.62it/s]


Epoch 28/100, Train Loss: -111.0168, Val Loss: -105.5951


Training: 100%|██████████| 36/36 [00:09<00:00,  3.65it/s]


Epoch 29/100, Train Loss: -170.3695, Val Loss: -56.3255


Training: 100%|██████████| 36/36 [00:09<00:00,  3.68it/s]


Epoch 30/100, Train Loss: -198.0363, Val Loss: -150.0708


Training: 100%|██████████| 36/36 [00:09<00:00,  3.67it/s]


Epoch 31/100, Train Loss: -289.9885, Val Loss: -315.6989


Training: 100%|██████████| 36/36 [00:09<00:00,  3.63it/s]


Epoch 32/100, Train Loss: -388.6701, Val Loss: 954.4425


Training: 100%|██████████| 36/36 [00:09<00:00,  3.63it/s]


Epoch 33/100, Train Loss: -109.9937, Val Loss: -263.3002


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 34/100, Train Loss: -417.8024, Val Loss: -6.2356


Training: 100%|██████████| 36/36 [00:09<00:00,  3.67it/s]


Epoch 35/100, Train Loss: -276.7150, Val Loss: -397.8456


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 36/100, Train Loss: -534.3649, Val Loss: -523.0569


Training: 100%|██████████| 36/36 [00:09<00:00,  3.64it/s]


Epoch 37/100, Train Loss: -623.8942, Val Loss: -479.5083


Training: 100%|██████████| 36/36 [00:09<00:00,  3.63it/s]


Epoch 38/100, Train Loss: -627.4600, Val Loss: -605.4258


Training: 100%|██████████| 36/36 [00:09<00:00,  3.67it/s]


Epoch 39/100, Train Loss: -459.9621, Val Loss: -413.2849


Training: 100%|██████████| 36/36 [00:09<00:00,  3.67it/s]


Epoch 40/100, Train Loss: -621.3979, Val Loss: -632.1389


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 41/100, Train Loss: -705.4481, Val Loss: -618.9527


Training: 100%|██████████| 36/36 [00:09<00:00,  3.63it/s]


Epoch 42/100, Train Loss: -804.7483, Val Loss: -832.6432


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 43/100, Train Loss: -887.4441, Val Loss: -669.6609


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 44/100, Train Loss: -866.2258, Val Loss: -832.7043


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 45/100, Train Loss: -674.9363, Val Loss: -736.0942


Training: 100%|██████████| 36/36 [00:09<00:00,  3.64it/s]


Epoch 46/100, Train Loss: -950.1647, Val Loss: -991.9484


Training: 100%|██████████| 36/36 [00:09<00:00,  3.62it/s]


Epoch 47/100, Train Loss: -588.6013, Val Loss: -608.9380


Training: 100%|██████████| 36/36 [00:09<00:00,  3.65it/s]


Epoch 48/100, Train Loss: -825.0601, Val Loss: -798.4391


Training: 100%|██████████| 36/36 [00:09<00:00,  3.67it/s]


Epoch 49/100, Train Loss: -946.2619, Val Loss: -873.8851


Training: 100%|██████████| 36/36 [00:10<00:00,  3.47it/s]


Epoch 50/100, Train Loss: -1084.7017, Val Loss: -486.3537


Training: 100%|██████████| 36/36 [00:09<00:00,  3.63it/s]


Epoch 51/100, Train Loss: -957.0410, Val Loss: -966.9374


Training: 100%|██████████| 36/36 [00:09<00:00,  3.65it/s]


Epoch 52/100, Train Loss: -596.4628, Val Loss: -973.7483


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 53/100, Train Loss: -1133.9703, Val Loss: -1027.2591


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 54/100, Train Loss: -1129.1047, Val Loss: -191.0214


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 55/100, Train Loss: -795.7222, Val Loss: -1006.0741


Training: 100%|██████████| 36/36 [00:09<00:00,  3.63it/s]


Epoch 56/100, Train Loss: -1196.3757, Val Loss: -1097.8191


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 57/100, Train Loss: -1225.0013, Val Loss: -1110.1755


Training: 100%|██████████| 36/36 [00:09<00:00,  3.67it/s]


Epoch 58/100, Train Loss: -1263.5033, Val Loss: -1171.6883


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 59/100, Train Loss: -1314.8858, Val Loss: -1185.2631


Training: 100%|██████████| 36/36 [00:09<00:00,  3.66it/s]


Epoch 60/100, Train Loss: -1295.6833, Val Loss: -1103.7763


Training: 100%|██████████| 36/36 [00:09<00:00,  3.65it/s]


Epoch 61/100, Train Loss: -1168.4218, Val Loss: -816.1104


Training: 100%|██████████| 36/36 [00:09<00:00,  3.68it/s]


Epoch 62/100, Train Loss: -1222.4897, Val Loss: -936.9605


Training: 100%|██████████| 36/36 [00:09<00:00,  3.68it/s]


Epoch 63/100, Train Loss: -1239.3564, Val Loss: -1227.2053


Training: 100%|██████████| 36/36 [00:09<00:00,  3.69it/s]


Epoch 64/100, Train Loss: -1319.6064, Val Loss: -1295.4149


Training: 100%|██████████| 36/36 [00:09<00:00,  3.68it/s]


Epoch 65/100, Train Loss: -1375.9440, Val Loss: -1150.3407


Training: 100%|██████████| 36/36 [00:09<00:00,  3.64it/s]


Epoch 66/100, Train Loss: -1382.4530, Val Loss: -1246.2598


Training: 100%|██████████| 36/36 [00:09<00:00,  3.68it/s]


Epoch 67/100, Train Loss: -1371.0446, Val Loss: -1393.5069


Training:   3%|▎         | 1/36 [00:00<00:09,  3.67it/s]