In [1]:
import sys
sys.path.insert(0, '../')

import math
import os
import numpy as np
from scipy.special import gammaln
import torch
import beer

import warnings
#warnings.filterwarnings('ignore')


# For plotting.
from bokeh.io import show, output_notebook, export_png
from bokeh.plotting import figure, gridplot
from bokeh.models import Arrow, OpenHead, NormalHead, VeeHead
from bokeh.events import Tap, Pan, PanEnd, PanStart
from bokeh.application.handlers import FunctionHandler
from bokeh.application import Application
from bokeh.models import ColumnDataSource
from bokeh.models.widgets import Button, RadioButtonGroup, Dropdown, Slider
from bokeh.layouts import layout
from bokeh.palettes import viridis


output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

# Model

\begin{align}
p(h_k) &= \mathcal{N}\big(h_k | 0, I\big)  \\
p(\psi_k | W, \mu, h_k, \tau) &= \mathcal{N}\big(\psi_k | W^T h_k + \mu, \frac{1}{\tau} I \big) \\
p(\psi_k | W, \mu, \tau) &= \mathcal{N}\big(\psi_k | \mu, W^T W + \frac{1}{\tau} I \big) \\
\pi_{k,i} &= \frac{e^{\psi_{k,i}}}{1 + \sum_{j=1}^{D-1} e^{\psi_{k,j}}}
\end{align}


In [191]:
x_range = y_range = (-5, 5)
nsamples = 500
mu = np.array([1, -.5])
weights = np.array([[1., 2.]])
#weights = np.array([[0., 1.]])
W = weights
tau = 100
C = W.T @ W + (1./tau) * np.identity(2)
L = np.linalg.cholesky(C)
psis = mu + np.random.randn(nsamples, 2) @ L.T
pi = np.exp(psis) / (1 + np.exp(psis).sum(axis=1)[:, None])

precs = 1e-2 + 2 * np.exp(psis[:, -1])
means = psis[:, 0] / precs
n_var = 1/precs
n_mean = means

fig1 = figure(width=400, height=400, x_range=x_range, y_range=y_range)
fig1.circle(psis[:, 0], psis[:, 1], alpha=.5)

fig2 = figure(width=400, height=400, x_range=(0, 1), y_range=(0, 1))
fig2.line([1, 0], [0, 1], line_width=2, color='black')
fig2.circle(pi[:, 0], pi[:, 1], alpha=.5)

show(gridplot([[fig1, fig2]]))

In [192]:
def generate_data(pi, nsamples=100):
    pi = np.c_[pi, 1 - pi.sum(axis=-1)]
    data = [np.random.choice(pi.shape[-1], size=nsamples, p=pi[i])
            for i in range(pi.shape[0])]
    data = torch.LongTensor(data)
        
    # Compute the accumulated sufficient statistics of the 
    # samples for each model.
    T = torch.cat([
        beer.utils.onehot(data[i], max_label=3, 
                          dtype=data.dtype, device=data.device)[:, :-1].sum(dim=0)[None]
        for i in range(data.shape[0])
    ]).float()
    counts = torch.FloatTensor([nsamples] * len(T))

    T = torch.cat([T, counts[:, None]], dim=-1)
    return T

def generate_data_normal(mean, var, nsamples=100):
    noise = np.random.randn(len(mean), nsamples)
    s = mean[:, None] + np.sqrt(var)[:, None] * noise
    T = np.concatenate([s[:, None, :], (s**2)[:, None, :]], axis=1).sum(axis=-1)
    T = torch.from_numpy(T).float()
    counts = torch.FloatTensor([nsamples] * len(T))
    return torch.cat([T, counts[:, None]], dim=-1)

