# Iris classification in Pyro

aims of notebook
* demonstrate use of pyro for probabilistic classification on iris benchmasrk dataset

Other benchmasrk datasets might be tried in this or another notebook

In [None]:
import numpy as np
import pandas as pd
import torch
import pyro
import pyro.distributions as dist

In [None]:
# Load the Iris dataset
iris = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data',
                   header=None, names=['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species'])
iris = iris.drop(columns='species')


In [None]:
# Define a probabilistic model using Pyro
def model(data):
    # Priors over the parameters of a Gaussian mixture model
    weights = pyro.sample('weights', dist.Dirichlet(torch.ones(3)))
    means = pyro.sample('means', dist.Normal(torch.zeros(3), 10 * torch.ones(3)))
    scales = pyro.sample('scales', dist.LogNormal(torch.zeros(3), 10 * torch.ones(3)))

    # Mixture components
    with pyro.plate('data', len(data)):
        z = pyro.sample('z', dist.Categorical(weights))
        pyro.sample('obs', dist.Normal(means[z], scales[z]), obs=data)


In [None]:

# Define a variational distribution using Pyro
def guide(data):
    # Approximate posterior over the parameters of the Gaussian mixture model
    alpha = pyro.param('alpha', torch.ones(3), constraint=dist.constraints.positive)
    beta = pyro.param('beta', torch.randn(3), constraint=dist.constraints.positive)
    gamma = pyro.param('gamma', torch.randn(3), constraint=dist.constraints.positive)
    weights = pyro.sample('weights', dist.Dirichlet(alpha))
    means = pyro.sample('means', dist.Normal(beta, gamma))
    scales = pyro.sample('scales', dist.LogNormal(beta, gamma))

    


In [None]:


# Define an optimizer
optimizer = pyro.optim.Adam({'lr': 0.01})

# Create a Pyro SVI object for variational inference
svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())

# Train the model using variational inference
data = torch.tensor(iris.values, dtype=torch.float32)
for i in range(1000):
    loss = svi.step(data)
    if i % 100 == 0:
        print('iteration', i, 'loss', loss)

# Extract the posterior distribution over the parameters of the model
weights = pyro.param('alpha').detach().numpy() / np.sum(pyro.param('alpha').detach().numpy())
means = pyro.param('beta').detach().numpy()
scales = pyro.param('gamma').detach().numpy()

# Print the results
print('Inferred weights:', weights)
print('Inferred means:', means)
print('Inferred scales:', scales)
