In [1]:
import numpy as np
import pandas as pd
from scipy.stats import multivariate_normal, invwishart, wishart
from scipy.linalg import eig, eigh
from timeit import timeit
from datetime import datetime
import logging

logger = logging.getLogger()

In [13]:
def sim_data(num_of_clusters, data_size):
    # generate random centroids
    LB=0
    UB=1000

    # make K centroids
    _centroids = np.random.randint(LB, UB, [num_of_clusters, 2])

    # make K variance-covariance matrices
    _covs = np.tile(np.eye(2), (num_of_clusters, 1, 1))
#     _covs  = wishart(df=2, scale=np.eye(2)).rvs(size=data_size)
    
    # Draw N random (two-dimensional) points 
    data = np.random.randint(LB, UB, [data_size, 2])
    
    # assign each of the points to one of the clusters
    parent_cluster = np.random.randint(0, num_of_clusters, data_size)

    # the centroids and cov matrices of the corresponding clusters
    centroids = _centroids[parent_cluster]
    covs = _covs[parent_cluster]
    
    return data, centroids, covs, parent_cluster

In [14]:
K = 100000
N = 10000000
data, centroids, covs, parent_cluster = sim_data(num_of_clusters=K, data_size=N)

In [4]:
def myFun(data, centroids, covs):
    param = list(zip(*[data, centroids, covs]))
    out = [multivariate_normal.logpdf(p[0], p[1], p[2]) for i, p in enumerate(param)]
    return out

In [5]:
def myFun_2(data, centroids, covs, parent_cluster):
    res = np.nan * np.ones(data.shape[0])
    pcs = np.unique(parent_cluster)
    for pc in pcs:
        mask = parent_cluster == pc
        res[mask] = multivariate_normal.logpdf(data[mask], centroids[mask][0], covs[mask][0])
    return res
    

In [17]:
def multiple_logpdfs(x, means, covs):
    # Thankfully, NumPy broadcasts `eigh`.
    print('1: %s' % datetime.now())
#     ucovs, idx = np.unique(covs, axis=0, return_inverse=True)
#     uvals, uvecs = np.linalg.eigh(ucovs)
#     vals = uvals[idx]
#     vecs = uvecs[idx]
    vals, vecs = np.linalg.eigh(covs)
    
    print('2: %s' % datetime.now())
    # Compute the log determinants across the second axis.
    logdets    = np.sum(np.log(vals), axis=1)

    print('3: %s' % datetime.now())
    # Invert the eigenvalues.
    valsinvs   = 1./vals
    
    print('4: %s' % datetime.now())
    # Add a dimension to `valsinvs` so that NumPy broadcasts appropriately.
    Us         = vecs * np.sqrt(valsinvs)[:, None]
    devs       = x - means

    print('5: %s' % datetime.now())
    # Use `einsum` for matrix-vector multiplications across the first dimension.
    devUs      = np.einsum('ni,nij->nj', devs, Us)

    print('6: %s' % datetime.now())
    # Compute the Mahalanobis distance by squaring each term and summing.
    mahas      = np.sum(np.square(devUs), axis=1)
    
    print('7: %s' % datetime.now())
    # Compute and broadcast scalar normalizers.
    dim        = len(vals[0])
    log2pi     = np.log(2 * np.pi)
    
    print('8: %s' % datetime.now())
    return -0.5 * (dim * log2pi + mahas + logdets)

In [None]:
assert max(abs(multiple_logpdfs(data, centroids, covs) - myFun_2(data, centroids, covs, parent_cluster))) < 1e-6

1: 2021-08-14 23:38:24.213362
2: 2021-08-14 23:38:24.684358
3: 2021-08-14 23:38:24.985365
4: 2021-08-14 23:38:25.042358
5: 2021-08-14 23:38:25.348360
6: 2021-08-14 23:38:26.588358
7: 2021-08-14 23:38:26.793332
8: 2021-08-14 23:38:26.793332


In [8]:
def f():
    return myFun(data, centroids, covs)
print(timeit(f, number=3))

2.5790328999999996


In [None]:
def f_2():
    return myFun_2(data, centroids, covs, parent_cluster)
print(timeit(f_2, number=1))

In [18]:
def f_3():
    return multiple_logpdfs(data, centroids, covs)
print(timeit(f_3, number=1))

1: 2021-08-14 23:53:25.557792
2: 2021-08-14 23:53:31.517818
3: 2021-08-14 23:53:31.808828
4: 2021-08-14 23:53:31.875817
5: 2021-08-14 23:53:32.196792
6: 2021-08-14 23:53:33.413825
7: 2021-08-14 23:53:33.607817
8: 2021-08-14 23:53:33.607817
8.262327000000028


In [11]:
myFun_2(data, centroids, covs, parent_cluster).sum()

-207661327622.14847

In [12]:
multiple_logpdfs(data, centroids, covs).sum()

1: 2021-08-14 23:35:39.965203
2: 2021-08-14 23:35:39.965203
3: 2021-08-14 23:35:39.966205
4: 2021-08-14 23:35:39.966205
5: 2021-08-14 23:35:39.967220
6: 2021-08-14 23:35:39.969203
7: 2021-08-14 23:35:39.970216
8: 2021-08-14 23:35:39.970216


-207661327622.14847

In [13]:
multiple_logpdfs(data, centroids, covs)

1: 2021-08-14 23:35:45.866886
2: 2021-08-14 23:35:45.867866
3: 2021-08-14 23:35:45.867866
4: 2021-08-14 23:35:45.867866
5: 2021-08-14 23:35:45.868867
6: 2021-08-14 23:35:45.870895
7: 2021-08-14 23:35:45.871866
8: 2021-08-14 23:35:45.871866


array([-2.06600413e+07, -3.26272831e+04, -3.05835865e+03, ...,
       -8.11232295e+05, -2.06523115e+04, -9.47856745e+03])

In [14]:
myFun_2(data, centroids, covs, parent_cluster)

array([-2.06600413e+07, -3.26272831e+04, -3.05835865e+03, ...,
       -8.11232295e+05, -2.06523115e+04, -9.47856745e+03])

In [10]:
np.linalg.eigh(covs)

(array([[3.15383976e-01, 6.85693169e-01],
        [7.17175877e-03, 2.00194008e+00],
        [1.34820650e-02, 2.28038011e+00],
        ...,
        [5.96179841e-02, 8.36992125e+00],
        [4.85815571e-02, 4.20676176e-01],
        [2.34164849e-01, 6.31572848e-01]]),
 array([[[-0.9746998 , -0.22351803],
         [-0.22351803,  0.9746998 ]],
 
        [[-0.93109978,  0.36476457],
         [ 0.36476457,  0.93109978]],
 
        [[-0.05950777, -0.99822784],
         [-0.99822784,  0.05950777]],
 
        ...,
 
        [[ 0.02229099, -0.99975153],
         [-0.99975153, -0.02229099]],
 
        [[-0.75654242, -0.65394462],
         [-0.65394462,  0.75654242]],
 
        [[-0.95576615, -0.29412764],
         [-0.29412764,  0.95576615]]]))

In [11]:
eigh(covs)

ValueError: expected square "a" matrix

In [12]:
np.unique(covs, axis=0)
covs

array([[[1., 0.],
        [0., 1.]],

       [[1., 0.],
        [0., 1.]],

       [[1., 0.],
        [0., 1.]],

       ...,

       [[1., 0.],
        [0., 1.]],

       [[1., 0.],
        [0., 1.]],

       [[1., 0.],
        [0., 1.]]])