In [None]:
import numpy
import matplotlib.pyplot as plt
from numpy import inf
from utils import *
from classifiers.MultivariateGaussianClassifier import *
from classifiers.NaiveBayesClassifier import *
from classifiers.TiedCovarianceGaussianClassifier import *
from classifiers.TiedDiagCovGaussianClassifier import *
from classifiers.LogisticRegression import *
from classifiers.LinearSVM import *
from classifiers.KernelSVM import *
from classifiers.GaussianMixtureModel import *
from transformers.PCA import *
from transformers.Gaussianizer import *
from tabulate import tabulate
from itertools import combinations
import time
import concurrent.futures
from tqdm import tqdm
import sklearn.model_selection
from scipy.interpolate import interp1d

labels_map = {
    0: 'Not a Pulsar',
    1: 'Pulsar'
}

features_map = {
    0: 'Mean of the integrated profile',
    1: 'Standard deviation of the integrated profile',
    2: 'Excess kurtosis of the integrated profile',
    3: 'Skewness of the integrated profile',
    4: 'Mean of the DM-SNR curve',
    5: 'Standard deviation of the DM-SNR curve',
    6: 'Excess kurtosis of the DM-SNR curve',
    7: 'Skewness of the DM-SNR curve'
}


def load_data(filepath):
    data = []
    labels = []
    with open(filepath) as f:
        for line in f:
            fields = line.split(',')
            data.append([float(feature) for feature in fields[0: 8]])
            labels.append(int(fields[8]))
    data = numpy.array(data).T  # transpose to have the features on the rows and the samples on the columns
    labels = numpy.array(labels)
    return data, labels


def plot_hist(D, L, folder='hist'):
    D_index_L = [D[:, L == i] for i in set(L)]

    for i in features_map.keys():
        plt.figure()
        plt.xlabel(features_map[i])
        for index, data in enumerate(D_index_L):
            plt.hist(data[i, :], bins=20, density=True, ec='black', alpha=0.5, label=labels_map[index])
        plt.legend()
        plt.tight_layout()
        plt.savefig('./plots/' + folder + '/hist_%d.png' % i)
    plt.show()


def plot_scatter(D, L, folder='scatter'):
    D_index_L = [D[:, L == i] for i in set(L)]

    for i in features_map.keys():
        for j in features_map.keys():
            if i == j:
                continue
            plt.figure()
            plt.xlabel(features_map[i])
            plt.ylabel(features_map[j])
            for index, data in enumerate(D_index_L):
                plt.scatter(data[i, :], data[j, :], label=labels_map[index], alpha=0.5)  # red
            plt.legend()
            plt.tight_layout()
            plt.savefig('./plots/' + folder + '/scatter_%d_%d.png' % (i, j))
        plt.show()


def plot_heatmap(D, folder='heatmap', subtitle='', color='YlGn'):
    corr_coef = numpy.corrcoef(D)

    fig, ax = plt.subplots()
    ax.imshow(corr_coef, cmap=color)
    for i in range(len(features_map)):
        for j in range(len(features_map)):
            ax.text(j, i, str(round(corr_coef[i, j], 1)), ha="center", va="center", color="r")

    fig.tight_layout()
    plt.savefig('./plots/' + folder + '/corr_coeff_' + subtitle + '.png')
    plt.show()


def compute_confusion_matrix(true, predicted):
    K = numpy.unique(numpy.concatenate((true, predicted))).size
    confusion_matrix = numpy.zeros((K, K), dtype=numpy.int64)

    # for i in range(len(true)):
    #     confusion_matrix[predicted[i], true[i]] += 1

    # 6 times speed up with respect to the previous code
    labels = numpy.hstack((vcol(predicted), vcol(true)))
    for indexes in set(combinations(tuple(list(range(K)) + list(range(K))), K)):
        equals = numpy.array(labels == indexes, dtype=numpy.int8).sum(axis=1) == K
        confusion_matrix[indexes] = numpy.array(equals, dtype=numpy.int8).sum()

    return confusion_matrix


def DCFu(prior, cfn, cfp, confusion_matrix):
    FNR = confusion_matrix[0, 1] / sum(confusion_matrix[:, 1])
    FPR = confusion_matrix[1, 0] / sum(confusion_matrix[:, 0])
    DCFu = prior * cfn * FNR + (1 - prior) * cfp * FPR
    return DCFu


def DCF(prior, cfn, cfp, confusion_matrix):
    DCFu_ = DCFu(prior, cfn, cfp, confusion_matrix)
    Bdummy = min(prior * cfn, (1 - prior) * cfp)
    return DCFu_ / Bdummy


def min_DCF(llr, labels, prior, cfn, cfp):
    scores = llr  # numpy.sort(llr)  # without sort improve performance

    mindcf = None
    for i, threshold in enumerate(scores):
        predicted = 0 + (llr > threshold)
        confusion_matrix_min_dcf = compute_confusion_matrix(labels, predicted)
        DCF_ = DCF(prior, cfn, cfp, confusion_matrix_min_dcf)
        mindcf = mindcf if mindcf is not None and mindcf <= DCF_ else DCF_

    return mindcf


