In [None]:
import mxnet as mx
import numpy as np

In [None]:
n_state_dims = 1
n_action_dims = 1
n_total_input_dims = n_state_dims + n_action_dims

In [None]:
size = 500
X = np.random.rand(size, n_total_input_dims) * 10
# Y = np.random.rand(size, n_state_dims)
Y = (X[:,0] + X[:,1])[:, None]

In [None]:
Y.shape

In [None]:
from mxfusion.common import config
config.DEFAULT_DTYPE = 'float64'

In [None]:
from mxfusion import Model, Variable
from mxfusion.components.variables import PositiveTransformation
from mxfusion.components.distributions.gp.kernels import RBF
from mxfusion.modules.gp_modules import GPRegression, GPRegressionSamplingPrediction


m = Model()
m.N = Variable()
m.X = Variable(shape=(m.N, n_total_input_dims))
m.noise_var = Variable(shape=(1,), transformation=PositiveTransformation(), initial_value=0.01)
m.kernel = RBF(input_dim=n_total_input_dims, variance=1, lengthscale=1)
m.Y = GPRegression.define_variable(X=m.X, kernel=m.kernel, noise_var=m.noise_var, shape=(m.N, n_state_dims))

gp = m.Y.factor
gp.attach_prediction_algorithms(targets=gp.output_names, conditionals=gp.input_names,
            algorithm=GPRegressionSamplingPrediction(
                gp._module_graph, gp._extra_graphs[0], [gp._module_graph.X]), 
            alg_name='gp_predict')

In [None]:
import mxnet as mx
from mxfusion.inference import GradBasedInference, MAP

infr = GradBasedInference(inference_algorithm=MAP(model=m, observed=[m.X, m.Y]))
infr.run(X=mx.nd.array(X, dtype='float64'), Y=mx.nd.array(Y, dtype='float64'), 
         max_iter=200, learning_rate=.1, verbose=True)

In [None]:
n_time_steps = 10
initial_state = mx.nd.array([[10.]], dtype='float64') # TODO want a proposal distribution here instead of same.
linear_policy = mx.gluon.nn.Dense(1, in_units=n_state_dims, dtype='float64')
linear_policy.collect_params().initialize(mx.init.Constant(0))

In [None]:
class CostFunction(mx.gluon.HybridBlock):
    def hybrid_forward(self, F, x):
        return F.sum((x - 1)**2, axis=-1)
    
cost = CostFunction()

In [None]:
from mxfusion.inference import GradTransferInference, ModelBasedAlgorithm, BatchInferenceLoop
mb_alg = ModelBasedAlgorithm(model=m, 
                             observed=[m.X], 
                             cost_function=cost, 
                             policy=linear_policy, 
                             n_time_steps=n_time_steps,
                             initial_state=initial_state, num_samples=10)
infr_pred = GradTransferInference(mb_alg, 
                              infr_params=infr.params, train_params=linear_policy.collect_params())
infr_pred.run(max_iter=500, X=mx.nd.array(X, dtype='float64'), verbose=True, learning_rate=1e-1)

In [None]:
# n_prints=10
# max_iter = 5
# verbose=True
# trainer = mx.gluon.Trainer(linear_policy.collect_params(),
#                            optimizer='adam',
#                            optimizer_params={'learning_rate':
#                                              1e-1})
# iter_step = max(max_iter // n_prints, 1)
# for i in range(max_iter):
#     with mx.autograd.record():
#         loss_for_gradient = infr_pred.run(X=mx.nd.array(X, dtype='float64'), verbose=True)[0]
#         loss_for_gradient.backward()
#         for p in linear_policy.collect_params().values():
#             print(p.grad(), p.data())
#     if verbose:
#         print('\rIteration {} loss: {}'.format(i + 1, loss_for_gradient.asscalar()),
#               end='')
#         if i % iter_step == 0 and i > 0:
#             print()
#     trainer.step(batch_size=1, ignore_stale_grad=True)