In [2]:
using MPSTime
using Random
using Plots
using PrettyTables
using Zygote
using Distributions
using ITensors
using Base.Threads

rng = Xoshiro(1)

opts = MPSOptions(d = 10, chi_max = 40, sigmoid_transform = false)

ntimepoints = 10
ntrain_instances = 500
ntest_instances = 300

sigma = 0;

In [None]:
X_train, _ = trendy_sine(ntimepoints, ntrain_instances; sigma = sigma, rng = rng);
X_test, _ = trendy_sine(ntimepoints, ntest_instances; sigma = sigma, rng = rng);

In [8]:
function twoD_feature_map(x::Float64)
	state = Vector{Float64}([cos(pi / 2 * x), sin(pi / 2 * x)])
	return state
end

twoD_feature_map (generic function with 1 method)

In [38]:
function adlegendre_feature_map(x::Float64, dim::Int64, norm=true)
    """ N states mapping """

    state = Array{Float64, 1}(undef, dim)
    state[1] = 1
    state[2] = x
    for i in 3:dim
        state[i] = ((2 * i - 1) * x * state[i-1] - (i - 1) * state[i-2]) / i
    end

    if (norm == true)
        state = sqrt(dim + 0.5) .* state
    end

    return state
end  

adlegendre_feature_map (generic function with 3 methods)

In [None]:
mps, info, test_states = fitMPS(X_train, opts);

Generating initial weight MPS with bond dimension χ_init = 4
        using random state 1234.
Initialising train states.
Using 1 iterations per update.
Training KL Div. 14.555960013828624 | Training acc. 1.0.
Using optimiser CustomGD with the "TSGO" algorithm
Starting backward sweeep: [1/5]
Backward sweep finished.
Starting forward sweep: [1/5]
Finished sweep 1. Time for sweep: 35.43s
Training KL Div. -0.6141226945387014 | Training acc. 1.0.


In [None]:
class = 0
instance_idx = 10
impute_sites = [ntimepoints]
method = :ITS
imp = init_imputation_problem(mps,X_test);

In [None]:
imputed_ts, pred_err, target_ts, stats, nice_plots = MPS_impute(imp, class, instance_idx, impute_sites, method; plot_fits = true, num_trajectories = 10, rejection_threshold = 2.5);

In [None]:
pretty_table(stats[1]; header = ["Metric", "Value"], header_crayon = crayon"yellow bold", tf = tf_unicode_rounded)

In [None]:
plot(nice_plots...)