In [16]:
import torch
import pydpf
import model
import pathlib
from training_loop import train
import numpy as np
import pandas as pd

In [17]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")   
data_path = pathlib.Path('.').parent.absolute().joinpath('data.csv')
experiment_cuda_rng = torch.Generator(device).manual_seed(0)
experiment_cpu_rng = torch.Generator().manual_seed(0)
#experiments = ['DPF', 'Soft', 'Stop-Gradient', 'Marginal Stop-Gradient', 'Optimal Transport', 'Kernel']
experiments = ['Marginal Stop-Gradient']
n_repeats = 10

In [18]:
def get_SSM():
    alpha = torch.nn.Parameter(torch.rand((1,1), device=device, generator=experiment_cuda_rng), requires_grad=True)
    sigma = torch.nn.Parameter(torch.rand((1,1), device=device, generator=experiment_cuda_rng)*5, requires_grad=True)
    beta = torch.nn.Parameter(torch.rand((1,), device=device, generator=experiment_cuda_rng)*2, requires_grad=True)
    return model.make_SSM(sigma, alpha, beta, device, experiment_cuda_rng), alpha, beta, sigma

In [19]:
def get_DPF(SSM, DPF_type):
    if DPF_type == 'DPF':
        return pydpf.DPF(SSM=SSM, resampling_generator=experiment_cuda_rng)
    if DPF_type == 'Soft':
        return pydpf.SoftDPF(SSM=SSM, resampling_generator=experiment_cuda_rng)
    if DPF_type == 'Stop-Gradient':
        return pydpf.StopGradientDPF(SSM=SSM, resampling_generator=experiment_cuda_rng)
    if DPF_type == 'Marginal Stop-Gradient':
        return pydpf.MarginalStopGradientDPF(SSM=SSM, resampling_generator=experiment_cuda_rng)
    if DPF_type == 'Optimal Transport':
        return pydpf.OptimalTransportDPF(SSM=SSM, regularisation=0.5, transport_gradient_clip=1.)
    if DPF_type == 'Kernel':
        kernel = pydpf.StandardGaussian(1, experiment_cuda_rng, False, True)
        kernel_mixture = pydpf.KernelMixture(kernel, generator=experiment_cuda_rng)
        return pydpf.KernelDPF(SSM=SSM, kernel=kernel_mixture)
    raise ValueError('DPF_type should be one of the allowed options')

In [20]:
result_path = pathlib.Path('.').parent.absolute().joinpath('multiple_parameters_results.csv')
for experiment in experiments:
    ELBOs = np.empty(n_repeats)
    alphas = np.empty(n_repeats)
    betas = np.empty(n_repeats)
    sigmas = np.empty(n_repeats)
    for n in range(n_repeats):
        experiment_cuda_rng = torch.Generator(device).manual_seed(n*10)
        generation_rng = torch.Generator(device).manual_seed(n*10)
        experiment_cpu_rng = torch.Generator().manual_seed(n*10)
        true_SSM = model.make_SSM(torch.tensor([[1.]], device=device), torch.tensor([[0.91]], device=device), torch.tensor([0.5], device=device), device, generation_rng)
        pydpf.simulate_and_save(data_path, SSM=true_SSM, time_extent=1000, n_trajectories=500, batch_size=100, device=device, bypass_ask=True)
        SSM, alpha, beta, sigma = get_SSM()
        dpf = get_DPF(SSM, experiment)
        if experiment == 'Kernel':
            opt = torch.optim.SGD([{'params':[alpha], 'lr':0.05}, {'params':[beta], 'lr':0.1}, {'params':[sigma], 'lr':0.25}, {'params':dpf.resampler.mixture.parameters(), 'lr':0.1}], lr=0.2, momentum=0.9, nesterov=True)
        else:
            opt = torch.optim.SGD([{'params':[alpha], 'lr':0.05}, {'params':[beta], 'lr':0.1}, {'params':[sigma], 'lr':0.25}], lr=0.2, momentum=0.9, nesterov=True)
        opt_schedule = torch.optim.lr_scheduler.ExponentialLR(opt, 0.95)
        dataset = pydpf.StateSpaceDataset(data_path, state_prefix='state', device=device)
        _, ELBO = train(dpf, opt, dataset, 20, (100, 100, 100), (30, 100, 100), (0.5, 0.25, 0.25), 1., experiment_cpu_rng, target='ELBO', time_extent=100, lr_scheduler=opt_schedule)
        ELBOs[n] = ELBO
        alphas[n] = alpha
        betas[n] = beta
        sigmas[n] = sigma
    results = pd.read_csv(result_path, index_col=0)
    row = np.array([np.mean(ELBOs), np.mean(np.abs(alphas - 0.91)), np.mean(np.abs(betas - 0.5)), np.mean(np.abs(sigmas - 1.))])
    results.loc[experiment] = row
    results.to_csv(result_path)

Done                  

epoch 1/20, train loss: -3.197660551071167, validation MSE: 3.3586248397827148, validation ELBO: -130.47197570800782
epoch 2/20, train loss: -3.413773717880249, validation MSE: 1.7812508344650269, validation ELBO: -109.29358215332032
epoch 3/20, train loss: -3.5692336559295654, validation MSE: 1.4107524394989013, validation ELBO: -107.21873321533204
epoch 4/20, train loss: -3.5785078144073488, validation MSE: 1.2746970891952514, validation ELBO: -104.98376159667968
epoch 5/20, train loss: -3.582836322784424, validation MSE: 1.2979756593704224, validation ELBO: -104.84027709960938
epoch 6/20, train loss: -3.5833983612060547, validation MSE: 1.2848281860351562, validation ELBO: -104.79761657714843
epoch 7/20, train loss: -3.584835557937622, validation MSE: 1.3324397802352905, validation ELBO: -105.68380279541016
epoch 8/20, train loss: -3.5829285907745363, validation MSE: 1.3629448890686036, validation ELBO: -105.36447143554688
epoch 9/20, train loss: -3.580100078