In [1]:
import numpy as np
from MINE.gtm import GTM

import tensorflow as tf
from knnie import kraskov_mi
import pandas as pd
from tqdm import tqdm

from MINE.augmentation import *

In [2]:
def calculate_cmi(X, Y, Z):
    return kraskov_mi(X, np.concatenate((Y, Z), axis=1)) - kraskov_mi(X, Z)

def estimate_mi(X, Y):
    return kraskov_mi(X, Y)

In [3]:
def prepare_datasets(n):
    def create_uni(n, d):
        tmp = np.random.choice(range(d), n)
        tmp2 = np.zeros((tmp.size, tmp.max() + 1))
        tmp2[np.arange(tmp.size), tmp] = 1
        return tmp2
    aug = Augmentation()

    if n == 10000:
        print('uni')
    
    uni = create_uni(n, 16)
    
    if n == 10000:
        print('uni aug')
    
    aug.transform(uni, n=10, m=1)

    mean = np.array([0, 1])
    cov = np.array([[1, 0], [0, 2]])
    
    if n == 10000:
        print('norm not corr')
    norm_not_corr = np.random.multivariate_normal(mean=mean, cov=cov, size=n)
    if n == 10000:
        print('norm not corr aug')
    aug.transform(norm_not_corr, n=10, m=1)

    mean = np.array([0, 1])
    cov = np.array([[1, 0.75], [0.75, 2]])
    
    if n == 10000:
        print('norm corr')
    norm_corr = np.random.multivariate_normal(mean=mean, cov=cov, size=n)
    if n == 10000:
        print('norm corr aug')
    aug.transform(norm_corr, n=10, m=1)


    cov = np.array([[ 2.97, -0.36,  1.12, -0.97,  0.07,  0.96,  2.36, -0.55,  0.88],
       [-0.36,  1.27,  0.07, -0.2 , -0.98, -0.97, -0.49,  0.46,  0.59],
       [ 1.12,  0.07,  4.21,  0.27, -2.04, -1.01,  0.45,  0.26,  0.73],
       [-0.97, -0.2 ,  0.27,  2.52, -0.57, -1.22,  0.45,  0.41, -0.89],
       [ 0.07, -0.98, -2.04, -0.57,  2.73,  2.26,  0.47, -1.12, -0.01],
       [ 0.96, -0.97, -1.01, -1.22,  2.26,  2.82,  0.78, -2.03,  0.51],
       [ 2.36, -0.49,  0.45,  0.45,  0.47,  0.78,  3.22, -0.99,  0.96],
       [-0.55,  0.46,  0.26,  0.41, -1.12, -2.03, -0.99,  2.75, -0.98],
       [ 0.88,  0.59,  0.73, -0.89, -0.01,  0.51,  0.96, -0.98,  2.2 ]])

    mean = np.array([-4, -3, -2, -1, 0, 1, 2, 3, 4])
    
    if n == 10000:
        print('norm hd')

    norm_hd = np.random.multivariate_normal(mean=mean, cov=cov, size=n)
    
    if n == 10000:
        print('norm hd aug')
#     aug.transform(norm_hd, n=10, m=1)

    datasets = {'uni': uni,
               'norm_not_corr': norm_not_corr,
               'norm_corr': norm_corr,
               'norm_hd': norm_hd}
    return datasets

In [4]:
np.random.seed(77)
seeds = np.random.choice(1_000_000_000, size=100, replace=False)

In [5]:
seeds.shape

(100,)

# MI

In [None]:
result = []
for n in [100, 1000, 10000]:
    print(n)
    for seed in tqdm(seeds):
        tf.keras.utils.set_random_seed(int(seed))
        datasets = prepare_datasets(n)
        for dataset_name, dataset in datasets.items():
            if 'aug' in dataset_name:
                continue
            if 'uni' in dataset_name:
                x_ind = list(range(16))
                y_ind = list(range(16))
            else:
                if 'aug' not in dataset_name:
                    size = dataset.shape[1]//2
                    y_ind = list(range(size, dataset.shape[1]))
                else:
                    size = dataset[0].shape[1]//2
                    y_ind = list(range(size, dataset[0].shape[1]))
                x_ind = list(range(size))

                if 'norm_hd' in dataset_name:
                    y_ind = y_ind[1:]

