In [1]:
import mxnet as mx
import numpy as np

In [2]:
n_state_dims = 1
n_action_dims = 1
n_total_input_dims = n_state_dims + n_action_dims

In [3]:
size = 500
X = np.random.rand(size, n_total_input_dims) * 10
# Y = np.random.rand(size, n_state_dims)
Y = (X[:,0] + X[:,1])[:, None]

In [4]:
Y.shape

(500, 1)

In [5]:
from mxfusion.common import config
config.DEFAULT_DTYPE = 'float64'

In [6]:
from mxfusion import Model, Variable
from mxfusion.components.variables import PositiveTransformation
from mxfusion.components.distributions.gp.kernels import RBF
from mxfusion.modules.gp_modules import GPRegression


m = Model()
m.N = Variable()
m.X = Variable(shape=(m.N, n_total_input_dims))
m.noise_var = Variable(shape=(1,), transformation=PositiveTransformation(), initial_value=0.01)
m.kernel = RBF(input_dim=n_total_input_dims, variance=1, lengthscale=1)
m.Y = GPRegression.define_variable(X=m.X, kernel=m.kernel, noise_var=m.noise_var, shape=(m.N, n_state_dims))

In [7]:
import mxnet as mx
from mxfusion.inference import GradBasedInference, MAP

infr = GradBasedInference(inference_algorithm=MAP(model=m, observed=[m.X, m.Y]))
infr.run(X=mx.nd.array(X, dtype='float64'), Y=mx.nd.array(Y, dtype='float64'), 
         max_iter=100, learning_rate=0.05, verbose=True)



Iteration 11 loss: 227.56065261411027
Iteration 21 loss: -216.23025838967354
Iteration 31 loss: -489.17760037412126
Iteration 41 loss: -693.3077083550918
Iteration 51 loss: -866.6408734765444
Iteration 61 loss: -1024.3471825672368
Iteration 71 loss: -1173.0854992340664
Iteration 81 loss: -1316.0865405219145
Iteration 91 loss: -1455.0728049039021
Iteration 100 loss: -1577.5677420449078

In [8]:
n_time_steps = 10
initial_state = mx.nd.array([[1.]], dtype='float64')
linear_policy = mx.gluon.nn.Dense(1, in_units=n_state_dims, dtype='float64')
linear_policy.collect_params().initialize(mx.init.Constant(0))

In [9]:
class CostFunction(mx.gluon.HybridBlock):
    def hybrid_forward(self, F, x):
        return F.sum(F.abs(x - 1), axis=-1)
    
cost = CostFunction()

In [10]:
from mxfusion.inference import GradTransferInference, ModelBasedAlgorithm, BatchInferenceLoop
mb_alg = ModelBasedAlgorithm(model=m, 
                             observed=[m.X], 
                             cost_function=cost, 
                             policy=linear_policy, 
                             n_time_steps=n_time_steps,
                             initial_state=initial_state)
infr_pred = GradTransferInference(mb_alg, 
                              infr_params=infr.params)

In [16]:
n_prints=10
max_iter = 100
verbose=True
trainer = mx.gluon.Trainer(linear_policy.collect_params(),
                           optimizer='adam',
                           optimizer_params={'learning_rate':
                                             1e-3})
