In [None]:
# Change directory to VSCode workspace root so that relative path loads work correctly. Turn this addition off with the DataScience.changeDirOnImportExport setting
# ms-toolsai.jupyter added
import os
try:
	os.chdir(os.path.join(os.getcwd(), '..'))
	print(os.getcwd())
except:
	pass

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpl_patches


from numpy.random import dirichlet
from scipy.stats import invwishart, multivariate_normal

from src.data.data import genereate_splitnet_dataset, genereate_splitnet_dataset_multi
from src.data.data import *
from src.utils.metrics import *

from sklearn.decomposition import PCA


### utils and plotting functions

In [None]:
def plot_dataset(X, y, niw_prior, alpha, title, num_ax=3):
    fig, ax = plt.subplots(num_ax, num_ax, figsize=(4*num_ax, 4*num_ax), facecolor='white')
    # plt.figure(figsize=(4*num_ax, 4*num_ax), facecolor='white')

    for i in range(num_ax**2):

        x, y = X[i], Y[i]
        x_ = x
        if x.shape[-1] > 2:
            pca = PCA(n_components=2)
            x = pca.fit_transform(x)


        r, c = i // num_ax, i % num_ax
        # splot = plt.subplot(num_ax, num_ax, i+1)
        ax[r, c].scatter(x[y==0,0], x[y==0,1], c='b', alpha=0.3)
        ax[r, c].scatter(x[y==1,0], x[y==1,1], c='r', alpha=0.3)
        H =  log_hasting_ratio(x_, y, niw_prior, alpha)
        ax[r, c].set_title(f"Log Hasting Ratio: {H:.2f}")
        plt.axis('tight')

    fig.suptitle(title)
    plt.tight_layout()
    plt.subplots_adjust(top=0.92)
    plt.show()

In [4]:
def plot_loghr_histograms(X, Y, niw_prior, alpha, title):
    gt_hr_mat = calc_hastings_ratio_gt(X, Y, niw_prior, alpha)
    kmean_hr_mat = calc_hastings_ratio_kmeans(X, niw_prior, alpha)

    histograms = {
        "gt": gt_hr_mat,
        "kmeans": kmean_hr_mat,
    }

    fig = sns.displot(histograms, kind="kde")

    handles = [
        mpl_patches.Rectangle(
            (0, 0), 1, 1, fc="white", ec="white", lw=0, alpha=0
        )
    ] * 2

    labels = []
    labels.append(
        f"NIW Prior: \n kappa={niw_prior.k} \n nu={niw_prior.v} \n D={niw_prior.psi.shape[0]}"
    )
    labels.append(f"alpha={alpha}")

    plt.legend(
        handles,
        labels,
        loc="right",
        fontsize="medium",
        fancybox=True,
        framealpha=0.7,
        handlelength=0,
        handletextpad=0,
    )

    plt.suptitle(f"{title} | log Hastings Ratio Histogram")

    plt.show()


# Easy Data:

In [5]:
nu = 19
k = 0.1
D = 3
mu=np.zeros(D)
psi=np.eye(D)
dp_alpha = 100
alpha=10

easy_niw_prior = niw_hyperparams(k=k, v=nu, mu=mu, psi=psi)

X, Y = genereate_splitnet_dataset(
    niw_params=easy_niw_prior,
    dataset_size=100,
    num_points=2048,
    dp_alpha=dp_alpha,
    alpha=alpha,
    hr_threshold=0.001
    )




Generated data, shape: (100, 2048, 3)
Generated labels, shape: (100, 2048)


In [6]:
title = f"Easy {D}D Data"
plot_loghr_histograms(X, Y, easy_niw_prior, alpha, title)
plot_dataset(X, Y, easy_niw_prior, alpha, title, num_ax=2)


# Challenging Data:

In [None]:
nu = 5
k = 1
D = 3
mu=np.zeros(D)
psi=np.eye(D)
dp_alpha=100
alpha=5


niw_prior = niw_hyperparams(k=k, v=nu, mu=mu, psi=psi)

X, Y = genereate_splitnet_dataset(
    niw_params=niw_prior,
    dataset_size=100,
    num_points=2048,
    alpha=alpha,
    dp_alpha=dp_alpha,
    )

In [None]:
title = f"Less Easy {D}D Data"
plot_loghr_histograms(X, Y, niw_prior, alpha, title)
plot_dataset(X, Y, niw_prior, alpha, title, num_ax=3)


# Hard Data:

In [None]:
nu = 4
k = 10
D = 3
mu=np.zeros(D)
psi=np.eye(D)
dp_alpha=100
alpha = 1


niw_prior = niw_hyperparams(k=k, v=nu, mu=mu, psi=psi)

X, Y = genereate_splitnet_dataset(
    niw_params=niw_prior,
    dataset_size=100,
    num_points=2048,
    alpha=alpha,
    hr_threshold=0.01,
    dp_alpha=dp_alpha,
    )

In [None]:
title = f"Hard {D}D Data"
plot_loghr_histograms(X, Y, easy_niw_prior, alpha, title)
plot_dataset(X, Y, niw_prior, alpha, title, num_ax=3)


In [None]:
nu =10
k = 0.5
D=10
mu=np.zeros(D)
psi=np.eye(D)
dp_alpha=100

niw_prior = niw_hyperparams(k=k, v=nu, mu=mu, psi=psi)
alpha=2

X, Y = genereate_splitnet_dataset(
    niw_params=niw_prior,
    dataset_size=100,
    num_points=1000,
    alpha=alpha,
    dp_alpha=dp_alpha,
    hr_threshold=0.01
    )

In [None]:
title = f"Hard {D}D Data"
plot_loghr_histograms(X, Y, niw_prior, alpha, title)

# Multi-K | Easy Data:

In [None]:
nu = 20
k = 0.1
D = 10
mu=np.zeros(D)
psi=np.eye(D)
dp_alpha = 100
alpha=10

easy_niw_prior = niw_hyperparams(k=k, v=nu, mu=mu, psi=psi)

X, Y = genereate_splitnet_dataset_multi(
    niw_params=easy_niw_prior,
    dataset_size=1000,
    num_points=2024,
    dp_alpha=dp_alpha,
    alpha=alpha,
    hr_threshold=0.001,
    K_max=6,
    )

In [None]:
title = f"Easy {D}D Data"
plot_loghr_histograms(X, Y, easy_niw_prior, alpha, title)
plot_dataset(X, Y, easy_niw_prior, alpha, title, num_ax=3)