In [193]:
class CategoricalLogLikelihood:
    
    def __call__(self, nparams, stats_counts):
        stats = stats_counts[:, :-1]
        counts = stats_counts[:, -1]
        lnorm = torch.log(1 + nparams.exp().sum(dim=-1))
        return (nparams * stats).sum(dim=-1) - counts * lnorm
    
    def argmax(self, stats_counts):
        stats = stats_counts[:, :-1]
        counts = stats_counts[:, -1]
        pi = stats / counts[:, None]
        pi_K = pi.sum(dim=1)
        eps = torch.tensor(1e-6, dtype=stats.dtype,
                            device=stats.device)
        remainder = torch.max(eps, 1 - pi_K)
        return torch.log(eps + pi) - torch.log(remainder)[:, None]
    
    def hessian(self, nparams, stats_counts, mode='diagonal'):
        stats = stats_counts[:, :-1]
        counts = stats_counts[:, -1]
        mean = torch.exp(nparams) / (1 +  torch.exp(nparams).sum(dim=1)[:, None])
        I = torch.eye(nparams.shape[-1], dtype=nparams.dtype,
                      device=nparams.device)
        hessians = (1e-6 + I)[None] * mean[:, :, None]
        hessians -= (mean[:, :, None] * mean[:, None, :])
        hessians *= -counts[:, None, None]
        if mode == 'full':
            return hessians
        elif mode == 'diagonal':
            idxs = tuple(range(hessians.shape[-1]))
            return hessians[:, idxs, idxs]
        elif mode == 'scalar':
            dim = hessians.shape[-1]
            idxs = tuple(range(dim))
            return hessians[:, idxs, idxs].mean(dim=-1)
        else:
            raise ValueError(f'Unknown Hessian approximation mode: "{mode}"')

class NormalLogLikelihood:
    
    def __call__(self, psi_params, stats_counts):
        stats = stats_counts[:, :-1]
        counts = stats_counts[:, -1]
        nparams = torch.zeros_like(psi_params)
        nparams[:, 0] = psi_params[:, 0]
        nparams[:, -1] = -torch.exp(psi_params[:, -1])
        return (nparams * stats).sum(dim=-1) - counts * self.log_norm(nparams)
    
    def log_norm(self, nparams):
        precs = 2 * torch.exp(nparams[:, -1])
        means = nparams[:, 0] / precs
        return .5 * (means ** 2) * precs - .5 * torch.log(precs) + math.log(2 * math.pi)
    
    def argmax(self, stats_counts):
        stats = stats_counts[:, :-1]
        counts = stats_counts[:, -1]
        mT = stats / counts[:, None]
        means = mT[:, 0]
        precs = 1/(mT[:, -1] - means**2)
        return torch.cat([
            (precs * means)[:, None], 
            torch.log(.5 * precs)[:, None]], 
        dim=-1)

    def hessian(self, psi_params, stats_counts, mode='diagonal'):
        stats = stats_counts[:, :-1]
        counts = stats_counts[:, -1]
        H = torch.zeros(len(psi_params), 2 * 2, dtype=psi_params.dtype, 
                        device=psi_params.device)
        
        e_psi1 = torch.exp(- psi_params[:, -1])
        H[:, 0] = -.5 * counts * e_psi1
        H[:, 1] = .5 * counts * psi_params[:, 0] * e_psi1
        H[:, 2] = .5 * counts * psi_params[:, 0] * e_psi1 
        H[:, 3] = -stats[:, 1] / e_psi1 -.25 * counts * ((psi_params[:, 0])**2) * e_psi1
        hessians = H.reshape(len(psi_params), 2, 2) 
        if mode == 'full':
            return hessians
        elif mode == 'diagonal':
            idxs = tuple(range(hessians.shape[-1]))
            return hessians[:, idxs, idxs]
        elif mode == 'scalar':
            dim = hessians.shape[-1]
            idxs = tuple(range(dim))
            return hessians[:, idxs, idxs].mean(dim=-1)
        else:
            raise ValueError(f'Unknown Hessian approximation mode: "{mode}"')

# Inference

In [194]:
dim = 2
s_dim = 3

llh = CategoricalLogLikelihood()

with torch.no_grad():
    M = torch.zeros(s_dim, dim).float()
    global_mean = torch.zeros(dim).float()
    prec = torch.ones(dim).float() 
    scale = torch.ones(s_dim).float() 
    
model_full = beer.models.GeneralizedSubspaceModel.create(
    llh, 
    M, 
    global_mean, 
    prec, 
    scale,
    hessian_type='full',
    noise_std=0.1,
    prior_strength=1,
    hyper_prior_strength=1e-2,
).double()

model_diagonal = beer.models.GeneralizedSubspaceModel.create(
    llh,
    M, 
    global_mean, 
    prec, 
    scale,
    hessian_type='diagonal',
    noise_std=0.1,
    prior_strength=1e-3,
    hyper_prior_strength=1e-3,
).double()

