In [16]:
import sys
sys.path.append("/vol/bitbucket/ad6013/Research/gp-causal")
import numpy as np
import matplotlib.pyplot as plt

import pickle
from sklearn.mixture import BayesianGaussianMixture
from data import get_data
from sklearn.metrics import roc_curve, auc
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler

In [17]:
# Need to load all the data
def return_adam_files(data_name):
    files = [
        f"fullscore-{data_name}_pairs-gplvm_adam-reinit2-numind200_start:{i}_end:{i+20}.p"
        for i in np.linspace(0, 280, 15, dtype=int)
    ]
    return files


def return_bfgs_files(data_name):
    files = [
        f"fullscore-{data_name}_pairs-gplvm-reinit20-numind200_start:{i}_end:{i+20}.p"
        for i in np.linspace(0, 280, 15, dtype=int)
    ]
    return files

In [18]:
def convert_file_scores_into_dict(files):
    """
    Convert files saved by runs to one big dict of scores with the key being
    the run_number and the value being the tuple of a tuple of scores of the
    form ( (x, y|x), (y, x|y) )
    """
    work_dir = "/vol/bitbucket/ad6013/Research/gp-causal"
    all_scores = {}
    for file_idx in range(len(files)):
        # Open  file results
        with open(f"{work_dir}/results/{files[file_idx]}", "rb") as f:
            results = pickle.load(f)
        for i in range(len(results['scores'])):
            idx = 20 * file_idx + i
            all_scores[idx] = results["scores"][i]
    return all_scores

In [19]:
def return_best_between_two_scores(scores_1, scores_2):
    """
    Scores should be dict with values of tuples of tuples
    ( (x, y|x), (y, x|y) )
    """
    all_scores = {}
    for idx in scores_1.keys(): 
        if idx not in scores_2.keys():
            raise ValueError(f"Run idx mismatch for run {idx}")
        else:
            scores_1_idx = scores_1[idx]
            scores_2_idx = scores_2[idx]
            min_score_x = min(scores_1_idx[0][0], scores_2_idx[0][0])
            min_score_y_x = min(scores_1_idx[0][1], scores_2_idx[0][1]) 
            min_score_y = min(scores_1_idx[1][0], scores_2_idx[1][0])
            min_score_x_y = min(scores_1_idx[1][1], scores_2_idx[1][1]) 
            final_scores = (
                (min_score_x, min_score_y_x),
                (min_score_y, min_score_x_y)
            )
            all_scores[idx] = final_scores
    return all_scores

In [20]:
def return_best_between_marginal_and_full_scores(marginal_score, full_score):
    """
    Scores should be dict with values of tuples of tuples
    maringal_score = (x, y)
    full_score: ( (x, y|x), (y, x|y) )
    """
    all_scores = {}
    for idx in marginal_score.keys(): 
        if idx not in full_score.keys():
            raise ValueError(f"Run idx mismatch for run {idx}")
        else:
            scores_marg_idx = marginal_score[idx]
            scores_full_idx = full_score[idx]
            min_score_x = min(scores_marg_idx[0], scores_full_idx[0][0])
            # min_score_x = scores_marg_idx[0]
            min_score_y_x = scores_full_idx[0][1] 
            min_score_y = min(scores_marg_idx[1], scores_full_idx[1][0])
            # min_score_y =  scores_marg_idx[1] 
            min_score_x_y = scores_full_idx[1][1]
            final_scores = (
                (min_score_x, min_score_y_x),
                (min_score_y, min_score_x_y)
            )
            all_scores[idx] = final_scores
    return all_scores

In [21]:
def get_auc_scores(data_name, scores):
    work_dir = "/vol/bitbucket/ad6013/Research/gp-causal"
    data_get = getattr(get_data, f"get_{data_name}_pairs_dataset")
    x, y, weight, target = data_get(data_path=f"{work_dir}/data/{data_name}_pairs/files")

    y_scores = []
    y_labels = []
    for idx in scores.keys():
        causal = sum(scores[idx][0])
        anti_causal = sum(scores[idx][1])
        final_score = - causal + anti_causal 
        y_labels.append(target[idx][0])
        y_scores.append(final_score)

    fpr, tpr, _ = roc_curve(y_labels, y_scores)
    roc_auc = auc(fpr, tpr)
    return roc_auc


In [22]:
data_name = "gauss"
adam_files = return_adam_files(data_name=data_name)
bfgs_files = return_bfgs_files(data_name=data_name)
adam_scores = convert_file_scores_into_dict(adam_files)
bfgs_scores = convert_file_scores_into_dict(bfgs_files)
scores = return_best_between_two_scores(adam_scores, bfgs_scores)
adam_bfgs_auc = get_auc_scores(data_name, scores)

In [23]:
def bayesgmm_score(train_data, n_components):
    model = BayesianGaussianMixture(
        n_components=n_components,
        max_iter=int(1e6)
    ).fit(train_data)
    return - model.lower_bound_

