In [96]:
import numpy as np
import sklearn.datasets as dst
from numpy import linalg as LA

class ConditionalGaussEstimator:
    def __init__(self, A, mu_x, var_x):
        self._A = A
        self._A_T = self._A.transpose()
        self._mu_x = mu_x
        self._mu_y = self._A.dot(self._mu_x)
        self._var_x = var_x
        
        if len(self._A) == 1:
            self._var_y = 1.0 / self._A.dot(self._var_x).dot(self._A_T)
        else:
            self._var_y = LA.inv(A.dot(self._var_x).dot(self._A_T))
        
        self._Phi = self._var_x.dot(self._A_T).dot(self._var_y)
        
    def estimate(self, y):
        return self._mu_x + self._Phi.dot(y - self._mu_y)
    
    def estimate_0(self, y):
        return self._mu_x
    
def eigen_norm(A):
    eigen_values, eigen_vectors = LA.eig(A)
    return max([LA.norm(eigen_value) for eigen_value in eigen_values])




nb_samples = 10000
dim_x = 2
dim_y = 1
samples = np.zeros((nb_samples, 2, dim_x))
        
A = np.random.rand(dim_y,dim_x)

A_T = A.transpose()
mu_x = np.array([3.0, .5])
var_x =  dst.make_spd_matrix(dim_x)

cond_gauss_est = ConditionalGaussEstimator(A, mu_x, var_x)
rnd_x = np.random.multivariate_normal(mu_x, var_x, nb_samples)

check = True

for i in range(0, nb_samples):
    samples[i,0] = rnd_x[i]
    y_i  = A.dot(rnd_x[i])
    samples[i,1] = cond_gauss_est.estimate(A.dot(rnd_x[i]))
    
    if check:
        assert LA.norm(y_i - A.dot(samples[i,1])) < 10e-5, 'Estimators is not contained in subspace!'

var_0 = np.cov((samples[:,0]-mu_x).transpose())
var_1 = np.cov((samples[:,1]-mu_x).transpose())

#print(var_0)
#print(var_x)
#print(var_ext)
    
eigen_norm_0 = eigen_norm(var_0)
print(eigen_norm_0)

eigen_norm_1 = eigen_norm(var_1)
print(eigen_norm_1)



dim_x = 3
dim_y = 2

A = np.array([[+ 1.0, - 1.0, 4.0], [1.0, 1.0, .0]])
mu_x = np.array([3.0, .5, -1.0])
var_x =  dst.make_spd_matrix(dim_x)


cond_gauss_est = ConditionalGaussEstimator(A, mu_x, var_x)
rnd_x = np.random.multivariate_normal(mu_x, var_x, nb_samples)

2.6755697569159853
0.7081900627374043
