In [8]:
import pandas as pd
import numpy as np
import pymc as pm

if __name__ == '__main__':
        
    # Read the data
    iris = pd.read_csv('iris.csv')
    
    # Convert species to numeric
    species_map = {name: i for i, name in enumerate(iris['species'].unique())}
    iris['species_num'] = iris['species'].map(species_map)
    
    # Standardize features
    features = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
    X = iris[features].values
    X = (X - X.mean(axis=0)) / X.std(axis=0)
    y = iris['species_num'].values
    
    # Create and run the model
    with pm.Model() as model:
        # Priors for unknown model parameters
        alpha = pm.Normal('alpha', mu=0, sigma=10)
        beta = pm.Normal('beta', mu=0, sigma=10, shape=len(features))
        sigma = pm.HalfNormal('sigma', sigma=1)
    
        # Expected value of outcome
        mu = alpha + pm.math.dot(X, beta)
    
        # Likelihood (sampling distribution) of observations
        y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y)
    
        # Inference
        trace = pm.sample(draws=20, tune=10, chains=1, progressbar=False)
    
    print("Model sampling completed")
    
    # Print the feature importance based on absolute coefficient values
    with model:
        beta_samples = trace.posterior['beta'].mean(dim=["chain", "draw"]).values
        importance = np.abs(beta_samples)
    
    print("\nFeature importance (absolute value of coefficients):")
    for feat, imp in zip(features, importance):
        print(f"{feat}: {imp:.3f}")

Only 20 samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate.
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [alpha, beta, sigma]
Sampling 1 chain for 10 tune and 20 draw iterations (10 + 20 draws total) took 0 seconds.
The number of samples is too small to check convergence reliably.


Model sampling completed

Feature importance (absolute value of coefficients):
sepal_length: 0.052
sepal_width: 0.032
petal_length: 0.286
petal_width: 0.537