In [24]:
def return_bayesian_gmm_scores(data_name, n_components):
    """
    Will return scores of Bayesian GMM for a dataset.
    """
    work_dir = "/vol/bitbucket/ad6013/Research/gp-causal"
    data_get = getattr(get_data, f"get_{data_name}_pairs_dataset")
    x, y, weight, target = data_get(data_path=f"{work_dir}/data/{data_name}_pairs/files")

    all_scores = {}
    for idx in tqdm(range(len(x)), desc="Running BayesGMM"):
        train_x = x[idx]
        train_y = y[idx]
        # Normalise the data
        train_x = StandardScaler().fit_transform(train_x).astype(np.float64)
        train_y = StandardScaler().fit_transform(train_y).astype(np.float64)


        x_score = bayesgmm_score(train_x, n_components=n_components)
        y_score = bayesgmm_score(train_y, n_components=n_components)
        all_scores[idx] = (x_score, y_score)
    return all_scores

In [25]:
bayesgmm_scores = return_bayesian_gmm_scores(data_name, 5)

Running BayesGMM: 100%|██████████| 300/300 [02:32<00:00,  1.96it/s]


In [26]:
bayesgmm_scores

{0: (762.1480443197194, 760.6897816002702),
 1: (-571.1022835520236, 726.6546543093609),
 2: (624.4273110213992, 140.5453067595722),
 3: (427.38879475376166, 475.67074959890357),
 4: (-158.4897002450028, 521.1092608275321),
 5: (582.7384735902588, -103.0394652119129),
 6: (220.34943347650196, 350.139896713779),
 7: (633.3337593478742, 720.2403394652795),
 8: (733.1218941975048, 329.01176608626355),
 9: (697.7131195661434, 538.6364314797013),
 10: (743.7334138210357, 673.2466117587403),
 11: (72.83318061570638, 449.0175652191142),
 12: (1.843446101545453, 359.0796672189982),
 13: (-70.74753906683029, 481.1346917766029),
 14: (60.95321991899405, 709.5786776650826),
 15: (718.4095923219297, 545.9791072976379),
 16: (423.96033259695804, 677.7873432075389),
 17: (737.4781889287898, 393.65482268865867),
 18: (724.7936359977261, 694.7995024358656),
 19: (106.68274002863697, 254.270698547305),
 20: (673.406084855834, -78.71295553411798),
 21: (655.7755973051461, 702.166603547343),
 22: (702.93

In [27]:
scores

{0: ((2128.407857868433, 1161.0878201043197),
  (2128.4078116252413, 1148.4036079883244)),
 1: ((570.5971714916177, 1838.5358821425198),
  (2112.0086279249554, 580.7883091767762)),
 2: ((2005.6700882446257, 322.6656308062661),
  (1501.5561124811306, 546.9319079182459)),
 3: ((1818.9317943955566, 608.2462765807238),
  (1892.3150285955735, 552.4405312524282)),
 4: ((1242.8962607426438, 1198.644474791121),
  (1915.2626741640383, 571.618946652819)),
 5: ((1960.4431778792696, 284.74672856934285),
  (1220.5994948699245, 942.6566576545633)),
 6: ((1619.4635750043549, 81.01549898760186),
  (1686.3509386314536, -30.510824818379774)),
 7: ((2023.1581160308178, 1480.8638876358823),
  (2112.0899607198558, 1385.5095547160531)),
 8: ((2111.807490240761, 974.3174307300137),
  (1714.4528111335503, 1273.929300188136)),
 9: ((2088.168792159807, 1252.853572184294),
  (1910.2712065961941, 1382.3287235646176)),
 10: ((2120.1316699746, 1675.5651575881989),
  (2062.9305273472696, 1720.5958837282549)),
 11: (

In [28]:
best_gmm_full_scores = return_best_between_marginal_and_full_scores(
    bayesgmm_scores, scores 
)
bayes_gmm_auc = get_auc_scores(data_name, best_gmm_full_scores)

In [29]:
best_gmm_full_scores

{0: ((762.1480443197194, 1161.0878201043197),
  (760.6897816002702, 1148.4036079883244)),
 1: ((-571.1022835520236, 1838.5358821425198),
  (726.6546543093609, 580.7883091767762)),
 2: ((624.4273110213992, 322.6656308062661),
  (140.5453067595722, 546.9319079182459)),
 3: ((427.38879475376166, 608.2462765807238),
  (475.67074959890357, 552.4405312524282)),
 4: ((-158.4897002450028, 1198.644474791121),
  (521.1092608275321, 571.618946652819)),
 5: ((582.7384735902588, 284.74672856934285),
  (-103.0394652119129, 942.6566576545633)),
 6: ((220.34943347650196, 81.01549898760186),
  (350.139896713779, -30.510824818379774)),
 7: ((633.3337593478742, 1480.8638876358823),
  (720.2403394652795, 1385.5095547160531)),
 8: ((733.1218941975048, 974.3174307300137),
  (329.01176608626355, 1273.929300188136)),
 9: ((697.7131195661434, 1252.853572184294),
  (538.6364314797013, 1382.3287235646176)),
 10: ((743.7334138210357, 1675.5651575881989),
  (673.2466117587403, 1720.5958837282549)),
 11: ((72.83318

In [30]:
print(f"AUC: {adam_bfgs_auc}, GMM AUC: {bayes_gmm_auc}")

AUC: 0.8935242839352429, GMM AUC: 0.8626578900551504
