Импорты


In [None]:
import os
import math
import numpy as np
from tqdm import tqdm_notebook as tqdm

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.color_palette("bright")
import matplotlib as mpl
import matplotlib.cm as cm

import torch
from torch import Tensor
from torch import nn
from torch.nn  import functional as F
from torch.autograd import Variable

use_cuda = torch.cuda.is_available()

ОДЕ решатель (метод Эйлера)

In [None]:
def ode_solve(z0, t0, t1, f):
    h_max = 0.05
    n_steps = math.ceil((abs(t1 - t0)/h_max).max().item())

    h = (t1 - t0)/n_steps
    t = t0
    z = z0

    for i_step in range(n_steps):
        z = z + h * f(z, t)
        t = t + h
    return z

Метод сопряженного уравнения

In [None]:
class ODEAdjoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, z0, t, flat_parameters, func):
        assert isinstance(func, ODEF)
        bs, *z_shape = z0.size()
        time_len = t.size(0)

        with torch.no_grad():
            z = torch.zeros(time_len, bs, *z_shape).to(z0)
            z[0] = z0
            for i_t in range(time_len - 1):
                z0 = ode_solve(z0, t[i_t], t[i_t+1], func)
                z[i_t+1] = z0

        ctx.func = func
        ctx.save_for_backward(t, z.clone(), flat_parameters)
        return z

    @staticmethod
    def backward(ctx, dLdz):
        func = ctx.func
        t, z, flat_parameters = ctx.saved_tensors
        time_len, bs, *z_shape = z.size()
        n_dim = np.prod(z_shape)
        n_params = flat_parameters.size(0)

        def augmented_dynamics(aug_z_i, t_i):
            z_i, a = aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim]
            z_i = z_i.view(bs, *z_shape)
            a = a.view(bs, *z_shape)
            with torch.set_grad_enabled(True):
                t_i = t_i.detach().requires_grad_(True)
                z_i = z_i.detach().requires_grad_(True)

                faug = func.forward_with_grad(z_i, t_i, grad_outputs=a)
                func_eval, adfdz, adfdt, adfdp = faug

                adfdz = adfdz if adfdz is not None else torch.zeros(bs, *z_shape)
                adfdp = adfdp if adfdp is not None else torch.zeros(bs, n_params)
                adfdt = adfdt if adfdt is not None else torch.zeros(bs, 1)
                adfdz = adfdz.to(z_i)
                adfdp = adfdp.to(z_i)
                adfdt = adfdt.to(z_i)

            func_eval = func_eval.view(bs, n_dim)
            adfdz = adfdz.view(bs, n_dim)
            return torch.cat((func_eval, -adfdz, -adfdp, -adfdt), dim=1)

        dLdz = dLdz.view(time_len, bs, n_dim)
        with torch.no_grad():
            adj_z = torch.zeros(bs, n_dim).to(dLdz)
            adj_p = torch.zeros(bs, n_params).to(dLdz)
            adj_t = torch.zeros(time_len, bs, 1).to(dLdz)

            for i_t in range(time_len-1, 0, -1):
                z_i = z[i_t]
                t_i = t[i_t]
                f_i = func(z_i, t_i).view(bs, n_dim)

                dLdz_i = dLdz[i_t]
                dLdt_i = torch.bmm(torch.transpose(dLdz_i.unsqueeze(-1), 1, 2),
                                   f_i.unsqueeze(-1))[:, 0]

                adj_z += dLdz_i
                adj_t[i_t] = adj_t[i_t] - dLdt_i

                aug_z = torch.cat((
                    z_i.view(bs, n_dim),
                    adj_z, torch.zeros(bs, n_params).to(z),
                    adj_t[i_t]),
                    dim=-1
                )

                aug_ans = ode_solve(aug_z, t_i, t[i_t-1], augmented_dynamics)

                adj_z[:] = aug_ans[:, n_dim:2*n_dim]
                adj_p[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params]
                adj_t[i_t-1] = aug_ans[:, 2*n_dim + n_params:]

                del aug_z, aug_ans

            dLdz_0 = dLdz[0]
            dLdt_0 = torch.bmm(torch.transpose(dLdz_0.unsqueeze(-1), 1, 2),
                                f_i.unsqueeze(-1))[:, 0]

            adj_z += dLdz_0
            adj_t[0] = adj_t[0] - dLdt_0
        return adj_z.view(bs, *z_shape), adj_t, adj_p, None

