# Stochastic volatility

In [12]:
import gc
import os, sys
# add to path
sys.path.append("../..")

import attr
import datetime

from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import seaborn

import tensorflow as tf
import tensorflow_probability as tfp
import tqdm

tfd = tfp.distributions

seaborn.set()
tf.random.set_seed(50)

In [13]:
from filterflow.smc import SMC
from filterflow.base import State, StateSeries, DTYPE_TO_OBSERVATION_SERIES

from filterflow.observation.base import ObservationModelBase, ObservationSampler

from filterflow.transition.random_walk import RandomWalkModel
from filterflow.transition.base import TransitionModelBase
from filterflow.proposal import BootstrapProposalModel
from filterflow.proposal.auxiliary_proposal import AuxiliaryProposal

from filterflow.resampling.criterion import NeffCriterion, AlwaysResample, NeverResample
from filterflow.resampling.standard import SystematicResampler, MultinomialResampler
from filterflow.resampling.differentiable import RegularisedTransform, CorrectedRegularizedTransform
from filterflow.resampling.differentiable.ricatti.solver import RicattiSolver

from filterflow.resampling.base import NoResampling

from filterflow.state_space_model import StateSpaceModel

### We define the model

In [24]:
class StochVolObservationModel(ObservationModelBase):
    def __init__(self, beta_tensor: tf.Tensor, name='StochVolObservationModel'):
        super(StochVolObservationModel, self).__init__(name=name)
        self._error_rv = tfd.MultivariateNormalDiag(loc=tf.zeros_like(beta_tensor), scale_diag=beta_tensor)
    
    def loglikelihood(self, state: State, observation: tf.Tensor):
        error = observation / tf.math.exp(state.particles/2)
        return self._error_rv.log_prob(error)

class StochVolTransition(TransitionModelBase):
    def __init__(self, mu_tensor: tf.Tensor, phi_tensor: tf.Tensor, transition_noise_stddev: tfp.distributions.Distribution, name='StochVolTransition'):
        super(StochVolTransition, self).__init__(name=name)
        self._drift = mu_tensor * (1-phi_tensor)
        self._random_walk_model = RandomWalkModel(phi, transition_noise_stddev)
        
    def loglikelihood(self, prior_state: State, proposed_state: State, inputs: tf.Tensor):
        undrifted_particles = proposed_state.particles - self._drift
        undrifted_state = attr.evolve(proposed_state, particles=undrifted_particles)
        return self._random_walk_model.loglikelihood(prior_state, undrifted_state, inputs)
    
    def sample(self, state: State, inputs: tf.Tensor):
        zero_drift_sampled_state = self._random_walk_model.sample(state, inputs)
        drifted_particles = zero_drift_sampled_state.particles + self._drift
        sampled_state = attr.evolve(zero_drift_sampled_state, particles=drifted_particles)
        return sampled_state

### We get the data

In [22]:
tf.random.set_seed(0)
observation_dim = 2
state_dim = 2
T = 150
initial_state          = tf.random.uniform([state_dim], -1., 1.)

transition_matrix      = tf.eye(state_dim) * 0.5
transition_covariance  = tf.eye(state_dim) * 0.5

chol_transition_covariance = tf.linalg.cholesky(transition_covariance)

observation_matrix     = tf.eye(observation_dim) * 0.5
observation_covariance = tf.eye(observation_dim) * 0.1

chol_observation_covariance = tf.linalg.cholesky(observation_covariance)