def k_fold_min_DCF(D, L, K, Classifier, prior, class_args=(), transformers=[], transf_args=[]):
    if K <= 0 or K > D.shape[1]:
        raise Exception("K-Fold : K should be > 1 and <= " + str(D.shape[1]))
    nTest = int(D.shape[1] / K)
    nTrain = D.shape[1] - nTest
    numpy.random.seed(0)
    idx_1 = numpy.random.permutation(D.shape[1])
    # duplicate idx
    idx = numpy.concatenate((idx_1, idx_1))

    n_classes = len(set(L))
    llr = numpy.zeros(D.shape[1])
    for i in range(K):
        start = i * nTest
        idxTrain = idx[start: start + nTrain]
        idxTest = idx[start + nTrain: start + nTrain + nTest]

        DTR = D[:, idxTrain]
        DTE = D[:, idxTest]
        LTR = L[idxTrain]
        LTE = L[idxTest]

        for j, T in enumerate(transformers):
            transformer = T().fit(DTR, *transf_args[j])
            DTR = transformer.transform(DTR)
            DTE = transformer.transform(DTE)

        classifier = Classifier(DTR, LTR, *class_args)
        llr[idxTest] = classifier.llr(DTE)

    mindcf = min_DCF(llr, L, prior, 1, 1)

    return mindcf


def gaussianize(D):
    return Gaussianizer().fit(D).transform(D)


if __name__ == '__main__':
    DTR, LTR = load_data('./data/Train.txt')

    # DTR, _, LTR, _ = sklearn.model_selection.train_test_split(DTR.T, LTR, train_size=1 / 8, random_state=42)
    # DTR = DTR.T

    print_plots = False
    load_precomputed_data = [False, False, False, False]  # [False, False, False, False]
    store_computed_data = [True, True, True, True]  # [True, True, True, True]

    if load_precomputed_data[0]:
        DTR_G = numpy.load('./data/TrainGAU.npy')
    else:
        DTR_G = Gaussianizer().fit(DTR).transform(DTR)
        if store_computed_data[0]:
            numpy.save('./data/TrainGAU.npy', DTR_G)

    if print_plots:
        plot_hist(DTR, LTR)
        plot_scatter(DTR, LTR)

        plot_hist(DTR_G, LTR, folder='hist_GAU')
        plot_scatter(DTR_G, LTR, folder='scatter_GAU')

        plot_heatmap(DTR, subtitle='all', color='binary')
        plot_heatmap(DTR[:, LTR == 1], subtitle='pulsar', color='Blues')
        plot_heatmap(DTR[:, LTR == 0], subtitle='not_pulsar', color='Greens')

    #######################################################################################
    # Gaussian Models
    #######################################################################################

    classifier_name = numpy.array([
        'Full-Cov',
        'Diag-Cov',
        'Tied Full-Cov',
        'Tied Diag-Cov'
    ])
    classifiers = numpy.array([
        MultivariateGaussianClassifier,
        NaiveBayesClassifier,
        TiedCovarianceGaussianClassifier,
        TiedDiagCovGaussianClassifier
    ])

    priors = numpy.array([0.5, 0.1, 0.9])
    data = [DTR for i in range(6)]
    mindcf = numpy.zeros((len(data), classifiers.shape[0], priors.shape[0]))
    transformers = [
        [Gaussianizer],
        [PCA, Gaussianizer],
        [PCA, Gaussianizer],
        [PCA, Gaussianizer],
        [],
        [PCA]
    ]
    transf_args = [
        [()],
        [(7,), ()],
        [(6,), ()],
        [(5,), ()],
        [()],
        [(7,)]
    ]

    if len(data) != len(transformers) or len(transformers) != len(transf_args):
        raise Exception("Length of data/transformers/transf_args incorrect")
    elif classifiers.shape[0] != classifier_name.shape[0]:
        raise Exception("Length of classifiers/classifier_name incoherent")

    if load_precomputed_data[1]:
        mindcf = numpy.load('./data/minDCF_GAU_models.npy')

    results = []
    for d, D in enumerate(data):
        with concurrent.futures.ProcessPoolExecutor() as executor:
            if not load_precomputed_data[1]:
                for i, c in enumerate(classifiers):
                    for j, p in enumerate(priors):
                        print(classifier_name[i] + " - prior = " + str(p) + " - data id = " + str(d))
                        results.append(executor.submit(k_fold_min_DCF, D, LTR, 5, c, p, (), transformers[d], transf_args[d]))