#             print(dataset_name, x_ind, y_ind)

            mi = estimate_mi(dataset[:, x_ind], dataset[:, y_ind])
            row = [dataset_name, 'orig', n, 'ksg', None, None, seed, mi, None]
            result.append(row)

  3%|██▍                                                                               | 3/100 [00:00<00:03, 27.15it/s]

100


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 27.13it/s]
  0%|                                                                                          | 0/100 [00:00<?, ?it/s]

1000


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:38<00:00,  2.62it/s]
  0%|                                                                                          | 0/100 [00:00<?, ?it/s]

10000
uni
uni aug
norm not corr
norm not corr aug
norm corr
norm corr aug
norm hd
norm hd aug


In [7]:
kraskov_mi(dataset[:, x_ind], dataset[:, y_ind])

ValueError: math domain error

# CMI

In [5]:
result = []
for gamma in [0.6, 0.75, 0.9]:
    print(gamma)
    for n in [1000, 5000, 10000]:
        print('\t', n)
        for seed in seeds:
            tf.keras.utils.set_random_seed(int(seed))
            gtm = GTM(12, gamma)
            X, Y = gtm.generate(n)
            for i in range(1, 11):
                if i > 1:
                    continue
                result.append([gamma, n, i, calculate_cmi(X[:, i].reshape(-1, 1), Y.reshape(-1, 1), X[:, :i])])

0.6
	 1000
	 5000
	 10000
0.75
	 1000
	 5000
	 10000
0.9
	 1000
	 5000
	 10000


In [16]:
[(x, y) for x in range(3) for y in range(100, 103)]

[(0, 100),
 (0, 101),
 (0, 102),
 (1, 100),
 (1, 101),
 (1, 102),
 (2, 100),
 (2, 101),
 (2, 102)]

In [6]:
result2 = pd.DataFrame(result)
result2['seed'] = [seed for gamma in [0.6, 0.75, 0.9] for n in [1000, 5000, 10_000] for seed in seeds for c in range(1, 2)]
result2.columns = ['gamma', 'n', 'c', 'cmi', 'seed']

In [7]:
result2.to_csv('KSG3.csv')

In [21]:
result1 = pd.read_csv('KSG.csv', index_col=0)
result1 = result1.loc[result1.c <= 10]

result2 = pd.read_csv('KSG3.csv', index_col=0)

result3 = pd.read_csv('KSG2.csv', index_col=0)

In [23]:
pd.concat((result1, result2, result3)).to_csv('KSG.csv')

# Selection

In [5]:
def select(x, y):
    y = y.reshape(-1, 1)
    selected = []
    remaining = set(list(range(x.shape[1])))
    max_ = -np.inf
    best_ = None
    for i in range(x.shape[1]):
        mi = estimate_mi(x[:, [i]], y.reshape(-1, 1))
        if mi > max_:
            max_ = mi
            best_ = i

    selected.append(best_)
    remaining.remove(best_)

    while len(remaining) > 1:
        max_ = -np.inf
        best_ = None
        for r in remaining:
            cmi = calculate_cmi(x[:, [r]], y.reshape(-1, 1), x[:, selected])
            if cmi > max_:
                max_ = cmi
                best_ = r
        selected.append(best_)
        remaining.remove(best_)
    selected.append(remaining.pop())
    return selected

In [6]:
result = []
for gamma in [0.6, 0.75, 0.9]:
    print(gamma)
    for n in [1000, 5000, 10000]:
        print('\t', n)
        for seed in seeds:
            tf.keras.utils.set_random_seed(int(seed))
            gtm = GTM(10, gamma)
            X, Y = gtm.generate(n)
            
            result.append([gamma, n, select(X, Y)])

0.6
	 1000
	 5000
	 10000
0.75
	 1000
	 5000
	 10000
0.9
	 1000
	 5000
	 10000


In [10]:
result = pd.DataFrame(result, columns=['gamma', 'n', 'selection'])

In [11]:
result['seed'] = [seed for gamma in [0.6, 0.75, 0.9] for n in [1000, 5000, 10_000] for seed in seeds]

In [13]:
result.to_csv('ksg_selection.csv')