model_scalar = beer.models.GeneralizedSubspaceModel.create(
    llh,
    M, 
    global_mean, 
    prec, 
    scale,
    hessian_type='scalar',
    noise_std=0.1,
    prior_strength=1,
    hyper_prior_strength=1e-3,
).double()

data = generate_data(pi, nsamples=100).double()
#data = generate_data_normal(n_mean, n_var, nsamples=100).double()

In [195]:
def callback(previous_L, L):
    print(L)
cache = model_full.latent_posteriors(data, max_iter=100, conv_threshold=1e-12, callback=callback)

-42988.83471263356
-42986.76280053244
-42986.757894816714
-42986.75788003472
-42986.75787998203
-42986.757879981815


In [462]:
dim = 2
s_dim = 1

with torch.no_grad():
    M = torch.zeros(s_dim, dim).float()
    global_mean = torch.zeros(dim).float()
    prec = torch.ones(dim).float() 
    scale = torch.ones(s_dim).float()
    
model_full = beer.models.GeneralizedSubspaceModel.create(
    llh, 
    M, 
    global_mean, 
    prec, 
    scale,
    hessian_type='full',
    noise_std=0.1,
    prior_strength=1e-3,
    hyper_prior_strength=1,
).double()

params = model_full.mean_field_factorization()
optim_full = beer.BayesianModelOptimizer([[model_full.subspace, model_full.mean, model_full.precision, model_full.scale]], lrate=1.)

epochs = 100
epochs = list(range(1, epochs + 1))
elbos_full, elbos_diagonal, elbos_scalar= [], [], []
for epoch in epochs:
    optim_full.init_step()
    elbo_full = beer.evidence_lower_bound(model_full, data.double(), max_iter=100, conv_threshold=1e-7, fast_eval=False)
    print(f'{float(elbo_full)/float(torch.sum(data[:, -1])):.5f}', 'kl q(W):', float(model_full.precision.kl_div()))
    
    elbo_full.backward()
    elbos_full.append(float(elbo_full))
    optim_full.step()

-20.84360 kl q(W): 0.0
-0.93887 kl q(W): 16.524280218965032
-0.90157 kl q(W): 16.51615649649193
-0.89109 kl q(W): 16.514176211364884
-0.88623 kl q(W): 16.512998588533037
-0.88577 kl q(W): 16.512873379345706
-0.88571 kl q(W): 16.512857069482607
-0.88570 kl q(W): 16.512854537818612
-0.88570 kl q(W): 16.512854083803404
-0.88570 kl q(W): 16.51285397655022
-0.88570 kl q(W): 16.512853943369578
-0.88570 kl q(W): 16.512853931274094
-0.88570 kl q(W): 16.51285392651164
-0.88570 kl q(W): 16.51285392446448
-0.88570 kl q(W): 16.51285392324121
-0.88570 kl q(W): 16.5128539217078
-0.88570 kl q(W): 16.512853918412816
-0.88570 kl q(W): 16.51285391022259
-0.88570 kl q(W): 16.512853889396524
-0.88569 kl q(W): 16.512853836613544
-0.88568 kl q(W): 16.512853705322414
-0.88564 kl q(W): 16.51285339541471
-0.88556 kl q(W): 16.51285277025022
-0.88538 kl q(W): 16.512852161380692
-0.88503 kl q(W): 16.512855509691235
-0.88452 kl q(W): 16.51288086350371
-0.88389 kl q(W): 16.512972178484972
-0.88325 kl q(W): 16.51318

In [463]:
model_full.mean.expected_value(), model_full.subspace.expected_value(), model_full.precision.expected_value(), model_full.scale.expected_value()

(tensor([ 0.9260, -0.5123], dtype=torch.float64),
 tensor([[0.8464, 1.7258]], dtype=torch.float64),
 tensor([11.2467,  3.3269], dtype=torch.float64),
 tensor([0.7024], dtype=torch.float64))

In [301]:
model_full.precision.prior.expected_value(), model_full.precision.posterior.expected_value()

(tensor([100.0000, 100.0000], dtype=torch.float64),
 tensor([-0.3566,  0.3369], dtype=torch.float64))

In [299]:
model_full.precision.prior.natural_parameters, model_full.precision.posterior.natural_parameters