NeuralODE

In [None]:
class NeuralODE(nn.Module):
    def __init__(self, func):
        super(NeuralODE, self).__init__()
        assert isinstance(func, ODEF)
        self.func = func

    def forward(self, z0, t=Tensor([0., 1.]), return_whole_sequence=False):
        t = t.to(z0)
        z = ODEAdjoint.apply(z0, t, self.func.flatten_parameters(), self.func)
        if return_whole_sequence:
            return z
        else:
            return z[-1]

ODEF и NNODEF

In [None]:
class ODEF(nn.Module):
    def forward_with_grad(self, z, t, grad_outputs):
        batch_size = z.shape[0]

        out = self.forward(z, t)

        a = grad_outputs
        adfdz, adfdt, *adfdp = torch.autograd.grad(
            (out,), (z, t) + tuple(self.parameters()), grad_outputs=(a),
            allow_unused=True, retain_graph=True
        )
        if adfdp is not None:
            adfdp = torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0)
            adfdp = adfdp.expand(batch_size, -1) / batch_size
        if adfdt is not None:
            adfdt = adfdt.expand(batch_size, 1) / batch_size
        return out, adfdz, adfdt, adfdp

    def flatten_parameters(self):
        p_shapes = []
        flat_parameters = []
        for p in self.parameters():
            p_shapes.append(p.size())
            flat_parameters.append(p.flatten())
        return torch.cat(flat_parameters)

class NNODEF(ODEF):
    def __init__(self, in_dim, hid_dim, time_invariant=False):
        super(NNODEF, self).__init__()
        self.time_invariant = time_invariant
        if time_invariant:
            self.lin1 = nn.Linear(in_dim, hid_dim)
        else:
            self.lin1 = nn.Linear(in_dim + 1, hid_dim)
        self.lin2 = nn.Linear(hid_dim, hid_dim)
        self.lin3 = nn.Linear(hid_dim, in_dim)
        self.elu = nn.ELU(inplace=True)

    def forward(self, x, t):
        if not self.time_invariant:
            x = torch.cat((x, t), dim=-1)
        h = self.elu(self.lin1(x))
        h = self.elu(self.lin2(h))
        out = self.lin3(h)
        return out

def to_np(x):
    return x.detach().cpu().numpy()

Энкодер и декодер

