In [41]:
import math
from tqdm import tqdm
import torch
import pydpf
import model
import pathlib
from time import time
import pandas as pd
import numpy as np
from copy import deepcopy
from math import sqrt

In [42]:
dx = 25
dy = 1
n_repeats = 5
data_path = pathlib.Path('.').parent.absolute().joinpath('data.csv')
result_path = pathlib.Path('.').parent.absolute().joinpath('proposal_learning_results.csv')
Ks = [None, 25, 100, 1000, 10000]
generate_data = False
#experiment_list = ['Bootstrap', 'Optimal', 'DPF', 'Soft', 'Stop-Gradient', 'Marginal Stop-Gradient', 'Optimal Transport', 'Kernel']
experiment_list = ['Optimal Transport', 'Kernel']
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

cuda_gen = torch.Generator(device=device).manual_seed(0)
cpu_gen = torch.Generator().manual_seed(0)

In [43]:
def make_model_componets(dx, dy, generator, optimal_prop = True):
    dynamic_model = model.GaussianDynamic(dx, generator)
    observation_model = model.GaussianObservation(dx, dy, generator)
    prior_model = model.GaussianPrior(dx, generator)
    if optimal_prop:
        proposal_model = model.GaussianOptimalProposal(dx, dy, generator)
    else:
        proposal_model = model.GaussianLearnedProposal(dx, dy, generator)
    return prior_model, dynamic_model, observation_model, proposal_model

In [44]:
if generate_data:
    gen_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    gen_generator = torch.Generator(device=gen_device).manual_seed(0)
    prior_model, dynamic_model, observation_model, _ = make_model_componets(dx, dy, gen_generator)
    SSM = pydpf.FilteringModel(prior_model=prior_model, dynamic_model=dynamic_model, observation_model=observation_model)
    pydpf.simulate_and_save(data_path, SSM=SSM, time_extent=1000, n_trajectories=2000, batch_size=100, device=gen_device)

In [45]:
def get_DPF(DPF_type, SSM, dim):
    if DPF_type == 'DPF':
        return pydpf.DPF(SSM=SSM, resampling_generator=cuda_gen)
    if DPF_type == 'Soft':
        return pydpf.SoftDPF(SSM=SSM, resampling_generator=cuda_gen)
    if DPF_type == 'Stop-Gradient':
        return pydpf.StopGradientDPF(SSM=SSM, resampling_generator=cuda_gen)
    if DPF_type == 'Marginal Stop-Gradient':
        return pydpf.MarginalStopGradientDPF(SSM=SSM, resampling_generator=cuda_gen)
    if DPF_type == 'Optimal Transport':
        return pydpf.OptimalTransportDPF(SSM=SSM, regularisation=0.5)
    if DPF_type == 'Kernel':
        kernel = pydpf.KernelMixture([('Gaussian', dim)], gradient_estimator='reparameterisation', generator=cuda_gen)
        return pydpf.KernelDPF(SSM=SSM, kernel=kernel)
    raise ValueError('DPF_type should be one of the allowed options')