(tensor([-0.0100, -0.0100,  0.0000,  0.0000], dtype=torch.float64),
 tensor([-306.9253,  478.9771,  250.0000,  250.0000], dtype=torch.float64))

In [240]:
epochs = 30
optim_full = beer.BayesianModelOptimizer(model_full.mean_field_factorization(), lrate=1.)
optim_diagonal = beer.BayesianModelOptimizer(model_diagonal.mean_field_factorization(), lrate=1.)
optim_scalar = beer.BayesianModelOptimizer(model_scalar.mean_field_factorization(), lrate=1.)

epochs = list(range(1, epochs + 1))
elbos_full, elbos_diagonal, elbos_scalar= [], [], []
for epoch in epochs:
    optim_full.init_step()
    optim_diagonal.init_step()
    optim_scalar.init_step()
    elbo_full = beer.evidence_lower_bound(model_full, data, max_iter=200, conv_threshold=1e-5)
    elbo_diagonal = beer.evidence_lower_bound(model_diagonal, data, max_iter=200, conv_threshold=1e-5)
    elbo_scalar = beer.evidence_lower_bound(model_scalar, data, max_iter=200, conv_threshold=1e-5)
    elbo_full.backward()
    elbo_diagonal.backward()
    elbo_scalar.backward()
    elbos_full.append(float(elbo_full) / float(torch.sum(data[:, -1])))
    elbos_diagonal.append(float(elbo_diagonal) / float(torch.sum(data[:, -1])))
    elbos_scalar.append(float(elbo_scalar) / float(torch.sum(data[:, -1])))
    optim_full.step()
    optim_diagonal.step()
    optim_scalar.step()

KeyError: 'params_diag_m2'

In [357]:
fig = figure(width=400, height=400)#, toolbar_location=None)
fig.line(epochs, elbos_full, color='red', line_width=2, legend='full H')
fig.line(epochs, elbos_diagonal, color='blue', line_width=2, legend='diag H')
fig.line(epochs, elbos_scalar, color='green', line_width=2, legend='scalar H')

fig.legend.location = 'bottom_right' 
fig.legend.label_text_font_size = '16pt'
fig.yaxis.visible = False
fig.yaxis.axis_label = 'ELBO'
fig.xaxis.axis_label = 'epoch'
fig.xgrid.visible = False
fig.ygrid.visible = False
fig.xaxis.major_label_text_font_size = '16pt'
fig.xaxis.axis_label_text_font_size = '16pt'
fig.yaxis.major_label_text_font_size = '16pt'
fig.yaxis.axis_label_text_font_size = '16pt'
fig.xaxis.ticker = [0,  10, 20, 30]
fig.min_border_left = 20
fig.min_border_top = 20
fig.min_border_right = 20
show(fig)
#outpath = '/Users/lucasondel/Desktop/gsm_cat_elbo.png'
#export_png(fig, outpath)

In [358]:
mean = -np.array([0, 0])
cov = 1 * np.identity(2)

cov[1, 1] = 5
cov[0, 0] = .1

def det_jacobian(x, y):
    delta = x** 2 + y**2
    det_jac = delta / (np.sqrt(delta) * delta)
    return det_jac

def plot_normal_polar(fig_mag, fig_angle, mean, cov, r_range, a_range, res=1000, r_legend='', a_legend='', **kwargs):
    prec = np.linalg.inv(cov)
    r = np.linspace(r_range[0], r_range[1], res)
    alpha = np.linspace(a_range[0], a_range[1], res)
    R, A = np.meshgrid(r, alpha)
    X, Y = R * np.cos(A), R * np.sin(A)
    Z = - .5 * (prec[0, 0] * (X - mean[0])**2 + prec[1, 1] * (Y - mean[1])**2) 
    Z += prec[0, 1] * (X - mean[0]) * (Y - mean[1])
    Z -= .5 * np.log(np.linalg.det(cov))
    Z = np.exp(Z) * det_jacobian(X, Y)
    
    dr = r_range[1] - r_range[0] 
    r_vals = Z.sum(axis=0) 
    norm = r_vals.sum() * (dr / Z.shape[1])
    r_vals /= norm
    fig_mag.line(r, r_vals, legend=r_legend, **kwargs)
    
    dalpha = a_range[1] - a_range[0] 
    alpha_vals = Z.sum(axis=1) 
    norm = alpha_vals.sum() * (dalpha / Z.shape[0])
    alpha_vals /= norm
    fig_angle.line(range(len(alpha_vals)), alpha_vals, legend=a_legend, **kwargs)

