In [2]:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import xarray as xr
%config InlineBackend.figure_format = 'retina'
az.style.use("arviz-darkgrid")
rng = np.random.default_rng(1234)



例子：社会经济地位，社会关系地位与幸福感的关系

数据本身的可视化、贝叶斯模型、模型诊断、统计推断

In [31]:
# 导入数据
SMS_data = pd.read_csv('SMS_Well_being.csv')[['uID','variable','factor','Country']]

In [102]:
plot_data = [
    sorted(SMS_data.query('factor=="Low"').variable[0:3000]),
    sorted(SMS_data.query('factor=="High"').variable[0:3000])]

In [5]:
# import matplotlib
# a = sorted([f.name for f in matplotlib.font_manager.fontManager.ttflist])

# for i in a:
#    print(i)

# 字体样式
font = {'family' : 'Source Han Sans CN'}
# 具体使用
plt.rc('font',**font)

In [107]:
# 画图对比两种社会地位对幸福感的影响
def adjacent_values(vals, q1, q3):
    upper_adjacent_value = q3 + (q3 - q1) * 1.5
    upper_adjacent_value = np.clip(upper_adjacent_value, q3, vals[-1])

    lower_adjacent_value = q1 - (q3 - q1) * 1.5
    lower_adjacent_value = np.clip(lower_adjacent_value, vals[0], q1)
    return lower_adjacent_value, upper_adjacent_value

def set_axis_style(ax, labels):
    ax.xaxis.set_tick_params(direction='out')
    ax.xaxis.set_ticks_position('bottom')
    ax.set_xticks(np.arange(1, len(labels) + 1), labels=labels)
    ax.set_xlim(0.25, len(labels) + 0.75)
    ax.set_xlabel('社会关系地位')

fig, ax1 = plt.subplots(nrows=1, ncols=1, figsize=(9, 4), sharey=True)

parts = ax1.violinplot(
        plot_data, showmeans=False, showmedians=False,
        showextrema=False)

for pc in parts['bodies']:
    pc.set_facecolor('#D43F3A')
    pc.set_edgecolor('black')
    pc.set_alpha(1)

quartile1, medians, quartile3 = np.percentile(plot_data, [25, 50, 75], axis=1)
whiskers = np.array([
    adjacent_values(sorted_array, q1, q3)
    for sorted_array, q1, q3 in zip(plot_data, quartile1, quartile3)])
whiskers_min, whiskers_max = whiskers[:, 0], whiskers[:, 1]

inds = np.arange(1, len(medians) + 1)
ax1.scatter(inds, medians, marker='o', color='white', s=30, zorder=3)
ax1.vlines(inds, quartile1, quartile3, color='k', linestyle='-', lw=5)
ax1.vlines(inds, whiskers_min, whiskers_max, color='k', linestyle='-', lw=1)

# set style for the axes
labels = ['低','高']
plt.xticks(np.arange(2)+1, labels)
plt.xlabel('社会地位')
plt.ylabel('幸福感')

plt.subplots_adjust(bottom=0.15, wspace=0.05)
plt.show()



In [108]:
x = pd.factorize(SMS_data.factor)[0] # high为0，low为1
with pm.Model() as linear_regression:
    sigma = pm.HalfCauchy("sigma", beta=2)
    β0 = pm.Normal("β0", 0, sigma=5)
    β1 = pm.Normal("β1", 0, sigma=5)
    # x = pm.MutableData("x", x, dims="uID")
    # μ = pm.Deterministic("μ", β0 + β1 * x, dims="uID")
    pm.Normal("y", mu=β0 + β1 * x, sigma=sigma, observed=SMS_data.variable)

In [44]:
pm.model_to_graphviz(linear_regression)

In [54]:
with linear_regression:
    idata = pm.sample(4000, tune=2000, target_accept=0.9, return_inferencedata=True)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [β1, β0, sigma]


Sampling 4 chains for 2_000 tune and 4_000 draw iterations (8_000 + 16_000 draws total) took 12 seconds.


In [55]:
az.plot_trace(idata);

In [67]:
az.loo(idata)

Computed from 16000 by 6905 log-likelihood matrix

         Estimate       SE
elpd_loo -6936.45    61.84
p_loo        3.19        -

In [68]:
az.rhat(idata)

In [57]:
reg_post = idata.posterior.stack(chain_draw=("chain", "draw"))

In [120]:
ppc_x = np.repeat([0,1],len(reg_post.sigma)/2)
ppc_y = reg_post['β0'] + reg_post['β1']*ppc_x

In [131]:
labels = ['低', '高']
obs_low = SMS_data.query('factor=="Low"').variable
obs_high = SMS_data.query('factor=="High"').variable
ppc_low = ppc_y[ppc_x==1].values
ppc_high = ppc_y[ppc_x==0].values
data = [
    list(obs_low),
    list(ppc_low),
    list(obs_high),
    list(ppc_high)
    ]
pos = [1, 2, 4, 5]

fig, ax = plt.subplots()
ax.violinplot(data, pos, points=20, widths=0.3,
    showmeans=True, showextrema=True, showmedians=True)
# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('幸福感')
ax.set_title('Posterior predictive check')
plt.xticks([1.5,4.5], labels)

fig.tight_layout()
plt.show()



In [125]:
labels = ['低', '高']
obs_low = SMS_data.query('factor=="Low"').variable
obs_high = SMS_data.query('factor=="High"').variable
ppc_low = ppc_y[ppc_x==1].values
ppc_high = ppc_y[ppc_x==0].values
obs_means = [
    np.mean(obs_low)+1,
    np.mean(obs_high)+1
    ]
obs_std = [
    np.std(obs_low),
    np.std(obs_high)
]
ppc_means = [
    np.mean(ppc_low)+1,
    np.mean(ppc_high)+1
    ]
ppc_std = [
    np.std(ppc_low),
    np.std(ppc_high)
]

x = np.arange(len(labels))  # the label locations
width = 0.5  # the width of the bars

fig, ax = plt.subplots()
rects1 = ax.bar(x - width/2, obs_means, width, yerr=obs_std, label='观测值')
rects2 = ax.bar(x + width/2, ppc_means, width, yerr=ppc_std, label='预测值')

# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('幸福感')
ax.set_title('Posterior predictive check')
plt.xticks(x, labels)
ax.legend()

ax.bar_label(rects1, padding=3)
ax.bar_label(rects2, padding=3)

fig.tight_layout()

plt.show()