#                         mindcf[d, i, j] = k_fold_min_DCF(
#                             D, LTR, K=5, 
#                             Classifier=c, 
#                             prior=p, 
#                             class_args=(), 
#                             transformers=transformers[d], 
#                             transf_args=transf_args[d]
#                         )
                        # print("min_DCF = " + str(mindcf[i, j]))
            for i, r in enumerate(tqdm(results)):
                mindcf[numpy.unravel_index(i, mindcf.shape, 'C')] = round(r.result(), 3)
            table = numpy.hstack((vcol(classifier_name), mindcf[d]))
            print(tabulate(table, headers=[""] + list(priors), tablefmt='fancy_grid'))

    if not store_computed_data[1]:
        numpy.save('./data/minDCF_GAU_models.npy', mindcf)

    #######################################################################################
    # Logistic Regression
    #######################################################################################

    classifier_name = numpy.array([
        'Log Reg',
        'Log Reg'
    ])
    classifiers = numpy.array([
        LogisticRegression,
        LogisticRegression
    ])
    transformers = [
        [Gaussianizer],
        []
    ]
    transf_args = [
        [()],
        [()]
    ]
    data = [DTR for i in range(2)]

    lamb = numpy.array([10 ** i for i in range(-6, 6)])
    lamb = numpy.array([numpy.linspace(lamb[i], lamb[i + 1], 10) for i in range(lamb.shape[0] - 1)]).reshape(-1)
    priors = numpy.array([0.5, 0.1, 0.9])

    if load_precomputed_data[2]:
        mindcf = numpy.load('./data/minDCF_LogReg_lamb.npy')
    else:
        mindcf = numpy.zeros((len(data), classifiers.shape[0], priors.shape[0], lamb.shape[0]))

    if len(data) != len(transformers) or len(transformers) != len(transf_args):
        raise Exception("Length of data/transformers/transf_args incoherent")
    elif classifiers.shape[0] != classifier_name.shape[0]:
        raise Exception("Length of classifiers/classifier_name incoherent")

    results = []
    for d, D in enumerate(data):
        with concurrent.futures.ProcessPoolExecutor() as executor:
            if not load_precomputed_data[2]:
                for i, c in enumerate(classifiers):
                    for j, p in enumerate(priors):
                        print(classifier_name[i] + " - prior = " + str(p) + " - data id = " + str(d))
                        for k, l in enumerate(lamb):
                            results.append(executor.submit(k_fold_min_DCF, D, LTR, 5, c, p, (l,), transformers[d], transf_args[d]))
            for i, r in enumerate(tqdm(results)):
                mindcf[numpy.unravel_index(i, mindcf.shape, 'C')] = round(r.result(), 3)
            table = numpy.hstack((vcol(classifier_name), mindcf[d].min(axis=2, initial=inf)))
            print(tabulate(table, headers=[""] + list(priors), tablefmt='fancy_grid'))

    if store_computed_data[2]:
        numpy.save('./data/minDCF_LogReg_lamb.npy', mindcf)

    for d in range(len(data)):
        for i in range(mindcf[d].shape[0] // 2):
            plt.figure()
            for j, p in enumerate(priors):
                plt.plot(lamb, mindcf[d, i + 1, j], label='minDCF (π = ' + str(p) + ')')
            plt.xlabel('λ')
            plt.ylabel('min DCF')
            plt.legend()
            plt.xscale('log')
            plt.tight_layout()
            name = 'Raw' if i == 0 else 'Gaussianized'
            if store_computed_data[2]:
                plt.savefig('./plots/mindcf_training/LogReg_lamb_' + str(d) + '_' + str(i) + '.png')
            plt.show()

    ##################################################################################
    # Linear SVM
    ##################################################################################

    classifier_name = numpy.array([
        'SVM (no class balancing)',
        'SVM (with class balancing)'
    ])
    classifiers = numpy.array([
        LinearSVM,
        LinearSVM
    ])
    transformers = [
        [],
    ]
    transf_args = [
        [()],
    ]
    data = [DTR]

    Ci = numpy.array([10 ** i for i in range(-3, 3)])
    priors = numpy.array([0.5, 0.1, 0.9])

    if load_precomputed_data[3]:
        mindcf = numpy.load('./data/minDCF_SVM_C.npy')
    else:
        mindcf = numpy.zeros((len(data), classifiers.shape[0], priors.shape[0], Ci.shape[0]))

    if len(data) != len(transformers) or len(transformers) != len(transf_args):
        raise Exception("Length of data/transformers/transf_args incoherent")
    elif classifiers.shape[0] != classifier_name.shape[0]:
        raise Exception("Length of classifiers/classifier_name incoherent")

    results = []
    for d, D in enumerate(data):
        with concurrent.futures.ProcessPoolExecutor() as executor:
            if not load_precomputed_data[3]:
                for i, c in enumerate(classifiers):
                    for j, p in enumerate(priors):
                        print(classifier_name[i] + " - prior = " + str(p) + " - data id = " + str(d))
                        for k, C in enumerate(Ci):
                            results.append(
                                executor.submit(k_fold_min_DCF, D, LTR, 5, c, p, (1, C, p, None,), transformers[d], transf_args[d]))
            for i, r in enumerate(tqdm(results)):
                mindcf[numpy.unravel_index(i, mindcf.shape, 'C')] = round(r.result(), 3)
            table = numpy.hstack((vcol(classifier_name), mindcf[d].min(axis=2, initial=inf)))
            print(tabulate(table, headers=[""] + list(priors), tablefmt='fancy_grid'))

    if store_computed_data[3]:
        numpy.save('./data/minDCF_SVM_C.npy', mindcf)

    for d in range(len(data)):
        for i in range(mindcf[d].shape[0]):
            plt.figure()
            for j, p in enumerate(priors):
                plt.plot(lamb, mindcf[d, i, j], label='minDCF (π = ' + str(p) + ')')
            plt.xlabel('λ')
            plt.ylabel('min DCF')
            plt.legend()
            plt.xscale('log')
            plt.tight_layout()
            if store_computed_data[2]:
                plt.savefig('./plots/mindcf_training/SVM_C_' + str(d) + '_' + str(i) + '.png')
            plt.show()



Bad key "text.kerning_factor" on line 4 in
/opt/anaconda3/envs/bigdatalab_cpu_202101/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test_patch.mplstyle.
You probably need to get an updated matplotlibrc file from
https://github.com/matplotlib/matplotlib/blob/v3.1.2/matplotlibrc.template
or from the matplotlib source distribution


[[0.06460643 0.7275781  0.10133244 ... 0.10536334 0.33635651 0.40857687]
 [0.06505431 0.02216997 0.03728586 ... 0.01645952 0.33467697 0.01578771]
 [0.95252491 0.25193147 0.85164035 ... 0.93214646 0.6264696  0.58414511]
 ...
 [0.46456164 0.71458963 0.39626022 ... 0.89564438 0.14735192 0.1025641 ]
 [0.56455044 0.26346434 0.67237711 ... 0.05576083 0.86720412 0.88131228]
 [0.55380137 0.29380808 0.66554697 ... 0.05576083 0.86462882 0.886015  ]]
[[-1.51721324  0.60550455 -1.27399588 ... -1.25156975 -0.42242748
  -0.23120733]
 [-1.51367373 -2.0108629  -1.78308962 ... -2.13306925 -0.42703486
  -2.14974474]
 [ 1.6698443  -0.66842405  1.04349457 ...  1.4919697   0.32251735
   0.21250923]
 ...
 [-0.08894795  0.56684317 -0.26303911 ...  1.25711712 -1.04785842
  -1.26707575]
 [ 0.16251653 -0.63270137  0.44648662 ... -1.59139074  1.1132717
   1.18157234]
 [ 0.13527144 -0.54229374  0.42764986 ... -1.59139074  1.1013546
   1.20560458]]
Full-Cov - prior = 0.5 - data id = 0


  0%|          | 0/12 [00:00<?, ?it/s]

Full-Cov - prior = 0.1 - data id = 0
Full-Cov - prior = 0.9 - data id = 0
Diag-Cov - prior = 0.5 - data id = 0
Diag-Cov - prior = 0.1 - data id = 0
Diag-Cov - prior = 0.9 - data id = 0
Tied Full-Cov - prior = 0.5 - data id = 0
Tied Full-Cov - prior = 0.1 - data id = 0
Tied Full-Cov - prior = 0.9 - data id = 0
Tied Diag-Cov - prior = 0.5 - data id = 0
Tied Diag-Cov - prior = 0.1 - data id = 0
Tied Diag-Cov - prior = 0.9 - data id = 0
[[0.56171285 0.65855024 0.56157291 ... 0.02728799 0.40064372 0.30800448]
 [0.52448922 0.49594179 0.42303387 ... 0.42163448 0.22166247 0.18891688]
 [0.71088721 0.44640358 0.36705849 ... 0.94080605 0.54757907 0.78379513]
 ...
 [0.74993003 0.88497061 0.55625525 ... 0.6602295  0.40778058 0.25944584]
 [0.36957739 0.06647075 0.43436888 ... 0.42807165 0.58116429 0.63895886]
 [0.34214945 0.0677302  0.45060174 ... 0.38441086 0.59739715 0.66890568]]


  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[0.56171285 0.65855024 0.56157291 ... 0.02728799 0.40064372 0.30800448]
 [0.52448922 0.49594179 0.42303387 ... 0.42163448 0.22166247 0.18891688]
 [0.71088721 0.44640358 0.36705849 ... 0.94080605 0.54757907 0.78379513]
 ...
 [0.74993003 0.88497061 0.55625525 ... 0.6602295  0.40778058 0.25944584]
 [0.36957739 0.06647075 0.43436888 ... 0.42807165 0.58116429 0.63895886]
 [0.34214945 0.0677302  0.45060174 ... 0.38441086 0.59739715 0.66890568]][[ 0.15531333  0.40850968  0.15495831 ... -1.92223665 -0.25168127
  -0.50151467]
 [ 0.06142399 -0.01017261 -0.19413812 ... -0.19771384 -0.76659065
  -0.8818947 ]
 [ 0.55597846 -0.134753   -0.33965416 ...  1.56157599  0.11954717
   0.78507476]
 ...
 [ 0.67426958  1.20020747  0.14148159 ...  0.41308955 -0.23325789
  -0.64505478]
 [-0.33297286 -1.50260278 -0.16526201 ... -0.18128574  0.20487291
   0.35567725]
 [-0.40660393 -1.49291134 -0.12414121 ... -0.29391649  0.24661563
   0.43689343]]



  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[ 0.15531333  0.40850968  0.15495831 ... -1.92223665 -0.25168127
  -0.50151467]
 [ 0.06142399 -0.01017261 -0.19413812 ... -0.19771384 -0.76659065
  -0.8818947 ]
 [ 0.55597846 -0.134753   -0.33965416 ...  1.56157599  0.11954717
   0.78507476]
 ...
 [ 0.67426958  1.20020747  0.14148159 ...  0.41308955 -0.23325789
  -0.64505478]
 [-0.33297286 -1.50260278 -0.16526201 ... -0.18128574  0.20487291
   0.35567725]
 [-0.40660393 -1.49291134 -0.12414121 ... -0.29391649  0.24661563
   0.43689343]]
