In [1]:
from tqdm import tqdm
import numpy as np
from mutinfo.distributions.base import CorrelatedNormal, CorrelatedUniform, CorrelatedStudent, SmoothedUniform

def run_tests(distribution_factory, estimator, MI_grid, n_samples, n_runs):
    estimated_MI = []

    for mutual_information in tqdm(MI_grid):
        current_run_estimates = []
        for run in range(n_runs):
            random_variable = distribution_factory(mutual_information)
            x, y = random_variable.rvs(n_samples)
        
            current_run_estimates.append(estimator(x, y))
        
        current_run_estimates = np.array(current_run_estimates)
        mean = np.mean(current_run_estimates)
        std = np.std(current_run_estimates) / np.sqrt(n_runs)
        
        estimated_MI.append([mean, 3*std])
    
    estimated_MI = np.array(estimated_MI)
    return estimated_MI

In [2]:
import jax
jax.config.update('jax_platforms', 'cpu')
import bmi
from minde.minde import MINDE
from minde.scripts.helper import get_data_loader, get_default_config

args = get_default_config()
args.type ="j"
args.importance_sampling = True
args.use_ema = True
args.checkpoint_dir = r"/home/foresti/minde/checkpoints"



In [3]:
MI_grid = np.linspace(0.0, 10.0, 11)
# MI_grid = np.array([2.0])
print(MI_grid)
n_samples = 1000
n_runs = 10

X_dimension = 10
Y_dimension = 10

var_list = {"x": X_dimension, "y": Y_dimension}
estimator = MINDE(args, var_list=var_list)

dimension = max(X_dimension, Y_dimension)

estimated_MI = run_tests(
    lambda mutual_information : CorrelatedNormal(mutual_information, X_dimension, Y_dimension),
    estimator=estimator,
    MI_grid=MI_grid,
    n_samples=n_samples,
    n_runs=n_runs
)

[0.]


  0%|          | 0/1 [00:00<?, ?it/s]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name      | Type           | Params
---------------------------------------------
0 | score     | UnetMLP_simple | 172 K 
1 | model_ema | EMA            | 172 K 
---------------------------------------------
344 K     Trainable params
0         Non-trainable params
344 K     Total params
1.376     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/foresti/miniconda3/envs/minde/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
/home/foresti/miniconda3/envs/minde/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
/home/foresti/miniconda3/envs/minde/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (4) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Epoch:  1  GT:  Not given MINDE_estimate:  0.001 MINDE_sigma_estimate:  -0.0
Epoch:  2  GT:  Not given MINDE_estimate:  0.002 MINDE_sigma_estimate:  0.001
Epoch:  3  GT:  Not given MINDE_estimate:  0.005 MINDE_sigma_estimate:  -0.0
Epoch:  4  GT:  Not given MINDE_estimate:  0.009 MINDE_sigma_estimate:  0.004
Epoch:  5  GT:  Not given MINDE_estimate:  0.014 MINDE_sigma_estimate:  -0.002
Epoch:  6  GT:  Not given MINDE_estimate:  0.02 MINDE_sigma_estimate:  0.009
Epoch:  7  GT:  Not given MINDE_estimate:  0.027 MINDE_sigma_estimate:  0.002
Epoch:  8  GT:  Not given MINDE_estimate:  0.035 MINDE_sigma_estimate:  0.017
Epoch:  9  GT:  Not given MINDE_estimate:  0.044 MINDE_sigma_estimate:  -0.001
Epoch:  10  GT:  Not given MINDE_estimate:  0.055 MINDE_sigma_estimate:  0.002
Epoch:  11  GT:  Not given MINDE_estimate:  0.066 MINDE_sigma_estimate:  0.033
Epoch:  12  GT:  Not given MINDE_estimate:  0.08 MINDE_sigma_estimate:  0.019
Epoch:  13  GT:  Not given MINDE_estimate:  0.097 MINDE_sigma_e

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch:  49  GT:  Not given MINDE_estimate:  1.158 MINDE_sigma_estimate:  4.665
Epoch:  50  GT:  Not given MINDE_estimate:  1.21 MINDE_sigma_estimate:  4.602
Epoch:  51  GT:  Not given MINDE_estimate:  1.25 MINDE_sigma_estimate:  4.72
Epoch:  52  GT:  Not given MINDE_estimate:  1.289 MINDE_sigma_estimate:  4.665
Epoch:  53  GT:  Not given MINDE_estimate:  1.315 MINDE_sigma_estimate:  4.708
Epoch:  54  GT:  Not given MINDE_estimate:  1.344 MINDE_sigma_estimate:  4.64
Epoch:  55  GT:  Not given MINDE_estimate:  1.379 MINDE_sigma_estimate:  4.645
Epoch:  56  GT:  Not given MINDE_estimate:  1.405 MINDE_sigma_estimate:  4.704


In [None]:
import matplotlib
import matplotlib.pyplot as plt


def plot_estimated_MI(MI, estimated_MI, title):
    """
    Plot estimated mutual information values.
    """
    
    estimated_MI_mean = estimated_MI[:,0]
    estimated_MI_std  = estimated_MI[:,1]
    
    fig_normal, ax_normal = plt.subplots()

    fig_normal.set_figheight(11)
    fig_normal.set_figwidth(16)

    # Grid.
    ax_normal.grid(color='#000000', alpha=0.15, linestyle='-', linewidth=1, which='major')
    ax_normal.grid(color='#000000', alpha=0.1, linestyle='-', linewidth=0.5, which='minor')

    ax_normal.set_title(title)
    ax_normal.set_xlabel("$I(X,Y)$")
    ax_normal.set_ylabel("$\\hat I(X,Y)$")
    
    ax_normal.minorticks_on()

    ax_normal.plot(MI, MI, label="$I(X,Y)$", color='red')
    ax_normal.plot(MI, estimated_MI_mean, label="$\\hat I(X,Y)$")        
    ax_normal.fill_between(MI, estimated_MI_mean + estimated_MI_std, estimated_MI_mean - estimated_MI_std, alpha=0.2)

    ax_normal.legend(loc='upper left')

    ax_normal.set_xlim((0.0, None))
    ax_normal.set_ylim((0.0, None))

    plt.show();

In [4]:
plot_estimated_MI(MI_grid, estimated_MI, "Multivariate correlated uniform distribution")

ModuleNotFoundError: No module named 'mutinfo.utils.plots'