In [360]:
model = model_full
p_M, p_U = model.subspace.prior.to_std_parameters()
p_U_diag = p_U.diag().numpy()
p_M, p_U = p_M.numpy(), p_U.numpy()
M, U = model.subspace.posterior.to_std_parameters()
U_diag = U.diag().numpy()
M, U = M.numpy(), U.numpy()
p_mean, p_cov = model.mean.prior.to_std_parameters()
p_cov = p_cov * torch.eye(2, dtype=p_cov.dtype)
mean, cov = model.mean.posterior.to_std_parameters()
cov = cov * torch.eye(2, dtype=cov.dtype)

fig_mean = figure(width=350, height=350, x_range=(-2, 2), y_range=(-2, 2))
fig_mean.xgrid.visible = False
fig_mean.ygrid.visible = False
fig_mean.xaxis.ticker = [-1,  0, 1]
fig_mean.yaxis.ticker = [-1,  0, 1]
fig_mean.xaxis.major_label_text_font_size = '16pt'
fig_mean.xaxis.axis_label_text_font_size = '16pt'
fig_mean.yaxis.major_label_text_font_size = '16pt'
fig_mean.yaxis.axis_label_text_font_size = '16pt'
fig_mean.min_border_left = 20
fig_mean.min_border_right = 20
fig_mean.xaxis.axis_label = 'μ'


plotting.plot_normal(fig_mean, p_mean.numpy(), p_cov.numpy(), n_std_dev=2, fill_alpha=.15, 
                     line_alpha=0., color='blue', legend='p(μ)')
plotting.plot_normal(fig_mean, mean.numpy(), cov.numpy(), n_std_dev=2, fill_alpha=.2, 
                     line_alpha=0., color='red', legend='q(μ)')
fig_mean.cross(mu[0], mu[1], color='green', size=10, line_width=2, legend='μ*')

fig_mean.legend.label_text_font_size = '16pt'
fig_mean.legend.location = 'top_left'


r_range = (0.01, 3)
res = 1000
fig_mag = figure(width=350, height=350, x_range=(0, r_range[1]), y_range=(0, 10))
fig_mag.xaxis.ticker = [0, 1, 2, 3]
fig_mag.yaxis.visible = False
fig_mag.xgrid.visible = False
fig_mag.ygrid.visible = False
fig_mag.xaxis.major_label_text_font_size = '16pt'
fig_mag.xaxis.axis_label_text_font_size = '16pt'
#fig_mag.yaxis.major_label_text_font_size = '16pt'
fig_mag.yaxis.axis_label_text_font_size = '16pt'
fig_mag.min_border_left = 20
fig_mag.min_border_right = 20
fig_mag.xaxis.axis_label = 'r'

