In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import sqrtm


In [None]:
rho = 0.5
C = np.array([[1,rho],[rho,1]])
def MFID(C1,C2):
    a,b=np.array(C1[1,1]),np.array(C2[1,1])
    return np.sqrt(np.sum(np.sum(a+b - 2*np.sqrt(np.sqrt(a)*b*np.sqrt(a)))**2))
def RFID(C1,C2):
    return np.sqrt(np.sum(np.trace(C1+C2 - 2*sqrtm(sqrtm(C2)@C1@sqrtm(C2)))**2))

def CFID(C1,C2):
    t1 = C1[0,1]
    t2 = C2[0,1]
    C1_x = C1[1,1]-(t1*t1/C1[0,0])
    C2_x = C2[1,1]-(t2*t2/C2[0,0])
    return (t2-t1)**2+(C1_x+C2_x-2*np.sqrt(C1_x*C2_x))
    
def SC(Z,n):
    S = np.dot(Z.T,Z)/n
    return S
def NSC1(Z,n):
    S = np.dot(Z.T,Z)/n
    D = np.diag(1/np.sqrt(np.diag(S)))
    D[0,0]=1.0
    S3 = np.dot(D,np.dot(S,D))
    return S3
def NSC2(Z,n):
    S = np.dot(Z.T,Z)/n
    D = np.diag(1/np.sqrt(np.diag(S)))
    S4 = np.dot(D,np.dot(S,D))
    return S4
metrics = [MFID, RFID, CFID]
estimators = [SC, NSC1, NSC2]
_,axes= plt.subplots(3,2, figsize=(10,10))
res = np.zeros((3,3,10))
T=10000
for i, metric in enumerate(metrics):
    for j, estimator in enumerate(estimators):
        for t,_ in enumerate(range(10,110,10)):
            count_T = 0
            for k in range(T):
                Z = np.random.multivariate_normal(mean=np.array([0,0]), cov=C, size=(t))
                C_hat = estimator(Z,t)
                try:
                    res[i,j,t]+=metric(C_hat,C)
                    count_T +=1
                except:
                    pass

            res[i,j,t]/=count_T
for i, met in enumerate(metrics):
    for j, est in enumerate(estimators):
        axes[i,0].plot(np.arange(10,110,10),res[i,j,:], label=est.__name__)
        axes[i,0].set_title(met.__name__)
        axes[i,0].legend()
for j, est in enumerate(estimators):
    for i, met in enumerate(metrics):
        axes[j,1].plot(np.arange(10,110,10),res[i,j,:], label=met.__name__)
        axes[j,1].set_title(est.__name__)
        axes[j,1].legend()




In [148]:
MFID.__name__

'MFID'

array([[-0.46653867, -1.05632679],
       [ 0.11416508,  2.19309505],
       [ 0.77699085,  0.60628687],
       [-0.59909016,  1.20797125],
       [-0.86428019, -1.47845019],
       [ 0.93394752,  0.49609784],
       [ 0.72165765, -1.05056689],
       [-0.21852747, -0.79962951],
       [-0.02315532,  0.33353888],
       [ 1.02647221, -0.68874768],
       [-0.13312907, -0.64417393],
       [-0.88668391, -1.26900132],
       [-0.21358428,  0.21881977],
       [-0.6732335 ,  0.31221667],
       [-1.16642699, -1.27001953],
       [ 0.2256514 , -1.43499737],
       [-1.47320041, -1.51690977],
       [ 1.62079535, -0.59649674],
       [ 1.07774467, -0.55754609],
       [ 0.16987038,  0.33830111],
       [ 2.46549424,  1.20133866],
       [ 1.00330193, -1.69304573],
       [-0.38820512, -0.5840512 ],
       [ 0.9355233 ,  0.91990445],
       [ 0.39268736, -1.59089082],
       [-0.75557523,  2.19060454],
       [-0.16165888, -0.78897459],
       [-0.3825136 ,  1.09387011],
       [-0.31576527,

In [141]:
np.random.multivariate_normal()

<function RandomState.multivariate_normal>