In [1]:
import sys, os
import torch
from torch import Tensor
from torch.utils.data import DataLoader, random_split
from trajdata import AgentBatch, UnifiedDataset

sys.path.append(os.path.abspath("../"))
from src.data.batch_proccessing import make_model_collate
from src.models.ode_baseline import ODEBaseline
from torch import nn
from tqdm import tqdm
import pytorch_lightning as pl
pl.seed_everything(42)

  import pkg_resources
Global seed set to 42


42

In [2]:
dataset = UnifiedDataset(
        desired_data=["eupeds_eth","eupeds_hotel","eupeds_univ","eupeds_zara1","eupeds_zara2"],
        data_dirs={
            "eupeds_eth":  "../data/eth",
            "eupeds_hotel":"../data/eth",
            "eupeds_univ": "../data/eth",
            "eupeds_zara1":"../data/eth",
            "eupeds_zara2":"../data/eth",
        },
        desired_dt=0.1,
        state_format='x,y',
        obs_format='x,y',
        centric="scene",
        history_sec=(0.8,0.8),
        future_sec=(0.8,0.8),
        standardize_data=False,
    )
collate_fn = make_model_collate(dataset=dataset, memory=4, dim=2)

In [3]:
N = len(dataset)
n_test = int(0.1 * N)
n_train = N - n_test
train_ds, test_ds = random_split(dataset, [n_train, n_test],
                                     generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True,
                              num_workers=os.cpu_count(), collate_fn=collate_fn, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False,
                              num_workers=os.cpu_count(), collate_fn=collate_fn, pin_memory=True)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
model = ODEBaseline(dim=2, w=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

In [None]:
EPOCHS = 5
for epoch in tqdm(range(1, EPOCHS + 1)):
    model.train()
    total_loss = 0.0
    for batch in train_loader:
        if batch is None:
            continue
        x0, x0_class, x1_full, x1_next, t0, t1 = batch
        x0, x1_next, t1 = x0.to(device), x1_next.to(device), t1.to(device)
        t1_scalar = t1.view(-1)[0]

        optimizer.zero_grad()
        x_pred = model(x0, t1_scalar)
        loss = criterion(x_pred, x1_next.squeeze())
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x0.size(0)

    mean_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch:03d}: train_loss = {mean_loss:.6f}")

    # ---- simple test each epoch ----
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            if batch is None:
                continue
            x0, x0_hist, x1, t0, t1 = batch
            x0, x1, t1 = x0.to(device), x1.to(device), t1.to(device)
            t1_scalar = t1.view(-1)[0]
            x_pred = model(x0, t1_scalar)
            loss = criterion(x_pred, x1.squeeze())
            test_loss += loss.item() * x0.size(0)

    test_loss /= len(test_loader.dataset)
    print(f"          test_loss = {test_loss:.6f}")

print("Training finished.")
torch.save(model.state_dict(), "ode_baseline.pth")

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 001: train_loss = 0.077097


 20%|██        | 1/5 [03:53<15:32, 233.22s/it]

          test_loss = 0.076630
Epoch 002: train_loss = 0.076009


 40%|████      | 2/5 [07:46<11:40, 233.54s/it]

          test_loss = 0.076421
Epoch 003: train_loss = 0.075840


 60%|██████    | 3/5 [11:41<07:47, 233.81s/it]

          test_loss = 0.076265
Epoch 004: train_loss = 0.075691


 80%|████████  | 4/5 [15:36<03:54, 234.24s/it]

          test_loss = 0.076115
Epoch 005: train_loss = 0.075525


100%|██████████| 5/5 [19:30<00:00, 234.05s/it]

          test_loss = 0.076010
Training finished.