In [46]:
def training_loop(dpf, epochs, train_loader, validation_loader, repeat):
    ELBO_fun = pydpf.ElBO_Loss()
    if experiment == 'Kernel':
        opt = torch.optim.SGD([{'params': [dpf.SSM.proposal_model.x_weight], 'lr' : 1.}, {'params': [dpf.SSM.proposal_model.y_weight, proposal_model.dist.cholesky_covariance], 'lr' : 1.}, { 'params': dpf.resampler.parameters(), 'lr': 0.01}], lr=.5, momentum=0.9, nesterov=True)
    else:
        opt = torch.optim.SGD([{'params': [dpf.SSM.proposal_model.x_weight], 'lr' : 1.}, {'params': [dpf.SSM.proposal_model.y_weight, proposal_model.dist.cholesky_covariance], 'lr' : 5.}], lr=.5, momentum=0.9, nesterov=True)
    opt_scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.99)
    best_validation_loss = torch.inf
    for epoch in tqdm(range(epochs)):
        dpf.train()
        total_size = 0
        train_loss = []
        for state, observation in train_loader:
            dpf.update()
            opt.zero_grad()
            ELBO = dpf(100, 100, ELBO_fun, observation=observation)
            loss = torch.mean(ELBO)
            loss.backward()
            if experiment == 'Kernel':
                torch.nn.utils.clip_grad_norm_(dpf.resampler.parameters(), 1., norm_type='inf', error_if_nonfinite=True)
            train_loss.append(loss.item()*state.size(1))
            opt.step()
            total_size += state.size(1)
            opt_scheduler.step()
        train_loss = np.sum(np.array(train_loss)) / total_size
        
        dpf.eval()
        dpf.update()
        total_size = 0
        validation_loss = []
        with torch.inference_mode():
            for state, observation in train_loader:
                ELBO = dpf(100, 100, ELBO_fun, observation=observation)
                loss = torch.mean(ELBO)
                validation_loss.append(loss.item()*state.size(1))
                total_size += state.size(1)
            validation_loss = np.sum(np.array(validation_loss)) / total_size
    
        if validation_loss < best_validation_loss:
            best_validation_loss = validation_loss
            best_dict = deepcopy(dpf.state_dict())
        dpf.load_state_dict(best_dict)

In [47]:
def fractional_diff_exp(a, b):
    frac = b-a
    return torch.abs(1 - torch.exp(frac))

def test_dpf(dpf, test_loader, KalmanFilter):
    aggregation_fun = {'ELBO': pydpf.ElBO_Loss(), 'Filtering Mean': pydpf.FilteringMean(), 'Likelihood_factors': pydpf.LogLikelihoodFactors()}
    test_ELBO = []
    epsilon_x = []
    epsilon_l = []
    dpf.update()
    total_size = 0
    for n, p in dpf.named_parameters():
        print(n)
        if p.dim() == 2:
            print(torch.diag(p))
        else:
            print(p)
    with torch.inference_mode():
        for state, observation in test_loader:
            outputs = dpf(n_particles = 100, time_extent=1000, aggregation_function=aggregation_fun, observation=observation)
            test_ELBO.append(outputs['ELBO'].sum().item() * state.size(1))
            kalman_state, kalman_cov, kalman_likelihood = KalmanFilter(observation=observation, time_extent=1000)
            epsilon_x.append(torch.sum((outputs['Filtering Mean'] - kalman_state)**2, dim=-1).mean().item() * state.size(1))
            log_abs_likelihood_error = fractional_diff_exp(kalman_likelihood, outputs['Likelihood_factors'].squeeze()).mean()
            epsilon_l.append(log_abs_likelihood_error.item() * state.size(1))
            total_size += state.size(1)
    return -sum(test_ELBO)/total_size, sum(epsilon_x)/total_size, sum(epsilon_l)/total_size
    

In [48]:
def max_wass_dist(x_weight, y_weight, prop_cov):
    optimal_x_weight = torch.ones(dx, device=device)
    optimal_x_weight[:dy] = .5
    optimal_cov = torch.ones(dx, device=device)
    for i in range(dy):
        optimal_cov[i] = .5
    a = x_weight - optimal_x_weight
    b = y_weight - .5
    if torch.all(a == 0):
        mean_div = torch.zeros(dy, device=device)
    else:
        mean_div = a**2/torch.sum(a**2)
    if torch.all(b == 0):
        y_mean_div_contr = 0
    else:
        y_mean_div_contr = b**2/torch.sum(b**2)
    mean_div[:dy] += y_mean_div_contr
    mean_div = torch.sum(mean_div**2)
    cov_div = torch.sum((optimal_cov + prop_cov - 2*torch.sqrt(optimal_cov*prop_cov)))
    return mean_div + cov_div
    

In [49]:
def chain(*its):
    it_list = []
    for it in its:
        it_list += list(it)
    return it_list
        