In [None]:
class RNNEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(RNNEncoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.rnn = nn.GRU(input_dim + 1, hidden_dim)
        self.hid2lat = nn.Linear(hidden_dim, 2 * latent_dim)

    def forward(self, x, t):
        t = t.clone()
        t = t.unsqueeze(-1)
        t[1:] = t[:-1] - t[1:]
        t[0] = 0.
        xt = torch.cat((x, t), dim=-1)
        _, h0 = self.rnn(xt.flip((0,)))
        z0 = self.hid2lat(h0[0])
        z0_mean = z0[:, :self.latent_dim]
        z0_log_var = z0[:, self.latent_dim:]
        return z0_mean, z0_log_var

class NeuralODEDecoder(nn.Module):
    def __init__(self, output_dim, hidden_dim, latent_dim):
        super(NeuralODEDecoder, self).__init__()
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        func = NNODEF(latent_dim, hidden_dim, time_invariant=True)
        self.ode = NeuralODE(func)
        self.l2h = nn.Linear(latent_dim, hidden_dim)
        self.h2o = nn.Linear(hidden_dim, output_dim)

    def forward(self, z0, t):
        zs = self.ode(z0, t, return_whole_sequence=True)
        hs = self.l2h(zs)
        xs = self.h2o(hs)
        return xs


Итоговая модель

In [None]:
class ODEVAE(nn.Module):
    def __init__(self, output_dim, hidden_dim, latent_dim):
        super(ODEVAE, self).__init__()
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.encoder = RNNEncoder(output_dim, hidden_dim, latent_dim)
        self.decoder = NeuralODEDecoder(output_dim, hidden_dim, latent_dim)

    def forward(self, x, t, MAP=False):
        z_mean, z_log_var = self.encoder(x, t)
        if MAP:
            z = z_mean
        else:
            z = z_mean + torch.randn_like(z_mean) * torch.exp(0.5 * z_log_var)
        x_p = self.decoder(z, t)
        return x_p, z, z_mean, z_log_var

    def generate_with_seed(self, seed_x, t):
        seed_t_len = seed_x.shape[0]
        z_mean, z_log_var = self.encoder(seed_x, t[:seed_t_len])
        x_p = self.decoder(z_mean, t)
        return x_p

    def predict_future(self, seed_x, t_seed, t_future):
      z_mean, z_log_var = self.encoder(seed_x, t_seed)
      z = z_mean + torch.randn_like(z_mean) * torch.exp(0.5 * z_log_var)
      x_future = self.decoder(z, t_future)
      return x_future


Тестирование sin

In [None]:
t_max = 10
n_points = 500
time_np = np.linspace(0, t_max, num=n_points)
time = torch.from_numpy(time_np[:, None]).to(torch.float32)

orig_traj = torch.sin(time).unsqueeze(-1)
noise_std = 0.02
samp_traj = orig_traj + torch.randn_like(orig_traj) * noise_std

vae = ODEVAE(output_dim=1, hidden_dim=64, latent_dim=6)
if use_cuda:
    vae = vae.cuda()

optim = torch.optim.Adam(vae.parameters(), lr=0.001)

best_loss = float('inf')
patience = 150
patience_counter = 0
save_path = "best_model.pth"

for epoch in range(3000):
    optim.zero_grad()
    if use_cuda:
        samp_traj = samp_traj.cuda()
        time = time.cuda()

    x_p, z, z_mean, z_log_var = vae(samp_traj, time)
    loss = ((samp_traj - x_p) ** 2).mean() + 0.1 * (-0.5 * torch.mean(
        1 + z_log_var - z_mean.pow(2) - torch.exp(z_log_var)))

    loss.backward()
    optim.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

    if loss.item() < best_loss:
        best_loss = loss.item()
        patience_counter = 0

        if use_cuda:
            vae.cpu()
        torch.save(vae.state_dict(), save_path)
        if use_cuda:
            vae.cuda()
    else:
        patience_counter += 1
    if epoch == 499:
      print(f"New best model saved at epoch {epoch} with loss {best_loss}")
    if patience_counter >= patience:
        print(f"Early stopping triggered at epoch {epoch}. No improvement for {patience} epochs.")
        print(f"New best model saved at epoch {epoch} with loss {best_loss}")
        break

if os.path.exists(save_path):
    vae.load_state_dict(torch.load(save_path))
    print("Best model loaded.")

Восстановление

In [None]:
with torch.no_grad():
    generated_traj = vae.generate_with_seed(samp_traj[:50], time)

plt.plot(to_np(time).squeeze(), to_np(orig_traj).squeeze(), label="Original", color='blue')
plt.plot(to_np(time).squeeze(), to_np(generated_traj).squeeze(), label="Generated", color='red', linestyle='--')

plt.legend()
plt.show()

Предсказание

In [None]:
future_time_np = np.linspace(0, t_max * 2, num=100)
future_time = torch.from_numpy(future_time_np[:, None]).to(torch.float32)

t_seed = time
t_future = future_time

if use_cuda:
    seed_x = samp_traj.cuda()
    t_seed = t_seed.cuda()
    t_future = t_future.cuda()
else:
    seed_x = samp_traj.cpu()
    t_seed = t_seed.cpu()
    t_future = t_future.cpu()

with torch.no_grad():
    future_traj = vae.predict_future(seed_x, t_seed, t_future)

plt.figure(figsize=(12, 6))
plt.plot(to_np(time).squeeze(), to_np(orig_traj).squeeze(), label="Original", color='blue')
plt.plot(to_np(future_time).squeeze(), to_np(future_traj).squeeze(), label="Future Prediction", color='green', linestyle='--')

plt.axvline(x=t_seed[-1].item(), color='black', linestyle='--', label="Seed Boundary")

plt.legend()
plt.title("Future Prediction")
plt.show()

Тестирование cos

In [None]:
t_max = 10
n_points = 500
time_np = np.linspace(0, t_max, num=n_points)
time = torch.from_numpy(time_np[:, None]).to(torch.float32)

orig_traj = torch.cos(time).unsqueeze(-1)
noise_std = 0.02
samp_traj = orig_traj + torch.randn_like(orig_traj) * noise_std

vae = ODEVAE(output_dim=1, hidden_dim=64, latent_dim=6)
if use_cuda:
    vae = vae.cuda()

optim = torch.optim.Adam(vae.parameters(), lr=0.001)

best_loss = float('inf')
patience = 150
patience_counter = 0
save_path = "best_model.pth"

for epoch in range(3000):
    optim.zero_grad()
    if use_cuda:
        samp_traj = samp_traj.cuda()
        time = time.cuda()

    x_p, z, z_mean, z_log_var = vae(samp_traj, time)
    loss = ((samp_traj - x_p) ** 2).mean() + 0.1 * (-0.5 * torch.mean(
        1 + z_log_var - z_mean.pow(2) - torch.exp(z_log_var)))

    loss.backward()
    optim.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

    if loss.item() < best_loss:
        best_loss = loss.item()
        patience_counter = 0

        if use_cuda:
            vae.cpu()
        torch.save(vae.state_dict(), save_path)
        if use_cuda:
            vae.cuda()
    else:
        patience_counter += 1
    if epoch == 499:
      print(f"New best model saved at epoch {epoch} with loss {best_loss}")
    if patience_counter >= patience:
        print(f"Early stopping triggered at epoch {epoch}. No improvement for {patience} epochs.")
        print(f"New best model saved at epoch {epoch} with loss {best_loss}")
        break

if os.path.exists(save_path):
    vae.load_state_dict(torch.load(save_path))
    print("Best model loaded.")


Восстановление

In [None]:
with torch.no_grad():
    generated_traj = vae.generate_with_seed(samp_traj[:50], time)

plt.plot(to_np(time).squeeze(), to_np(orig_traj).squeeze(), label="Original", color='blue')
plt.plot(to_np(time).squeeze(), to_np(generated_traj).squeeze(), label="Generated", color='red', linestyle='--')

plt.legend()
plt.show()

Предсказание

In [None]:

future_time_np = np.linspace(0, t_max * 2, num=100)
future_time = torch.from_numpy(future_time_np[:, None]).to(torch.float32)

t_seed = time
t_future = future_time

if use_cuda:
    seed_x = samp_traj.cuda()
    t_seed = t_seed.cuda()
    t_future = t_future.cuda()
else:
    seed_x = samp_traj.cpu()
    t_seed = t_seed.cpu()
    t_future = t_future.cpu()

with torch.no_grad():
    future_traj = vae.predict_future(seed_x, t_seed, t_future)

plt.figure(figsize=(12, 6))
plt.plot(to_np(time).squeeze(), to_np(orig_traj).squeeze(), label="Original", color='blue')
plt.plot(to_np(future_time).squeeze(), to_np(future_traj).squeeze(), label="Future Prediction", color='green', linestyle='--')

plt.axvline(x=t_seed[-1].item(), color='black', linestyle='--', label="Seed Boundary")

plt.legend()
plt.title("Future Prediction")
plt.show()

Подбор гиперпараметров

In [None]:
from itertools import product

t_max = 10
n_points = 1500
time_np = np.linspace(0, t_max, num=n_points)
time = torch.from_numpy(time_np[:, None]).to(torch.float32)
orig_traj = torch.sin(time).unsqueeze(-1)

hyperparams_grid = {
    "hidden_dim": [64, 128],
    "latent_dim": [4, 6, 8],
    "learning_rate": [1e-3, 1e-4],
    "noise_std": [0.01, 0.02, 0.03],
    "time_invariant": [True, False]
}

best_loss = float('inf')
best_params = {}
save_path_base = "best_model_epoch"

for hidden_dim, latent_dim, learning_rate, noise_std, time_invariant in product(*hyperparams_grid.values()):
    print(f"\nОбучение с параметрами: hidden_dim={hidden_dim}, latent_dim={latent_dim}, "
          f"lr={learning_rate}, noise_std={noise_std}, time_invariant={time_invariant}")

    samp_traj = orig_traj + torch.randn_like(orig_traj) * noise_std
    vae = ODEVAE(output_dim=1, hidden_dim=hidden_dim, latent_dim=latent_dim)
    if use_cuda:
        vae = vae.cuda()

    optim = torch.optim.Adam(vae.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()
    patience = 20
    patience_counter = 0
    best_epoch_loss = float('inf')

    for epoch in range(100):
        vae.train()
        optim.zero_grad()
        if use_cuda:
            samp_traj = samp_traj.cuda()
            time = time.cuda()

        x_p, z, z_mean, z_log_var = vae(samp_traj, time)
        loss_recon = criterion(x_p, samp_traj)
        kl_loss = -0.5 * torch.mean(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
        loss = loss_recon + 0.1 * kl_loss

        loss.backward()
        optim.step()

        if loss.item() < best_epoch_loss:
            best_epoch_loss = loss.item()
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            break

    print(f"Лучшая потеря на этих параметрах: {best_epoch_loss}")
    if best_epoch_loss < best_loss:
        best_loss = best_epoch_loss
        best_params = {
            "hidden_dim": hidden_dim,
            "latent_dim": latent_dim,
            "learning_rate": learning_rate,
            "noise_std": noise_std,
            "time_invariant": time_invariant
        }
        torch.save(vae.state_dict(), f"{save_path_base}_hid{hidden_dim}_lat{latent_dim}.pth")

print("\nЛучшие гиперпараметры:")

print(best_params)

Тестирование cos после подбора гиперпараметров

In [None]:
t_max = 10
n_points = 500
time_np = np.linspace(0, t_max, num=n_points)
time = torch.from_numpy(time_np[:, None]).to(torch.float32)

orig_traj = torch.cos(time).unsqueeze(-1)
noise_std = 0.02
samp_traj = orig_traj + torch.randn_like(orig_traj) * noise_std

vae = ODEVAE(output_dim=1, hidden_dim=64, latent_dim=4)
if use_cuda:
    vae = vae.cuda()

optim = torch.optim.Adam(vae.parameters(), lr=0.001)

best_loss = float('inf')
patience = 150
patience_counter = 0
save_path = "best_model.pth"

for epoch in range(3000):
    optim.zero_grad()
    if use_cuda:
        samp_traj = samp_traj.cuda()
        time = time.cuda()

    x_p, z, z_mean, z_log_var = vae(samp_traj, time)
    loss = ((samp_traj - x_p) ** 2).mean() + 0.1 * (-0.5 * torch.mean(
        1 + z_log_var - z_mean.pow(2) - torch.exp(z_log_var)))

    loss.backward()
    optim.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

    if loss.item() < best_loss:
        best_loss = loss.item()
        patience_counter = 0

        if use_cuda:
            vae.cpu()
        torch.save(vae.state_dict(), save_path)
        if use_cuda:
            vae.cuda()
    else:
        patience_counter += 1
    if epoch == 499:
      print(f"New best model saved at epoch {epoch} with loss {best_loss}")
    if patience_counter >= patience:
        print(f"Early stopping triggered at epoch {epoch}. No improvement for {patience} epochs.")
        print(f"New best model saved at epoch {epoch} with loss {best_loss}")
        break

if os.path.exists(save_path):
    vae.load_state_dict(torch.load(save_path))
    print("Best model loaded.")

Восстановление

In [None]:
with torch.no_grad():
    generated_traj = vae.generate_with_seed(samp_traj[:50], time)

plt.plot(to_np(time).squeeze(), to_np(orig_traj).squeeze(), label="Original", color='blue')
plt.plot(to_np(time).squeeze(), to_np(generated_traj).squeeze(), label="Generated", color='red', linestyle='--')

plt.legend()
plt.show()

Предсказание

In [None]:

future_time_np = np.linspace(0, t_max * 2, num=100)
future_time = torch.from_numpy(future_time_np[:, None]).to(torch.float32)

t_seed = time
t_future = future_time

if use_cuda:
    seed_x = samp_traj.cuda()
    t_seed = t_seed.cuda()
    t_future = t_future.cuda()
else:
    seed_x = samp_traj.cpu()
    t_seed = t_seed.cpu()
    t_future = t_future.cpu()

with torch.no_grad():
    future_traj = vae.predict_future(seed_x, t_seed, t_future)

plt.figure(figsize=(12, 6))
plt.plot(to_np(time).squeeze(), to_np(orig_traj).squeeze(), label="Original", color='blue')
plt.plot(to_np(future_time).squeeze(), to_np(future_traj).squeeze(), label="Future Prediction", color='green', linestyle='--')

plt.axvline(x=t_seed[-1].item(), color='black', linestyle='--', label="Seed Boundary")

plt.legend()
plt.title("Future Prediction")
plt.show()

Тестирование sin после подбора гиперпараметров

In [None]:
t_max = 10
n_points = 500
time_np = np.linspace(0, t_max, num=n_points)
time = torch.from_numpy(time_np[:, None]).to(torch.float32)

orig_traj = torch.sin(time).unsqueeze(-1)
noise_std = 0.02
samp_traj = orig_traj + torch.randn_like(orig_traj) * noise_std

vae = ODEVAE(output_dim=1, hidden_dim=64, latent_dim=4)
if use_cuda:
    vae = vae.cuda()

optim = torch.optim.Adam(vae.parameters(), lr=0.001)

best_loss = float('inf')
patience = 150
patience_counter = 0
save_path = "best_model.pth"

for epoch in range(3000):
    optim.zero_grad()
    if use_cuda:
        samp_traj = samp_traj.cuda()
        time = time.cuda()

    x_p, z, z_mean, z_log_var = vae(samp_traj, time)
    loss = ((samp_traj - x_p) ** 2).mean() + 0.1 * (-0.5 * torch.mean(
        1 + z_log_var - z_mean.pow(2) - torch.exp(z_log_var)))

    loss.backward()
    optim.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

    if loss.item() < best_loss:
        best_loss = loss.item()
        patience_counter = 0

        if use_cuda:
            vae.cpu()
        torch.save(vae.state_dict(), save_path)
        if use_cuda:
            vae.cuda()
    else:
        patience_counter += 1
    if epoch == 499:
      print(f"New best model saved at epoch {epoch} with loss {best_loss}")
    if patience_counter >= patience:
        print(f"Early stopping triggered at epoch {epoch}. No improvement for {patience} epochs.")
        print(f"New best model saved at epoch {epoch} with loss {best_loss}")
        break

if os.path.exists(save_path):
    vae.load_state_dict(torch.load(save_path))
    print("Best model loaded.")

Восстановление

In [None]:
with torch.no_grad():
    generated_traj = vae.generate_with_seed(samp_traj[:50], time)

plt.plot(to_np(time).squeeze(), to_np(orig_traj).squeeze(), label="Original", color='blue')
plt.plot(to_np(time).squeeze(), to_np(generated_traj).squeeze(), label="Generated", color='red', linestyle='--')

plt.legend()
plt.show()


Предсказание

In [None]:
future_time_np = np.linspace(0, t_max * 2, num=100)
future_time = torch.from_numpy(future_time_np[:, None]).to(torch.float32)

t_seed = time
t_future = future_time

if use_cuda:
    seed_x = samp_traj.cuda()
    t_seed = t_seed.cuda()
    t_future = t_future.cuda()
else:
    seed_x = samp_traj.cpu()
    t_seed = t_seed.cpu()
    t_future = t_future.cpu()

with torch.no_grad():
    future_traj = vae.predict_future(seed_x, t_seed, t_future)

plt.figure(figsize=(12, 6))
plt.plot(to_np(time).squeeze(), to_np(orig_traj).squeeze(), label="Original", color='blue')
plt.plot(to_np(future_time).squeeze(), to_np(future_traj).squeeze(), label="Future Prediction", color='green', linestyle='--')

plt.axvline(x=t_seed[-1].item(), color='black', linestyle='--', label="Seed Boundary")

plt.legend()
plt.title("Future Prediction")
plt.show()