# Detecting Independence

In *Neural Conditional Probability for Uncertainty Quantification* (Kostic et al., 2024), the authors claim that the (deflated) conditional expectation operator can be used to detect the independence of two random variables X and Y by verifying whether it is zero. Here, we show this equivaliance in practice.

## Dataset

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split


def make_dataset(n_samples: int = 200, t: float = 0.0):
    """Draw sample from data model Y = tX + (1-t)X_, where X and X_ are independent gaussians.

    If t = 0, then X and Y are independent. Otherwise, if t->1, X and Y become ever more dependent.

    Args:
        n_samples (int, optional): Number of samples. Defaults to 200.
        t (float, optional): Interpolation factor. Defaults to 0.0.
    """
    X = torch.normal(mean=0, std=1, size=(n_samples, 1))
    X_ = torch.normal(mean=0, std=1, size=(n_samples, 1))
    Y = t * X + (1 - t) * X_

    ds = TensorDataset(X, Y)

    # Split data into train and val sets
    train_ds, val_ds = random_split(ds, lengths=[0.85, 0.15])

    return train_ds, val_ds

## Code for training NCP

In [None]:
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from torch.nn import Module
from torch.optim import Adam, Optimizer

from linear_operator_learning.nn import MLP, NCP, L2ContrastiveLoss
from linear_operator_learning.nn.functional import orthonormality_regularization


class NCPTrainingModule(L.LightningModule):
    """Optional LightningModule for training NCP. It isn't necessary to use lightning to run NCP!

    Args:
        ncp (NCP): NCP Module. See modules/ncp.py for more details.
        loss (Module, optional): Loss function for training NCP. Defauls to L2ContrastiveLoss,
            the loss used in the paper.
        loss_kwargs (dict, optional): Keyword arguments to be passed to the loss. Defaults to dict(),
            as the L2ContrastiveLoss doesn't implement regularization.
        gamma (float, optional): Orthonormality regularization strength. Defauls to 1e-3.
        optimizer (Optimizer, optional): Torch optimizer for optimizing NCP. Defaults to Adam.
        optimizer_kwargs (dict, optional): Keyword arguments to be passed to the optimizer.
            Defaults to {"lr": 5e-4}.
    """

    def __init__(
        self,
        # Hack to store the results of different runs without heavy machinery.
        results: dict,
        run_id: tuple,
        # NCP training interface begins here:
        ncp: NCP,
        loss: Module = L2ContrastiveLoss,
        loss_kwargs: dict = dict(),
        gamma: float = 1e-3,
        optimizer: Optimizer = Adam,
        optimizer_kwargs: dict = {"lr": 5e-4},
    ):
        super().__init__()
        self.results = results
        self.run_id = run_id
        self.ncp = ncp
        self.loss = loss(**loss_kwargs)
        self.gamma = gamma
        self._optimizer = optimizer
        self._optimizer_kwargs = optimizer_kwargs

    def configure_optimizers(self):
        return self._optimizer(self.parameters(), **self._optimizer_kwargs)

    def training_step(self, batch, batch_idx):
        u, v = self.ncp(*batch)

        Dr = self.ncp.S.weights  # TODO: Is there a better name?
        _loss = self.loss(u, (Dr @ v.T).T)  # Dr @ v

        # TODO: Pass params
        # reg = orthonormality_regularization()
        reg = 0

        loss = _loss + self.gamma * reg

        self.log("loss/train", loss, prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        u, v = self.ncp(*batch)

        Dr = self.ncp.S.weights  # TODO: Is there a better name?
        _loss = self.loss(u, (Dr @ v.T).T)  # Dr @ v

        # TODO: Pass params
        # reg = orthonormality_regularization()
        reg = 0

        loss = _loss + self.gamma * reg

        self.log("loss/val", loss, prog_bar=False)
        return loss

    def on_fit_end(self):
        """Perform whitening. In real-world applications, this would use the entire training set."""
        WHITENING_N_SAMPLES = 2000
        t = self.run_id[0]
        x = torch.normal(mean=0, std=1, size=(WHITENING_N_SAMPLES, 1))
        x_ = torch.normal(mean=0, std=1, size=(WHITENING_N_SAMPLES, 1))
        y = t * x + (1 - t) * x_
        self.ncp._update_whitening_buffers(x=x, y=y)
        self.results[self.run_id] = self.ncp.sing_val

## Detection happens here

In [None]:
from pathlib import Path

import torch
from lightning import seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

RUN_PATH = Path("runs")
RUN_PATH.mkdir(exist_ok=True)

SEED = 1
REPEATS = 1
BATCH_SIZE = 256
N_SAMPLES = 5000
NCP_PARAMS = dict(
    output_shape=2,
    n_hidden=2,
    layer_size=32,
    activation=torch.nn.ELU,
    bias=False,
    iterative_whitening=False,
)

results = dict()
for t in torch.linspace(start=0, end=1, steps=11):
    for r in range(REPEATS):
        run_id = (round(t.item(), 2), r)
        print(f"run_id = {run_id}")

        # Load data_________________________________________________________________________________
        seed_everything(seed=SEED)
        train_ds, val_ds = make_dataset(n_samples=N_SAMPLES, t=t.item())

        train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=False)
        val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

        # Build NCP_________________________________________________________________________________
        ncp = NCP(
            embedding_x=MLP(input_shape=1, **NCP_PARAMS),
            embedding_dim_x=NCP_PARAMS["output_shape"],
            embedding_y=MLP(input_shape=1, **NCP_PARAMS),
            embedding_dim_y=NCP_PARAMS["output_shape"],
        )

        # Train NCP_________________________________________________________________________________
        # Training module for lightning
        model = NCPTrainingModule(results=results, run_id=run_id, ncp=ncp)

        # Create logger
        logger = CSVLogger(save_dir=RUN_PATH, name="detecting_independence", version=run_id)

        # Create callbacks
        # TODO: Add ModelCheckpoint and EarlyStopping
        # ckpt_call = ModelCheckpoint()
        # early_call = EarlyStopping()

        trainer = L.Trainer(
            accelerator="cpu",
            precision="bf16",
            logger=logger,
            # callbacks=[ckpt_call, early_call],
            max_epochs=100,
            check_val_every_n_epoch=25,
            enable_model_summary=False,
            enable_progress_bar=False,
        )

        trainer.fit(model, train_dl, val_dl)

## Plots

In [None]:
import pandas as pd
import seaborn as sns

results_df = pd.DataFrame(
    data=[(t, r, svals.max().item()) for ((t, r), svals) in results.items()],
    columns=["t", "r", "norm"],
)
sns.pointplot(results_df, x="t", y="norm")