[[0.56171285 0.65855024 0.56157291 ... 0.02728799 0.40064372 0.30800448]
 [0.52448922 0.49594179 0.42303387 ... 0.42163448 0.22166247 0.18891688]
 [0.71088721 0.44640358 0.36705849 ... 0.94080605 0.54757907 0.78379513]
 ...
 [0.74993003 0.88497061 0.55625525 ... 0.6602295  0.40778058 0.25944584]
 [0.36957739 0.06647075 0.43436888 ... 0.42807165 0.58116429 0.63895886]
 [0.34214945 0.0677302  0.45060174 ... 0.38441086 0.59739715 0.66890568]]
[[0.56171285 0.65855024 0.56157291 ... 0.02728799 0.40064372 0.3

  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]





  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[ 0.15531333  0.40850968  0.15495831 ... -1.92223665 -0.25168127
  -0.50151467]
 [ 0.06142399 -0.01017261 -0.19413812 ... -0.19771384 -0.76659065
  -0.8818947 ]
 [ 0.55597846 -0.134753   -0.33965416 ...  1.56157599  0.11954717
   0.78507476]
 ...
 [ 0.67426958  1.20020747  0.14148159 ...  0.41308955 -0.23325789
  -0.64505478]
 [-0.33297286 -1.50260278 -0.16526201 ... -0.18128574  0.20487291
   0.35567725]
 [-0.40660393 -1.49291134 -0.12414121 ... -0.29391649  0.24661563
   0.43689343]]