iter_step = max(max_iter // n_prints, 1)
for i in range(max_iter):
    with mx.autograd.record():
        loss_for_gradient = infr_pred.run(X=mx.nd.array(X, dtype='float64'), verbose=True)[0]
        loss_for_gradient.backward()
        for p in linear_policy.collect_params().values():
            print(p.grad(), p.data())
    if verbose:
        print('\rIteration {} loss: {}'.format(i + 1, loss_for_gradient.asscalar()),
              end='')
        if i % iter_step == 0 and i > 0:
            print()
    trainer.step(batch_size=1, ignore_stale_grad=True)




[[-152.77540911]]
<NDArray 1x1 @cpu(0)> 
[[-0.00963272]]
<NDArray 1x1 @cpu(0)>

[-155.15811178]
<NDArray 1 @cpu(0)> 
[0.00545203]
<NDArray 1 @cpu(0)>
Iteration 1 loss: 0.6756819160650106
[[-126.54064346]]
<NDArray 1x1 @cpu(0)> 
[[-0.00863272]]
<NDArray 1x1 @cpu(0)>

[-127.14768806]
<NDArray 1 @cpu(0)> 
[0.00645204]
<NDArray 1 @cpu(0)>
Iteration 2 loss: 0.27397350627524486
[[87.60351719]]
<NDArray 1x1 @cpu(0)> 
[[-0.00764197]]
<NDArray 1x1 @cpu(0)>

[86.89072975]
<NDArray 1 @cpu(0)> 
[0.00744201]
<NDArray 1 @cpu(0)>
Iteration 3 loss: 0.2891619995082866
[[159.68648655]]
<NDArray 1x1 @cpu(0)> 
[[-0.00719971]]
<NDArray 1x1 @cpu(0)>

[157.40243849]
<NDArray 1 @cpu(0)> 
[0.00789005]
<NDArray 1 @cpu(0)>
Iteration 4 loss: 0.6111319629757708
[[52.46965811]]
<NDArray 1x1 @cpu(0)> 
[[-0.00725296]]
<NDArray 1x1 @cpu(0)>

[51.88041686]
<NDArray 1 @cpu(0)> 
[0.00784794]
<NDArray 1 @cpu(0)>
Iteration 5 loss: 0.3396051112975562
[[-20.24317063]]
<NDArray 1x1 @cpu(0)> 
[[-0.00740158]]
<NDArray 1x1 @cpu

Iteration 45 loss: 0.5122537707107756
[[-27.70681858]]
<NDArray 1x1 @cpu(0)> 
[[-0.00747371]]
<NDArray 1x1 @cpu(0)>

[-28.9050933]
<NDArray 1 @cpu(0)> 
[0.00806363]
<NDArray 1 @cpu(0)>
Iteration 46 loss: 0.3897505214286269
[[56.95331834]]
<NDArray 1x1 @cpu(0)> 
[[-0.00725042]]
<NDArray 1x1 @cpu(0)>

[55.80731155]
<NDArray 1 @cpu(0)> 
[0.0082995]
<NDArray 1 @cpu(0)>
Iteration 47 loss: 0.46708184373578443
[[145.96491697]]
<NDArray 1x1 @cpu(0)> 
[[-0.00712327]]
<NDArray 1x1 @cpu(0)>

[143.72758569]
<NDArray 1 @cpu(0)> 
[0.00843949]
<NDArray 1 @cpu(0)>
Iteration 48 loss: 0.7101628772896161
[[157.35921121]]
<NDArray 1x1 @cpu(0)> 
[[-0.00719804]]
<NDArray 1x1 @cpu(0)>

[155.94143904]
<NDArray 1 @cpu(0)> 
[0.00837854]
<NDArray 1 @cpu(0)>
Iteration 49 loss: 0.46454890978841534
[[159.14268659]]
<NDArray 1x1 @cpu(0)> 
[[-0.00745782]]
<NDArray 1x1 @cpu(0)>

[157.11949532]
<NDArray 1 @cpu(0)> 
[0.00813205]
<NDArray 1 @cpu(0)>
Iteration 50 loss: 0.6670247377447207
[[-38.8730461]]
<NDArray 1x1 @cpu(

Iteration 89 loss: 0.5433009136065516
[[-145.08769272]]
<NDArray 1x1 @cpu(0)> 
[[-0.00900757]]
<NDArray 1x1 @cpu(0)>

[-145.89952428]
<NDArray 1 @cpu(0)> 
[0.00709773]
<NDArray 1 @cpu(0)>
Iteration 90 loss: 0.3287068673745539
[[-77.18047766]]
<NDArray 1x1 @cpu(0)> 
[[-0.00886503]]
<NDArray 1x1 @cpu(0)>

[-77.70745544]
<NDArray 1 @cpu(0)> 
[0.00725411]
<NDArray 1 @cpu(0)>
Iteration 91 loss: 0.2643481296510367

[[-99.28003347]]
<NDArray 1x1 @cpu(0)> 
[[-0.00864594]]
<NDArray 1x1 @cpu(0)>

[-100.91872629]
<NDArray 1 @cpu(0)> 
[0.00748642]
<NDArray 1 @cpu(0)>
Iteration 92 loss: 0.4385847916083434
[[-120.80044243]]
<NDArray 1x1 @cpu(0)> 
[[-0.00833278]]
<NDArray 1x1 @cpu(0)>

[-121.81381758]
<NDArray 1 @cpu(0)> 
[0.0078135]
<NDArray 1 @cpu(0)>
Iteration 93 loss: 0.42239143521393463
[[53.49249034]]
<NDArray 1x1 @cpu(0)> 
[[-0.0079117]]
<NDArray 1x1 @cpu(0)>

[51.61457346]
<NDArray 1 @cpu(0)> 
[0.00824833]
<NDArray 1 @cpu(0)>
Iteration 94 loss: 0.5086641958867164
[[20.16312995]]
<NDArray 1x1 

In [17]:
# with mx.autograd.record():
#     res = infr_pred.run(X=mx.nd.array(X, dtype='float64'), verbose=True)[0]
#     res.backward()
#     print(res)

In [18]:
infr_pred.params.param_dict

(
  Parameter fb747366_3ae8_4dc9_af11_617e73f1e83e (shape=(1,), dtype=float64)
  Parameter acea51df_b4f6_4909_b4b6_9aaece3c6fa0 (shape=(1,), dtype=float64)
  Parameter 1bc1f90e_9e4d_472c_bfbc_27e148263f9b (shape=(1,), dtype=float64)
  Parameter 1858779c_b9a9_4dfd_a251_24f8eb58eefb (shape=(500, 500), dtype=float64)
  Parameter 6ab8e233_6de5_4d8c_ae69_14f422a2e1ec (shape=(500, 1), dtype=float64)
  Parameter 82224a32_fac4_4980_86a3_46b6ddcb96af (shape=(500, 2), dtype=float64)
)

In [19]:
for p in linear_policy.collect_params().values():
    print(p.data())


[[-0.00790285]]
<NDArray 1x1 @cpu(0)>

[0.00836033]
<NDArray 1 @cpu(0)>
