In [6]:
import numpy as np
from scipy.integrate import nquad
from scipy.stats import multivariate_normal

# Thanks chatgpt for this cell.

def squared_multivariate_normal_density(x, mean, cov):
    """
    Squared density function of a multivariate normal distribution.

    :param x: A point in n-dimensional space.
    :param mean: Mean vector of the multivariate normal distribution.
    :param cov: Covariance matrix of the multivariate normal distribution.
    :return: Squared density at the point x.
    """
    density = multivariate_normal.pdf(x, mean, cov)
    return density ** 2

def integrate_squared_density_multivariate(mean, cov, dims):
    """
    Numerically approximate the integrated squared density of a multivariate normal distribution.

    :param mean: Mean vector of the multivariate normal distribution.
    :param cov: Covariance matrix of the multivariate normal distribution.
    :param dims: Number of dimensions.
    :return: Numerical approximation of the integrated squared density.
    """
    # Integration limits for each dimension
    limits = [(-6, 6)] * dims

    # Wrapper function to adapt the multivariate function for nquad
    def integrand(*args):
        return squared_multivariate_normal_density(np.array(args), mean, cov)

    return nquad(integrand, limits)[0]

In [7]:
# quad_results = []
# # Integrated squared densities
# for d in [1, 2, 3]:
#     # Example usage
#     mean = np.zeros(d)
#     cov = np.eye(d)
#     
#     quad_result = integrate_squared_density_multivariate(mean, cov, d)
#     quad_results.append(quad_result)
#     print(f"with ndim={d}; \int p(x)^2 dx =", quad_result)

In [8]:
from chirho.robust.ops import fd_influence_fn
import pyro
import pyro.distributions as dist
import torch

In [9]:
def diagnormal(mean=0, std=1):
    return dict(x=pyro.sample('x', dist.Normal(mean, std)))

In [10]:
# Generate training data.
datas = []
N = 100

for d in [1, 2, 3]:
    with pyro.plate('N', N, dim=-2):
        with pyro.plate('d', d, dim=-1):
            datas.append(diagnormal(0, 1))

In [15]:
# Train models on the first half of the data.
inferred_models = []
for data in datas:
    # A model fit to the first half of the data.
    mean = torch.mean(data['x'][:N//2])
    std = torch.std(data['x'][:N//2])
    d = data['x'].shape[-1]
    def _model():
        with pyro.plate('d', d, dim=-1):
            return diagnormal(mean=mean, std=std)
    inferred_models.append(_model)

In [16]:
def squared_density_functional(model):
    def target(d, nmc=100):
        res = 0
        for _ in range(nmc):
            with pyro.poutine.trace() as tr:
                model()
            res += tr.trace.log_prob_sum().exp() / N
        return res
    return target

In [18]:
plugin_results = []
correction_results = []
corrected_results = []
for d, inferred_model, alldata in zip([1, 2, 3], inferred_models, datas):
    # Compute plugin.
    plugin_result = squared_density_functional(inferred_model)(d)
    plugin_results.append(plugin_result)
    print(f"plugin with ndim={d}; \int p(x)^2 dx =", plugin_result)
    
    # Estimate the expected influence function on the second half of the data.
    data = alldata['x'][N//2:]
    print(data)
    eif = fd_influence_fn(inferred_model, squared_density_functional, eps=torch.tensor(1e-3))
    correction_result = eif(data, d=d)
    correction_results.append(correction_result)
    print(f"correction with ndim={d}; \int p(x)^2 dx =", correction_result)
    
    corrected_result = plugin_result + correction_result
    corrected_results.append(corrected_result)
    print(f"corrected with ndim={d}; \int p(x)^2 dx =", corrected_result)

plugin with ndim=1; \int p(x)^2 dx = tensor(0.2571)
tensor([[-0.2172],
        [-0.8808],
        [-0.8434],
        [-0.8053],
        [-0.0255],
        [ 1.9628],
        [-0.6141],
        [ 1.5846],
        [ 0.5951],
        [-2.1053],
        [-0.1268],
        [-2.2127],
        [ 0.5829],
        [-0.0794],
        [-1.4608],
        [-2.0613],
        [-0.8474],
        [-1.6414],
        [ 0.2076],
        [ 0.8621],
        [ 0.3099],
        [ 0.4285],
        [-0.3761],
        [-0.4787],
        [ 0.5645],
        [ 0.8255],
        [ 0.2588],
        [-2.1734],
        [ 0.7421],
        [-0.5980],
        [-1.7096],
        [ 0.1787],
        [ 0.5939],
        [-0.6935],
        [-0.4807],
        [ 2.2234],
        [ 0.0458],
        [-0.6098],
        [-0.3696],
        [ 1.9040],
        [-1.0408],
        [-0.7124],
        [ 0.3746],
        [-0.5014],
        [ 0.3685],
        [-0.7191],
        [ 1.1040],
        [-0.1124],
        [-1.0923],
        [ 1.0834]

AttributeError: 'Tensor' object has no attribute 'items'