def rotate_range(c_repeat, rel_start, rel_end, repeats, total_elements):
    range_rotation_amount = (total_elements // repeats)*c_repeat
    start = (rel_start + range_rotation_amount) % total_elements
    end = (rel_end + range_rotation_amount) % total_elements
    if end == 0:
        return range(start, total_elements)
    if start > end:
        return chain(range(start, total_elements), range(0, end))
    return range(start, end)
    

In [50]:
dataset = pydpf.StateSpaceDataset(data_path=data_path, 
                        series_id_column='series_id', 
                        state_prefix='state', 
                        observation_prefix='observation', 
                        device=device)
if dy > dx:
    raise ValueError('The dimension of the observations cannot be more than the dimension of the states.')

for experiment in experiment_list:
    mean_wass_dist = torch.tensor(0., device=device)
    mean_epsilon_l = 0
    mean_epsilon_x = 0
    mean_ELBO = 0
    for repeat in range(n_repeats):
        cpu_gen = torch.Generator().manual_seed(10*repeat)
        cuda_gen = torch.Generator(device=device).manual_seed(10*repeat)
        train_set = torch.utils.data.Subset(dataset, rotate_range(repeat, 0, 1000, n_repeats, 2000))
        validation_set = torch.utils.data.Subset(dataset, rotate_range(repeat, 1000, 1500, n_repeats, 2000))
        test_set = torch.utils.data.Subset(dataset, rotate_range(repeat, 1500, 2000, n_repeats, 2000))
        prior_model, dynamic_model , observation_model, proposal_model = make_model_componets(dx, dy, cuda_gen, experiment == 'Optimal')
        if experiment == 'Bootstrap':
            SSM = pydpf.FilteringModel(prior_model=prior_model, dynamic_model=dynamic_model, observation_model=observation_model)
            dpf = get_DPF('DPF', SSM, dx)
            if repeat == 0:
                mean_wass_dist = max_wass_dist(torch.ones(dx, device=device), torch.zeros(dy, device=device), torch.ones(dx, device=device)) * n_repeats
        elif experiment == 'Optimal':
            SSM = pydpf.FilteringModel(prior_model=prior_model, dynamic_model=dynamic_model, observation_model=observation_model, proposal_model=proposal_model)
            dpf = get_DPF('DPF', SSM, dx)
        else:
            trained_model = pydpf.FilteringModel(prior_model=prior_model, dynamic_model=dynamic_model, observation_model=observation_model, proposal_model=proposal_model)
            dpf = get_DPF(experiment, trained_model, dx)
            train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, generator=cpu_gen, collate_fn=dataset.collate)
            validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=32, shuffle=False, generator=cpu_gen, collate_fn=dataset.collate)
            training_loop(dpf, 5, train_loader, validation_loader, repeat)
            cholesky_prop_cov = torch.diag(proposal_model.dist.cholesky_covariance)
            prop_cov = cholesky_prop_cov**2
            mean_wass_dist += max_wass_dist(proposal_model.x_weight, proposal_model.y_weight, prop_cov)
            

        test_loader = torch.utils.data.DataLoader(validation_set, batch_size=32, shuffle=False, generator=cpu_gen, collate_fn=dataset.collate)
        kalman_filter = pydpf.KalmanFilter(prior_model=prior_model, dynamic_model=dynamic_model, observation_model=observation_model)
        ELBO, e_x, e_l = test_dpf(dpf, test_loader, kalman_filter)
        mean_ELBO += ELBO
        mean_epsilon_l += e_l
        mean_epsilon_x += e_x
    mean_wass_dist = sqrt(mean_wass_dist.item() / n_repeats)
    mean_ELBO = mean_ELBO / n_repeats
    mean_epsilon_x = mean_epsilon_x / n_repeats
    mean_epsilon_l = mean_epsilon_l / n_repeats
    results_df = pd.read_csv(result_path, index_col=0)
    row = np.array([mean_epsilon_x, mean_epsilon_l, mean_wass_dist, mean_ELBO])
    results_df.loc[experiment] = row
    results_df.to_csv(result_path)
    print(results_df)

 40%|████      | 2/5 [04:57<07:25, 148.56s/it]


KeyboardInterrupt: 