In [13]:
import pandas as pd
import pymc as pm
import arviz as az
import numpy as np
import xarray as xr

# Load the Iris dataset
data = pd.read_csv('iris.csv')
features = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']

# Dictionary to store WAIC results
waic_results = {}

for feature in features:
    print(f"\n=== Analyzing feature: {feature} ===")
    X = data[feature].values
    
    with pm.Model() as model:
        # Cluster weights (Dirichlet prior)
        weights = pm.Dirichlet('weights', a=[1, 1, 1])
        
        # Cluster means (Normal prior)
        means = pm.Normal('means', mu=X.mean(), sigma=10, shape=3)
        
        # Cluster standard deviations (HalfNormal prior)
        stds = pm.HalfNormal('stds', sigma=X.std(), shape=3)
        
        # Gaussian Mixture Model
        likelihood = pm.NormalMixture(
            'obs',  # This is our likelihood variable
            w=weights,
            mu=means,
            sigma=stds,
            observed=X
        )
        
        # Sample
        trace = pm.sample(
            tune=10, 
            draws=10, 
            chains=1, 
            progressbar=False, 
            random_seed=42
        )
    
    print(f"\nModel summary for {feature}:")
    print(az.summary(trace, kind="stats"))
    
    # Compute log likelihood for each posterior sample
    log_likelihood = []
    for draw in range(trace.posterior.sizes["draw"]):
        # Extract posterior samples for this draw
        weights_sample = trace.posterior["weights"].sel(draw=draw).values
        means_sample = trace.posterior["means"].sel(draw=draw).values
        stds_sample = trace.posterior["stds"].sel(draw=draw).values
        
        # Compute log likelihood for this draw
        log_likelihood_draw = []
        for i in range(len(X)):
            # Compute the log likelihood for each observation
            log_likelihood_draw.append(
                np.log(np.sum(weights_sample * np.exp(-0.5 * ((X[i] - means_sample) / stds_sample)**2) / (stds_sample * np.sqrt(2 * np.pi))))
            )
        log_likelihood.append(log_likelihood_draw)
    
    # Stack log likelihood into an array with shape (chain, draw, obs)
    log_likelihood = np.stack(log_likelihood)  # Shape: (draw, obs)
    log_likelihood = log_likelihood[np.newaxis, :, :]  # Add chain dimension
    
    # Add log likelihood to trace with explicit dimensions
    trace.sample_stats["log_likelihood"] = xr.DataArray(
        log_likelihood,
        dims=["chain", "draw", "obs"],
        coords={
            "chain": trace.posterior.coords["chain"],
            "draw": trace.posterior.coords["draw"],
            "obs": np.arange(len(X))
        }
    )
    
    # Calculate WAIC
    waic = az.waic(trace)
    waic_results[feature] = waic.elpd_waic  # Use the correct attribute
    print(f"\nWAIC for {feature}: {waic.elpd_waic:.2f}")

# Identify the best separating feature
best_feature = min(waic_results, key=lambda k: waic_results[k])
print(f"\nBest separating feature: '{best_feature}' (WAIC: {waic_results[best_feature]:.2f})")


=== Analyzing feature: sepal_length ===


Only 10 samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate.
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [weights, means, stds]
Sampling 1 chain for 10 tune and 10 draw iterations (10 + 10 draws total) took 1 seconds.
The number of samples is too small to check convergence reliably.
Only 10 samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate.



Model summary for sepal_length:
             mean     sd  hdi_3%  hdi_97%
means[0]    6.165  0.071   6.051    6.267
means[1]    5.295  0.777   4.302    6.191
means[2]    5.062  0.286   4.843    5.798
weights[0]  0.632  0.185   0.164    0.869
weights[1]  0.166  0.182   0.006    0.598
weights[2]  0.202  0.089   0.047    0.299
stds[0]     0.759  0.143   0.626    1.110
stds[1]     0.581  0.505   0.082    1.666
stds[2]     0.368  0.089   0.248    0.558

WAIC for sepal_length: -182.74

=== Analyzing feature: sepal_width ===


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [weights, means, stds]
Sampling 1 chain for 10 tune and 10 draw iterations (10 + 10 draws total) took 1 seconds.
The number of samples is too small to check convergence reliably.
See http://arxiv.org/abs/1507.04544 for details
Only 10 samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate.



Model summary for sepal_width:
             mean     sd  hdi_3%  hdi_97%
means[0]    2.987  0.371   2.510    3.821
means[1]    2.967  0.057   2.907    3.089
means[2]    3.144  0.265   2.641    3.497
weights[0]  0.443  0.261   0.039    0.729
weights[1]  0.160  0.127   0.056    0.397
weights[2]  0.397  0.258   0.096    0.906
stds[0]     0.396  0.139   0.157    0.545
stds[1]     0.128  0.068   0.059    0.226
stds[2]     0.385  0.046   0.299    0.452

WAIC for sepal_width: -89.17

=== Analyzing feature: petal_length ===


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [weights, means, stds]
Sampling 1 chain for 10 tune and 10 draw iterations (10 + 10 draws total) took 1 seconds.
The number of samples is too small to check convergence reliably.
See http://arxiv.org/abs/1507.04544 for details
Only 10 samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate.



Model summary for petal_length:
             mean     sd  hdi_3%  hdi_97%
means[0]    1.463  0.022   1.429    1.489
means[1]    5.009  0.373   4.596    5.532
means[2]    5.067  0.451   4.482    5.572
weights[0]  0.347  0.033   0.313    0.421
weights[1]  0.358  0.121   0.164    0.484
weights[2]  0.295  0.111   0.180    0.523
stds[0]     0.176  0.021   0.144    0.215
stds[1]     0.813  0.130   0.660    0.969
stds[2]     0.720  0.136   0.590    0.974

WAIC for petal_length: -205.79

=== Analyzing feature: petal_width ===


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [weights, means, stds]
Sampling 1 chain for 10 tune and 10 draw iterations (10 + 10 draws total) took 1 seconds.
The number of samples is too small to check convergence reliably.



Model summary for petal_width:
             mean     sd  hdi_3%  hdi_97%
means[0]   -0.353  0.142  -0.734   -0.227
means[1]    0.233  0.016   0.214    0.255
means[2]    1.653  0.066   1.523    1.708
weights[0]  0.013  0.004   0.004    0.016
weights[1]  0.271  0.046   0.234    0.346
weights[2]  0.717  0.044   0.645    0.752
stds[0]     0.111  0.018   0.079    0.128
stds[1]     0.076  0.014   0.069    0.117
stds[2]     0.438  0.052   0.384    0.532

WAIC for petal_width: -117.44

Best separating feature: 'petal_length' (WAIC: -205.79)


See http://arxiv.org/abs/1507.04544 for details
