In [4]:
import numpy as np
from scipy.io import loadmat
import numpy as np
from sklearn.metrics import mean_squared_error
from statistics import variance

In [13]:
def estimate_variance(xs: np.ndarray, ys: np.ndarray, affine: np.ndarray,
                      translation: np.ndarray, responsibility: np.ndarray) -> float:
    """
    Estimate the variance of GMM
    :param xs: a set of points with size (N, D), N is the number of samples, D is the dimension of points
    :param ys: a set of points with size (M, D), M is the number of samples, D is the dimension of points
    :param affine: an affine matrix with size (D, D)
    :param translation: a translation vector with size (1, D)
    :param responsibility: the responsibility matrix with size (N, M)
    :return:
    """
    # TODO: implement a method to estimate the variance of the GMMs
    N, D = xs.shape
    M = ys.shape[0]
    xs_hat = ys @ affine.T + translation
    variance = mean_squared_error(xs, xs_hat)
    return variance

In [28]:
fish = loadmat('fish.mat')
xs = fish['X']
ys = fish['Y']

In [7]:
responsibility = np.ones((xs.shape[0], ys.shape[0])) / ys.shape[0]

In [9]:
affine = np.ones((2, 2))
translation = np.ones((1, 2))

In [14]:
estimate_variance(xs, ys, affine = affine, translation=translation, responsibility=responsibility)

0.5115057865232507

In [26]:
np.sum(xs, axis=1).reshape(-1, 1)

array([[-1.08076995],
       [-0.9926374 ],
       [-0.95942381],
       [-0.97362699],
       [-1.07086414],
       [-1.08907335],
       [-1.35667582],
       [-1.18871412],
       [-1.036558  ],
       [-0.93261986],
       [-0.79626933],
       [-0.69983338],
       [-0.54607485],
       [-0.41722652],
       [-0.20774782],
       [-0.0373825 ],
       [ 0.09226704],
       [ 0.30334815],
       [ 0.48041446],
       [ 0.89456462],
       [ 1.26129801],
       [ 1.6114246 ],
       [ 1.74617272],
       [ 1.99236117],
       [ 1.86671766],
       [ 1.6537428 ],
       [ 1.44827013],
       [ 1.21948967],
       [ 1.0215192 ],
       [ 0.83265332],
       [ 0.64298624],
       [ 0.51814393],
       [ 0.44741938],
       [ 0.4514254 ],
       [ 0.45463022],
       [ 0.44122825],
       [ 0.43612967],
       [ 0.40022112],
       [ 0.35600917],
       [ 0.36001519],
       [ 0.44975016],
       [ 0.61101088],
       [ 0.78647478],
       [ 1.05837463],
       [ 1.35358226],
       [ 1

In [32]:
xs[:, np.newaxis, :] - (ys @ affine.T + translation)

array([[[-0.37659708,  0.3734713 ],
        [-0.41683249,  0.33323588],
        [-0.44524657,  0.3048218 ],
        ...,
        [-0.21722028,  0.5328481 ],
        [-0.83928302, -0.08921465],
        [-0.96021137, -0.21014299]],

       [[-0.35168688,  0.43669366],
        [-0.3919223 ,  0.39645824],
        [-0.42033638,  0.36804416],
        ...,
        [-0.19231009,  0.59607046],
        [-0.81437283, -0.02599229],
        [-0.93530117, -0.14692063]],

       [[-0.31847329,  0.43669366],
        [-0.35870871,  0.39645824],
        [-0.38712279,  0.36804416],
        ...,
        [-0.1590965 ,  0.59607046],
        [-0.78115924, -0.02599229],
        [-0.90208758, -0.14692063]],

       ...,

       [[-0.51775483,  0.24702658],
        [-0.55799025,  0.20679116],
        [-0.58640433,  0.17837708],
        ...,
        [-0.35837804,  0.40640338],
        [-0.98044078, -0.21565937],
        [-1.10136912, -0.33658771]],

       [[ 0.48834016, -0.21923833],
        [ 0.44810475, -0.25

In [43]:
xs[:, np.newaxis, :].shape

(91, 1, 2)

In [48]:
np.sum((xs[:, np.newaxis, :]-(ys @ affine.T)), axis=2).shape

(91, 91)

In [41]:
np.sum((xs[:, np.newaxis, :] - (ys @ affine.T + translation)) ** 2, axis = 2)

array([[0.28130617, 0.28479548, 0.29116084, ..., 0.33111174, 0.71235525,
        0.96616595],
       [0.31438501, 0.31078223, 0.31213918, ..., 0.39228316, 0.66387871,
        0.89637396],
       [0.29212659, 0.28585108, 0.28532056, ..., 0.38061168, 0.61088536,
        0.83534768],
       ...,
       [0.3290922 , 0.3541157 , 0.37568842, ..., 0.29359852, 1.00777309,
        1.32630524],
       [0.28654156, 0.26812449, 0.25901966, ..., 0.42312067, 0.46567886,
        0.65364949],
       [0.4557391 , 0.42523908, 0.40760133, ..., 0.64018005, 0.49592893,
        0.64758403]])

In [22]:
np.sum(xs, axis = 2)

array([[-1.08076995],
       [-0.9926374 ],
       [-0.95942381],
       [-0.97362699],
       [-1.07086414],
       [-1.08907335],
       [-1.35667582],
       [-1.18871412],
       [-1.036558  ],
       [-0.93261986],
       [-0.79626933],
       [-0.69983338],
       [-0.54607485],
       [-0.41722652],
       [-0.20774782],
       [-0.0373825 ],
       [ 0.09226704],
       [ 0.30334815],
       [ 0.48041446],
       [ 0.89456462],
       [ 1.26129801],
       [ 1.6114246 ],
       [ 1.74617272],
       [ 1.99236117],
       [ 1.86671766],
       [ 1.6537428 ],
       [ 1.44827013],
       [ 1.21948967],
       [ 1.0215192 ],
       [ 0.83265332],
       [ 0.64298624],
       [ 0.51814393],
       [ 0.44741938],
       [ 0.4514254 ],
       [ 0.45463022],
       [ 0.44122825],
       [ 0.43612967],
       [ 0.40022112],
       [ 0.35600917],
       [ 0.36001519],
       [ 0.44975016],
       [ 0.61101088],
       [ 0.78647478],
       [ 1.05837463],
       [ 1.35358226],
       [ 1