fig_angle = figure(width=350, height=350, x_range=(0, res), y_range=(0, 5))
fig_angle.yaxis.visible = False
fig_angle.xaxis.ticker = [0, res//4, res//2, (3 * res) // 4, res]
fig_angle.xaxis.major_label_overrides = {0: '-π', res//4:'-π/2', res//2: '0', 
                                         (3 * res) // 4:'π/2', res: 'π'}
fig_angle.xgrid.visible = False
fig_angle.ygrid.visible = False
fig_angle.xaxis.major_label_text_font_size = '16pt'
fig_angle.xaxis.axis_label_text_font_size = '16pt'
fig_angle.min_border_left = 20
fig_angle.min_border_right = 20
fig_angle.xaxis.axis_label = 'α'

angle = int(res * np.arctan2(weights[0, 1], weights[0, 0]) / (2 * np.pi))
neg_angle = int(res * (np.arctan2(weights[0, 1], weights[0, 0]) + np.pi) / (2 * np.pi))
x = np.zeros(100) + angle
neg_x = np.zeros(100) + neg_angle
y = np.linspace(0, 10, 100)
fig_angle.line(x, y, line_width=2, alpha=.75, color='green', line_dash='dashed',
              legend='±α*')
fig_angle.line(neg_x, y, line_width=2, alpha=.75, color='green', line_dash='dashed')

norm = np.linalg.norm(weights[0])
x = np.zeros(100) + norm
y = np.linspace(0, 20, 100)
fig_mag.line(x, y, line_width=2, alpha=.75, color='green', line_dash='dashed',
             legend='r*')

cov = np.eye(2) * p_U_diag[0]
plot_normal_polar(fig_mag, fig_angle, p_M[0], cov, r_range=r_range, 
                  a_range=(0, 2 * np.pi), res=res, color='blue', alpha=.75, 
                  line_width=2, a_legend='p(α)', r_legend='p(r)')

cov = np.eye(2) * U_diag[0]
plot_normal_polar(fig_mag, fig_angle, M[0], cov, r_range=(0.01, 10), 
                  a_range=(0, 2 * np.pi), res=res, color='red', alpha=.75, 
                  line_width=2, a_legend='q(α)', r_legend='q(r)')

fig_angle.legend.location = 'top_left'
fig_angle.legend.label_text_font_size = '16pt'
fig_mag.legend.location = 'top_left'
fig_mag.legend.label_text_font_size = '16pt'

p_shape, p_rate = model.precision.prior.to_std_parameters()
p_shape, p_rate = p_shape.numpy(), p_rate.numpy()
shape, rate = model.precision.posterior.to_std_parameters()
shape, rate = shape.numpy(), rate.numpy()

t_range = (0.01, 12)
t = np.linspace(t_range[0], t_range[1], 10000)
log_p_tau = p_shape * np.log(p_rate) + (p_shape -1) * np.log(t)  -p_rate * t
log_p_tau -= gammaln(p_shape)
p_tau = np.exp(log_p_tau)
log_q_tau = shape * np.log(rate) + (shape -1) * np.log(t)  -rate * t
log_q_tau -= gammaln(shape)
q_tau = np.exp(log_q_tau)

fig_tau = figure(width=350, height=350, x_range=(0, t_range[1]), y_range=(0, 2.5))
fig_tau.yaxis.visible = False
fig_tau.xaxis.ticker = [0, 5, 10]
fig_tau.xgrid.visible = False
fig_tau.ygrid.visible = False
fig_tau.xaxis.major_label_text_font_size = '16pt'
fig_tau.xaxis.axis_label_text_font_size = '16pt'
fig_tau.min_border_left = 20
fig_tau.min_border_right = 20
fig_tau.xaxis.axis_label = 'τ'

x = np.zeros(100) + tau
y = np.linspace(0, 2.5, 100)
fig_tau.line(x, y, line_width=2, alpha=.75, color='green', line_dash='dashed',
              legend='τ*')

fig_tau.line(t, p_tau, line_width=2, alpha=.75, color='blue', legend='p(τ)')
fig_tau.line(t, q_tau, line_width=2, alpha=.75, color='red', legend='q(τ)')
fig_tau.legend.location = 'top_left' 
fig_tau.legend.label_text_font_size = '16pt'

fig = gridplot([[fig_mean, fig_tau], [fig_mag, fig_angle]], toolbar_location=None)
#outpath = '/Users/lucasondel/Desktop/gsm_norm_M100_N1000.png'
#export_png(fig, outpath)
show(fig)

In [350]:
cache = model.latent_posteriors(data=data, conv_threshold=1e-8)
means, covs = cache['params_mean'], cache['params_cov']

model = model_full
max_psis = model.llh_func.argmax(data).numpy()

fig1 = figure(width=400, height=400, x_range=x_range, y_range=y_range)
#fig1.cross(means.numpy()[:, 0], means.numpy()[:, 1], color='red')
for m, c in zip(means.numpy(), covs.numpy()):
    plotting.plot_normal(fig1, m, c * np.eye(2), n_std_dev=1, fill_alpha=.1, color='red')
fig1.circle(psis[:, 0], psis[:, 1], alpha=.5, color='blue')
fig1.cross(max_psis[:, 0], max_psis[:, 1], color='green', size=7)
show(fig1)

In [353]:
model = model_full
mean_scale = model.scale.expected_value()
mean_scale

tensor([231.8316], dtype=torch.float64)

In [354]:
model.subspace.posterior.mean, weights

(tensor([[3.1267e-13, 5.7760e-13]], dtype=torch.float64), array([[1., 2.]]))