In [None]:
import pandas as pd
import numpy as np
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt

df = pd.read_csv('date_colesterol.csv')
df.columns = [c.strip().replace('"', '') for c in df.columns]

X = df['Ore_Exercitii'].values
y = df['Colesterol'].values

X_s = (X - X.mean()) / X.std()
y_s = (y - y.mean()) / y.std()

k_values = [3, 4, 5]
idatas = {}

for K in k_values:
    with pm.Model() as model:
        weights = pm.Dirichlet('w', a=np.ones(K))

        alpha = pm.Normal('alpha', mu=0, sigma=1, shape=K)
        beta = pm.Normal('beta', mu=0, sigma=1, shape=K)
        gamma = pm.Normal('gamma', mu=0, sigma=1, shape=K)
        sigma = pm.HalfNormal('sigma', sigma=1, shape=K)

        components = []
        for k in range(K):
            mu_k = alpha[k] + beta[k] * X_s + gamma[k] * X_s**2
            dist_k = pm.Normal.dist(mu=mu_k, sigma=sigma[k])
            components.append(dist_k)

        pm.Mixture('y_mix', w=weights, comp_dists=components, observed=y_s)

        idata = pm.sample(3000, tune=1000, chains=1, return_inferencedata=True, progressbar=True, target_accept=0.95, random_seed=42)
        pm.compute_log_likelihood(idata)
        idatas[K] = idata

# a
print("Estimarea parametrilor pentru fiecare K:")
for K in k_values:
    print(f"\nK = {K}")
    summary = az.summary(idatas[K], var_names=['w', 'alpha', 'beta', 'gamma', 'sigma'])
    print(summary[['mean', 'sd', 'hdi_3%', 'hdi_97%']])

#b
print("\nCompararea modelelor (WAIC):")
comp_df = az.compare(idatas, ic="waic", scale="deviance")
print(comp_df)

best_K = comp_df.index[0]
print(f"\nCel mai bun model conform WAIC are K = {best_K}.")

Output()

Output()

Output()

Output()

Output()



Output()

Estimarea parametrilor pentru fiecare K:

K = 3


Shape validation failed: input_shape: (1, 3000), minimum_shape: (chains=2, draws=4)
Shape validation failed: input_shape: (1, 3000), minimum_shape: (chains=2, draws=4)


           mean     sd  hdi_3%  hdi_97%
w[0]      0.366  0.025   0.320    0.414
w[1]      0.290  0.039   0.215    0.359
w[2]      0.344  0.034   0.282    0.408
alpha[0] -1.108  0.028  -1.160   -1.056
alpha[1]  0.106  0.077  -0.044    0.246
alpha[2]  0.796  0.048   0.706    0.887
beta[0]  -0.425  0.019  -0.462   -0.389
beta[1]  -0.661  0.036  -0.728   -0.594
beta[2]  -0.392  0.037  -0.458   -0.319
gamma[0]  0.075  0.021   0.037    0.115
gamma[1] -0.070  0.053  -0.166    0.033
gamma[2]  0.247  0.029   0.192    0.301
sigma[0]  0.234  0.016   0.206    0.264
sigma[1]  0.301  0.040   0.230    0.378
sigma[2]  0.286  0.026   0.240    0.336

K = 4


Shape validation failed: input_shape: (1, 3000), minimum_shape: (chains=2, draws=4)


           mean     sd  hdi_3%  hdi_97%
w[0]      0.366  0.026   0.315    0.413
w[1]      0.219  0.071   0.068    0.346
w[2]      0.184  0.097   0.000    0.337
w[3]      0.231  0.082   0.081    0.388
alpha[0] -1.107  0.029  -1.156   -1.048
alpha[1]  0.124  0.244  -0.140    0.683
alpha[2]  0.581  0.395  -0.349    1.043
alpha[3]  0.651  0.371  -0.047    1.035
beta[0]  -0.419  0.020  -0.457   -0.383
beta[1]  -0.525  0.266  -0.776   -0.010
beta[2]  -0.222  0.298  -0.695    0.263
beta[3]  -0.411  0.220  -0.735   -0.051
gamma[0]  0.074  0.021   0.038    0.116
gamma[1] -0.022  0.183  -0.258    0.289
gamma[2]  0.175  0.267  -0.420    0.502
gamma[3]  0.171  0.186  -0.151    0.332
sigma[0]  0.234  0.016   0.204    0.263
sigma[1]  0.291  0.141   0.158    0.417
sigma[2]  0.285  0.210   0.037    0.466
sigma[3]  0.281  0.132   0.156    0.364

K = 5
           mean     sd  hdi_3%  hdi_97%
w[0]      0.208  0.123   0.000    0.385
w[1]      0.219  0.129   0.000    0.386
w[2]      0.154  0.106   0.000   

See http://arxiv.org/abs/1507.04544 for details
See http://arxiv.org/abs/1507.04544 for details