[[0.56171285 0.65855024 0.56157291 ... 0.02728799 0.40064372 0.30800448]
 [0.52448922 0.49594179 0.42303387 ... 0.42163448 0.22166247 0.18891688]
 [0.71088721 0.44640358 0.36705849 ... 0.94080605 0.54757907 0.78379513]
 ...
 [0.74993003 0.88497061 0.55625525 ... 0.6602295  0.40778058 0.25944584]
 [0.36957739 0.06647075 0.43436888 ... 0.42807165 0.58116429 0.63895886]
 [0.34214945 0.0677302  0.45060174 ... 0.38441086 0.59739715 0.66890568]][[ 0.15531333  0.40850968  0.15495831 ... -1.92223665 -0.25168127

  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[ 0.15531333  0.40850968  0.15495831 ... -1.92223665 -0.25168127
  -0.50151467]
 [ 0.06142399 -0.01017261 -0.19413812 ... -0.19771384 -0.76659065
  -0.8818947 ]
 [ 0.55597846 -0.134753   -0.33965416 ...  1.56157599  0.11954717
   0.78507476]
 ...
 [ 0.67426958  1.20020747  0.14148159 ...  0.41308955 -0.23325789
  -0.64505478]
 [-0.33297286 -1.50260278 -0.16526201 ... -0.18128574  0.20487291
   0.35567725]
 [-0.40660393 -1.49291134 -0.12414121 ... -0.29391649  0.24661563
   0.43689343]]
[[0.56171285 0.65855024 0.56157291 ... 0.02728799 0.40064372 0.30800448]
 [0.52448922 0.49594179 0.42303387 ... 0.42163448 0.22166247 0.18891688]
 [0.71088721 0.44640358 0.36705849 ... 0.94080605 0.54757907 0.78379513]
 ...
 [0.74993003 0.88497061 0.55625525 ... 0.6602295  0.40778058 0.25944584]
 [0.36957739 0.06647075 0.43436888 ... 0.42807165 0.58116429 0.63895886]
 [0.34214945 0.0677302  0.45060174 ... 0.38441086 0.59739715 0.66890568]][[0.56171285 0.65855024 0.56157291 ... 0.02728799 0.40064372 0.30

  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[0.56171285 0.65855024 0.56157291 ... 0.02728799 0.40064372 0.30800448]
 [0.52448922 0.49594179 0.42303387 ... 0.42163448 0.22166247 0.18891688]
 [0.71088721 0.44640358 0.36705849 ... 0.94080605 0.54757907 0.78379513]
 ...
 [0.74993003 0.88497061 0.55625525 ... 0.6602295  0.40778058 0.25944584]
 [0.36957739 0.06647075 0.43436888 ... 0.42807165 0.58116429 0.63895886]
 [0.34214945 0.0677302  0.45060174 ... 0.38441086 0.59739715 0.66890568]]


  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]



[[0.56171285 0.65855024 0.56157291 ... 0.02728799 0.40064372 0.30800448]
 [0.52448922 0.49594179 0.42303387 ... 0.42163448 0.22166247 0.18891688]
 [0.71088721 0.44640358 0.36705849 ... 0.94080605 0.54757907 0.78379513]
 ...
 [0.74993003 0.88497061 0.55625525 ... 0.6602295  0.40778058 0.25944584]
 [0.36957739 0.06647075 0.43436888 ... 0.42807165 0.58116429 0.63895886]
 [0.34214945 0.0677302  0.45060174 ... 0.38441086 0.59739715 0.66890568]]

  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]





  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[0.56171285 0.65855024 0.56157291 ... 0.02728799 0.40064372 0.30800448]
 [0.52448922 0.49594179 0.42303387 ... 0.42163448 0.22166247 0.18891688]
 [0.71088721 0.44640358 0.36705849 ... 0.94080605 0.54757907 0.78379513]
 ...
 [0.74993003 0.88497061 0.55625525 ... 0.6602295  0.40778058 0.25944584]
 [0.36957739 0.06647075 0.43436888 ... 0.42807165 0.58116429 0.63895886]
 [0.34214945 0.0677302  0.45060174 ... 0.38441086 0.59739715 0.66890568]][[ 0.15531333  0.40850968  0.15495831 ... -1.92223665 -0.25168127
  -0.50151467]
 [ 0.06142399 -0.01017261 -0.19413812 ... -0.19771384 -0.76659065
  -0.8818947 ]
 [ 0.55597846 -0.134753   -0.33965416 ...  1.56157599  0.11954717
   0.78507476]
 ...
 [ 0.67426958  1.20020747  0.14148159 ...  0.41308955 -0.23325789
  -0.64505478]
 [-0.33297286 -1.50260278 -0.16526201 ... -0.18128574  0.20487291
   0.35567725]
 [-0.40660393 -1.49291134 -0.12414121 ... -0.29391649  0.24661563
   0.43689343]][[0.56171285 0.65855024 0.56157291 ... 0.02728799 0.40064372 0.308

  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[ 0.15531333  0.40850968  0.15495831 ... -1.92223665 -0.25168127
  -0.50151467]
 [ 0.06142399 -0.01017261 -0.19413812 ... -0.19771384 -0.76659065
  -0.8818947 ]
 [ 0.55597846 -0.134753   -0.33965416 ...  1.56157599  0.11954717
   0.78507476]
 ...
 [ 0.67426958  1.20020747  0.14148159 ...  0.41308955 -0.23325789
  -0.64505478]
 [-0.33297286 -1.50260278 -0.16526201 ... -0.18128574  0.20487291
   0.35567725]
 [-0.40660393 -1.49291134 -0.12414121 ... -0.29391649  0.24661563
   0.43689343]][[ 0.15531333  0.40850968  0.15495831 ... -1.92223665 -0.25168127
  -0.50151467]
 [ 0.06142399 -0.01017261 -0.19413812 ... -0.19771384 -0.76659065
  -0.8818947 ]
 [ 0.55597846 -0.134753   -0.33965416 ...  1.56157599  0.11954717
   0.78507476]
 ...
 [ 0.67426958  1.20020747  0.14148159 ...  0.41308955 -0.23325789
  -0.64505478]
 [-0.33297286 -1.50260278 -0.16526201 ... -0.18128574  0.20487291
   0.35567725]
 [-0.40660393 -1.49291134 -0.12414121 ... -0.29391649  0.24661563
   0.43689343]]



  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]
  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[ 0.15531333  0.40850968  0.15495831 ... -1.92223665 -0.25168127
  -0.50151467]
 [ 0.06142399 -0.01017261 -0.19413812 ... -0.19771384 -0.76659065
  -0.8818947 ]
 [ 0.55597846 -0.134753   -0.33965416 ...  1.56157599  0.11954717
   0.78507476]
 ...
 [ 0.67426958  1.20020747  0.14148159 ...  0.41308955 -0.23325789
  -0.64505478]
 [-0.33297286 -1.50260278 -0.16526201 ... -0.18128574  0.20487291
   0.35567725]
 [-0.40660393 -1.49291134 -0.12414121 ... -0.29391649  0.24661563
   0.43689343]]
