In [12]:
import sys

import jax
import jax.numpy as jnp
import numpy as np
from numpyro.infer import MCMC, NUTS, Predictive  # noqa: F401

from mcmc.lmc import run_simple_lmc_numpyro  # noqa: F401


jax.config.update("jax_enable_x64", True)


def is_zero(x):
    return x == 0


def is_one(x):
    return x == 1


def balance_and_shuffle(x, y, n_samples=2000, zero_pred=is_zero, one_pred=is_one):
    half_n = n_samples // 2
    # we want the labels y (in 0,1) to be mostly balanced
    ones = np.squeeze(one_pred(y))
    zeros = np.squeeze(zero_pred(y))
    n_ones = np.sum(ones)
    n_zeros = np.sum(zeros)
    print(f"ones len: {n_ones}, zeros len: {n_zeros}")

    x_ones = x[ones]
    x_zeros = x[zeros]

    if n_ones < half_n:
        print("Using all ones")
        x = np.concatenate([x_ones, x_zeros[: (n_samples - n_ones)]])
        _y = np.concatenate([np.ones((n_ones, 1)), np.zeros((n_samples - n_ones, 1))])

    elif n_zeros < half_n:
        print("Using all zeros")
        x = np.concatenate([x_zeros, x_ones[: (n_samples - n_zeros)]])
        _y = np.concatenate([np.zeros((n_ones, 1)), np.ones((n_samples - n_ones, 1))])

    else:
        print("Using half of each")
        x = np.concatenate([x_ones[:half_n], x_zeros[:half_n]])
        _y = np.concatenate([np.ones((half_n, 1)), np.zeros((half_n, 1))])

    assert (np.shape(x)[0] == n_samples) and (np.shape(_y)[0] == n_samples)

    # now we shuffle the data
    perm = np.arange(n_samples)
    np.random.shuffle(perm)
    x = x[perm]
    _y = _y[perm]

    return x, _y


def filter_and_normalise(x, y, threshold=1e-6):
    # get rid of all columns with var < 1e-5
    var = np.var(x, axis=0)
    keep_indices = var > threshold
    x = x[:, keep_indices]
    # and normalise
    x = (x - x.mean(axis=0)) / x.std(axis=0)

    print(f"x shape: {x.shape}, y.shape: {y.shape}")

    return x, y

In [None]:
from ucimlrepo import fetch_ucirepo


# fetch dataset
# tbp = fetch_ucirepo(id=572)
isolet = fetch_ucirepo(id=54)

In [3]:
# data (as pandas dataframes)
x = np.array(isolet.data.features)
y = np.array(isolet.data.targets)

np.set_printoptions(threshold=sys.maxsize)
print(jnp.squeeze(y))

[ 1.  1.  2.  2.  3.  3.  4.  4.  5.  5.  6.  6.  7.  7.  8.  8.  9.  9.
 10. 10. 11. 11. 12. 12. 13. 13. 14. 14. 15. 15. 16. 16. 17. 17. 18. 18.
 19. 19. 20. 20. 21. 21. 22. 22. 23. 23. 24. 24. 25. 25. 26. 26.  1.  1.
  2.  2.  3.  3.  4.  4.  5.  5.  6.  6.  7.  7.  8.  8.  9.  9. 10. 10.
 11. 11. 12. 12. 13. 13. 14. 14. 15. 15. 16. 16. 17. 17. 18. 18. 19. 19.
 20. 20. 21. 21. 22. 22. 23. 23. 24. 24. 25. 25. 26. 26.  1.  1.  2.  2.
  3.  3.  4.  4.  5.  5.  6.  6.  7.  7.  8.  8.  9.  9. 10. 10. 11. 11.
 12. 12. 13. 13. 14. 14. 15. 15. 16. 16. 17. 17. 18. 18. 19. 19. 20. 20.
 21. 21. 22. 22. 23. 23. 24. 24. 25. 25. 26. 26.  1.  1.  2.  2.  3.  3.
  4.  4.  5.  5.  6.  6.  7.  7.  8.  8.  9.  9. 10. 10. 11. 11. 12. 12.
 13. 13. 14. 14. 15. 15. 16. 16. 17. 17. 18. 18. 19. 19. 20. 20. 21. 21.
 22. 22. 23. 23. 24. 24. 25. 25. 26. 26.  1.  1.  2.  2.  3.  3.  4.  4.
  5.  5.  6.  6.  7.  7.  8.  8.  9.  9. 10. 10. 11. 11. 12. 12. 13. 13.
 14. 14. 15. 15. 16. 16. 17. 17. 18. 18. 19. 19. 20

In [13]:
vowels = [1, 5, 9, 15, 21]
non_vowels = list(set(range(1, 27)) - set(vowels))
print(f"non_vowels: {non_vowels}")


def zero_pred_isolet(y):
    # is not a vowel
    return np.isin(y, non_vowels)


def one_pred_isolet(y):
    return np.isin(y, vowels)


x1, y1 = balance_and_shuffle(x, y, 1600, zero_pred_isolet, one_pred_isolet)

non_vowels: [2, 3, 4, 6, 7, 8, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 24, 25, 26]
ones len: 1500, zeros len: 6297
Using half of each


In [18]:
x2, y2 = filter_and_normalise(x1, y1, threshold=1e-3)

x shape: (1600, 617), y.shape: (1600, 1)


In [19]:
np.save("mcmc_data/isolet_x.npy", x2)
np.save("mcmc_data/isolet_y.npy", y2)