In [1]:
from new_tools import *
from tools import animate_trajectories

In [2]:
# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x_train = generate_sequences(n_sequences=2000, T=30)  # adjust as you like
x_train = x_train.to(device)

model = DKF(x_dim=2, z_dim=4, h_dim=64).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

batch_size = 64
n_epochs = 60

def get_batches(x, batch_size):
    N = x.size(0)
    idx = torch.randperm(N)
    for i in range(0, N, batch_size):
        yield x[idx[i:i+batch_size]]

for epoch in range(1, n_epochs+1):
    model.train()
    epoch_loss = 0.0
    n_batches = 0
    for batch in get_batches(x_train, batch_size):
        optimizer.zero_grad()
        loss, elbo = model(batch)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        n_batches += 1
    print(f"Epoch {epoch:03d} | loss = {epoch_loss / n_batches:.4f}")


Epoch 001 | loss = 7346.8047
Epoch 002 | loss = 2268.9953
Epoch 003 | loss = 811.6812
Epoch 004 | loss = 383.4318
Epoch 005 | loss = 279.5998
Epoch 006 | loss = 253.3435
Epoch 007 | loss = 239.8974
Epoch 008 | loss = 230.4829
Epoch 009 | loss = 221.4479
Epoch 010 | loss = 214.3080
Epoch 011 | loss = 208.8311
Epoch 012 | loss = 204.5058
Epoch 013 | loss = 201.3215
Epoch 014 | loss = 198.7505
Epoch 015 | loss = 197.6286
Epoch 016 | loss = 196.4361
Epoch 017 | loss = 195.3031
Epoch 018 | loss = 195.1415
Epoch 019 | loss = 194.3427
Epoch 020 | loss = 193.0683
Epoch 021 | loss = 192.4998
Epoch 022 | loss = 191.8087
Epoch 023 | loss = 190.9330
Epoch 024 | loss = 190.4963
Epoch 025 | loss = 190.0400
Epoch 026 | loss = 189.7372
Epoch 027 | loss = 188.3924
Epoch 028 | loss = 188.5069
Epoch 029 | loss = 187.8763
Epoch 030 | loss = 187.3302
Epoch 031 | loss = 187.5014
Epoch 032 | loss = 185.9418
Epoch 033 | loss = 185.5559
Epoch 034 | loss = 184.9533
Epoch 035 | loss = 184.5643
Epoch 036 | loss =

In [9]:
# Generate one test trajectory
sim = BouncingBallSim(pos_start=[15, 15], vel_start=[2.0, 1.0])
T_test = 60
obs_list = []
gt_list = []

for _ in range(T_test):
    obs, _, gt = sim.step()
    obs_list.append(obs)
    gt_list.append(gt)

observations = np.array(obs_list)   # (T, 2)
ground_truth = np.array(gt_list)    # (T, 2)


In [10]:
model.eval()
with torch.no_grad():
    x_test = torch.from_numpy(observations).float().unsqueeze(0).to(device)  # (1, T, 2)

    # posterior mean of z_t
    z_mu = model.infer_posterior_mean(x_test)   # (T, z_dim), numpy

    # send it back to torch for decoding
    z_mu_torch = torch.from_numpy(z_mu).float().to(device)   # (T, z_dim)
    x_mu_torch = model.decoder(z_mu_torch)                   # (T, 2)
    dkf_pos_est = x_mu_torch.cpu().numpy()                   # (T, 2)


In [11]:
from IPython.display import HTML

html_anim = animate_trajectories(
    ground_truth=ground_truth,
    estimates=[dkf_pos_est],
    labels=["DKF mean"]
)
HTML(html_anim.data)