[[ 0.15531333  0.40850968  0.15495831 ... -1.92223665 -0.25168127
  -0.50151467]
 [ 0.06142399 -0.01017261 -0.19413812 ... -0.19771384 -0.76659065
  -0.8818947 ]
 [ 0.55597846 -0.134753   -0.33965416 ...  1.56157599  0.11954717
   0.78507476]
 ...
 [ 0.67426958  1.20020747  0.14148159 ...  0.41308955 -0.23325789
  -0.64505478]
 [-0.33297286 -1.50260278 -0.16526201 ... -0.18128574  0.20487291
   0.35567725]
 [-0.40660393 -1.49291134 -0.12414121 ... -0.29391649  0.24661563
   0.43689343]]
[[ 0.15531333  0

100%|██████████| 12/12 [00:31<00:00,  2.59s/it]

╒═══════════════╤═══════╤═══════╤═══════╕
│               │   0.5 │   0.1 │   0.9 │
╞═══════════════╪═══════╪═══════╪═══════╡
│ Full-Cov      │ 0.154 │ 0.247 │ 0.706 │
├───────────────┼───────┼───────┼───────┤
│ Diag-Cov      │ 0.153 │ 0.278 │ 0.606 │
├───────────────┼───────┼───────┼───────┤
│ Tied Full-Cov │ 0.131 │ 0.235 │ 0.536 │
├───────────────┼───────┼───────┼───────┤
│ Tied Diag-Cov │ 0.163 │ 0.293 │ 0.611 │
╘═══════════════╧═══════╧═══════╧═══════╛
Full-Cov - prior = 0.5 - data id = 1



  0%|          | 0/24 [00:00<?, ?it/s]

Full-Cov - prior = 0.1 - data id = 1
Full-Cov - prior = 0.9 - data id = 1
Diag-Cov - prior = 0.5 - data id = 1
Diag-Cov - prior = 0.1 - data id = 1
Diag-Cov - prior = 0.9 - data id = 1
Tied Full-Cov - prior = 0.5 - data id = 1
Tied Full-Cov - prior = 0.1 - data id = 1
Tied Full-Cov - prior = 0.9 - data id = 1
Tied Diag-Cov - prior = 0.5 - data id = 1
Tied Diag-Cov - prior = 0.1 - data id = 1
Tied Diag-Cov - prior = 0.9 - data id = 1
[[0.65813042 0.93171005 0.547719   ... 0.62202631 0.40316261 0.32941506]
 [0.54813882 0.17968094 0.61321019 ... 0.96529527 0.80422614 0.8288553 ]
 [0.38133221 0.39056815 0.38189197 ... 0.92009516 0.53274559 0.64833473]
 ...
 [0.46739435 0.35838231 0.35348447 ... 0.86314022 0.24867058 0.24979009]
 [0.43059054 0.40078366 0.47788973 ... 0.9184159  0.70696893 0.72334173]
 [0.54701931 0.18247971 0.33683179 ... 0.59165967 0.39266723 0.24377274]]
[[0.65813042 0.93171005 0.547719   ... 0.62202631 0.40316261 0.32941506]
 [0.54813882 0.17968094 0.61321019 ... 0.96529

  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]





  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[ 0.40736605  1.48864868  0.11990047 ...  0.31080695 -0.24516947
  -0.44152894]
 [ 0.12096045 -0.91658169  0.28769577 ...  1.81574525  0.85681394
   0.94965144]
 [-0.30198378 -0.27783855 -0.30051555 ...  1.40571193  0.08217341
   0.38082847]
 ...
 [-0.08182146 -0.36278617 -0.37593    ...  1.09453692 -0.67867917
  -0.67515045]
 [-0.17487089 -0.25131923 -0.05545063 ...  1.39449491  0.54455134
   0.59279771]
 [ 0.11813414 -0.90595548 -0.4211253  ...  0.23181635 -0.27237401
  -0.69421804]]
[[ 0.40736605  1.48864868  0.11990047 ...  0.31080695 -0.24516947
  -0.44152894]
 [ 0.12096045 -0.91658169  0.28769577 ...  1.81574525  0.85681394
   0.94965144]
 [-0.30198378 -0.27783855 -0.30051555 ...  1.40571193  0.08217341
   0.38082847]
 ...
 [-0.08182146 -0.36278617 -0.37593    ...  1.09453692 -0.67867917
  -0.67515045]
 [-0.17487089 -0.25131923 -0.05545063 ...  1.39449491  0.54455134
   0.59279771]
 [ 0.11813414 -0.90595548 -0.4211253  ...  0.23181635 -0.27237401
  -0.69421804]]
[[0.65813042 0.9

  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[0.65813042 0.93171005 0.547719   ... 0.62202631 0.40316261 0.32941506]
 [0.54813882 0.17968094 0.61321019 ... 0.96529527 0.80422614 0.8288553 ]
 [0.38133221 0.39056815 0.38189197 ... 0.92009516 0.53274559 0.64833473]
 ...
 [0.46739435 0.35838231 0.35348447 ... 0.86314022 0.24867058 0.24979009]
 [0.43059054 0.40078366 0.47788973 ... 0.9184159  0.70696893 0.72334173]
 [0.54701931 0.18247971 0.33683179 ... 0.59165967 0.39266723 0.24377274]][[0.65813042 0.93171005 0.547719   ... 0.62202631 0.40316261 0.32941506]
 [0.54813882 0.17968094 0.61321019 ... 0.96529527 0.80422614 0.8288553 ]
 [0.38133221 0.39056815 0.38189197 ... 0.92009516 0.53274559 0.64833473]
 ...
 [0.46739435 0.35838231 0.35348447 ... 0.86314022 0.24867058 0.24979009]
 [0.43059054 0.40078366 0.47788973 ... 0.9184159  0.70696893 0.72334173]
 [0.54701931 0.18247971 0.33683179 ... 0.59165967 0.39266723 0.24377274]]

[[ 0.40736605  1.48864868  0.11990047 ...  0.31080695 -0.24516947
  -0.44152894]
 [ 0.12096045 -0.91658169  0.28

  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]
  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[ 0.40736605  1.48864868  0.11990047 ...  0.31080695 -0.24516947
  -0.44152894]
 [ 0.12096045 -0.91658169  0.28769577 ...  1.81574525  0.85681394
   0.94965144]
 [-0.30198378 -0.27783855 -0.30051555 ...  1.40571193  0.08217341
   0.38082847]
 ...
 [-0.08182146 -0.36278617 -0.37593    ...  1.09453692 -0.67867917
  -0.67515045]
 [-0.17487089 -0.25131923 -0.05545063 ...  1.39449491  0.54455134
   0.59279771]
 [ 0.11813414 -0.90595548 -0.4211253  ...  0.23181635 -0.27237401
  -0.69421804]]
[[ 0.40736605  1.48864868  0.11990047 ...  0.31080695 -0.24516947
  -0.44152894]
 [ 0.12096045 -0.91658169  0.28769577 ...  1.81574525  0.85681394
   0.94965144]
 [-0.30198378 -0.27783855 -0.30051555 ...  1.40571193  0.08217341
   0.38082847]
 ...
 [-0.08182146 -0.36278617 -0.37593    ...  1.09453692 -0.67867917
  -0.67515045]
 [-0.17487089 -0.25131923 -0.05545063 ...  1.39449491  0.54455134
   0.59279771]
 [ 0.11813414 -0.90595548 -0.4211253  ...  0.23181635 -0.27237401
  -0.69421804]]
[[0.65813042 0.9

  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[0.65813042 0.93171005 0.547719   ... 0.62202631 0.40316261 0.32941506]
 [0.54813882 0.17968094 0.61321019 ... 0.96529527 0.80422614 0.8288553 ]
 [0.38133221 0.39056815 0.38189197 ... 0.92009516 0.53274559 0.64833473]
 ...
 [0.46739435 0.35838231 0.35348447 ... 0.86314022 0.24867058 0.24979009]
 [0.43059054 0.40078366 0.47788973 ... 0.9184159  0.70696893 0.72334173]
 [0.54701931 0.18247971 0.33683179 ... 0.59165967 0.39266723 0.24377274]]

  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]
  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]





  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[ 0.40736605  1.48864868  0.11990047 ...  0.31080695 -0.24516947
  -0.44152894]
 [ 0.12096045 -0.91658169  0.28769577 ...  1.81574525  0.85681394
   0.94965144]
 [-0.30198378 -0.27783855 -0.30051555 ...  1.40571193  0.08217341
   0.38082847]
 ...
 [-0.08182146 -0.36278617 -0.37593    ...  1.09453692 -0.67867917
  -0.67515045]
 [-0.17487089 -0.25131923 -0.05545063 ...  1.39449491  0.54455134
   0.59279771]
 [ 0.11813414 -0.90595548 -0.4211253  ...  0.23181635 -0.27237401
  -0.69421804]][[ 0.40736605  1.48864868  0.11990047 ...  0.31080695 -0.24516947
  -0.44152894]
 [ 0.12096045 -0.91658169  0.28769577 ...  1.81574525  0.85681394
   0.94965144]
 [-0.30198378 -0.27783855 -0.30051555 ...  1.40571193  0.08217341
   0.38082847]
 ...
 [-0.08182146 -0.36278617 -0.37593    ...  1.09453692 -0.67867917
  -0.67515045]
 [-0.17487089 -0.25131923 -0.05545063 ...  1.39449491  0.54455134
   0.59279771]
 [ 0.11813414 -0.90595548 -0.4211253  ...  0.23181635 -0.27237401
  -0.69421804]]

[[ 0.40736605  1

  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[ 0.40736605  1.48864868  0.11990047 ...  0.31080695 -0.24516947
  -0.44152894]
 [ 0.12096045 -0.91658169  0.28769577 ...  1.81574525  0.85681394
   0.94965144]
 [-0.30198378 -0.27783855 -0.30051555 ...  1.40571193  0.08217341
   0.38082847]
 ...
 [-0.08182146 -0.36278617 -0.37593    ...  1.09453692 -0.67867917
  -0.67515045]
 [-0.17487089 -0.25131923 -0.05545063 ...  1.39449491  0.54455134
   0.59279771]
 [ 0.11813414 -0.90595548 -0.4211253  ...  0.23181635 -0.27237401
  -0.69421804]]
[[0.65813042 0.93171005 0.547719   ... 0.62202631 0.40316261 0.32941506]
 [0.54813882 0.17968094 0.61321019 ... 0.96529527 0.80422614 0.8288553 ]
 [0.38133221 0.39056815 0.38189197 ... 0.92009516 0.53274559 0.64833473]
 ...
 [0.46739435 0.35838231 0.35348447 ... 0.86314022 0.24867058 0.24979009]
 [0.43059054 0.40078366 0.47788973 ... 0.9184159  0.70696893 0.72334173]
 [0.54701931 0.18247971 0.33683179 ... 0.59165967 0.39266723 0.24377274]]


  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[ 0.40736605  1.48864868  0.11990047 ...  0.31080695 -0.24516947
  -0.44152894]
 [ 0.12096045 -0.91658169  0.28769577 ...  1.81574525  0.85681394
   0.94965144]
 [-0.30198378 -0.27783855 -0.30051555 ...  1.40571193  0.08217341
   0.38082847]
 ...
 [-0.08182146 -0.36278617 -0.37593    ...  1.09453692 -0.67867917
  -0.67515045]
 [-0.17487089 -0.25131923 -0.05545063 ...  1.39449491  0.54455134
   0.59279771]
 [ 0.11813414 -0.90595548 -0.4211253  ...  0.23181635 -0.27237401
  -0.69421804]]
[[0.65813042 0.93171005 0.547719   ... 0.62202631 0.40316261 0.32941506]
 [0.54813882 0.17968094 0.61321019 ... 0.96529527 0.80422614 0.8288553 ]
 [0.38133221 0.39056815 0.38189197 ... 0.92009516 0.53274559 0.64833473]
 ...
 [0.46739435 0.35838231 0.35348447 ... 0.86314022 0.24867058 0.24979009]
 [0.43059054 0.40078366 0.47788973 ... 0.9184159  0.70696893 0.72334173]
 [0.54701931 0.18247971 0.33683179 ... 0.59165967 0.39266723 0.24377274]]


  return [np.extract(cond, arr1 * expand_arr) for arr1 in newargs]


[[ 0.40736605  1.48864868  0.11990047 ...  0.31080695 -0.24516947
  -0.44152894]
 [ 0.12096045 -0.91658169  0.28769577 ...  1.81574525  0.85681394
   0.94965144]
 [-0.30198378 -0.27783855 -0.30051555 ...  1.40571193  0.08217341
   0.38082847]
 ...
 [-0.08182146 -0.36278617 -0.37593    ...  1.09453692 -0.67867917
  -0.67515045]
 [-0.17487089 -0.25131923 -0.05545063 ...  1.39449491  0.54455134
   0.59279771]
 [ 0.11813414 -0.90595548 -0.4211253  ...  0.23181635 -0.27237401
  -0.69421804]]
[[0.40862021 0.2647635  0.53778338 ... 0.76616289 0.19815281 0.02812762]
 [0.82059894 0.76952141 0.4806885  ... 0.69283515 0.63951861 0.02924713]
 [0.56143297 0.53736356 0.19521411 ... 0.76924153 0.28169605 0.10915197]
 ...
 [0.50055975 0.07010915 0.43898685 ... 0.77693815 0.4184159  0.76266443]
 [0.53456479 0.7672824  0.29722922 ... 0.72292191 0.27246012 0.65771061]
 [0.39686538 0.51357403 0.8616009  ... 0.71774419 0.95633921 0.02434929]]
[[0.40862021 0.2647635  0.53778338 ... 0.76616289 0.19815281 0.0

 50%|█████     | 12/24 [00:17<00:17,  1.45s/it]Process ForkProcess-119:
Process ForkProcess-115:
Process ForkProcess-132:
Process ForkProcess-118:
Process ForkProcess-139:
Process ForkProcess-140:
Process ForkProcess-120:
Process ForkProcess-117:
Process ForkProcess-107:
Process ForkProcess-126:
Process ForkProcess-122:
Process ForkProcess-137:
Process ForkProcess-138:
Process ForkProcess-104:
Process ForkProcess-94:
Process ForkProcess-113:
Process ForkProcess-101:
Process ForkProcess-133:
Process ForkProcess-131:
Process ForkProcess-129:
Process ForkProcess-121:
Process ForkProcess-125:
Process ForkProcess-99:
Process ForkProcess-123:
Process ForkProcess-96:
Process ForkProcess-127:
Process ForkProcess-136:
Process ForkProcess-141:
Process ForkProcess-116:
Process ForkProcess-135:
Process ForkProcess-124:
Process ForkProcess-103:
Process ForkProcess-105:
Process ForkProcess-98:
Process ForkProcess-106:
Process ForkProcess-114:
Process ForkProcess-97:
Process ForkProcess-100:
Process 