<a href="https://colab.research.google.com/github/HeningWang/numpyro_adjective_modelling/blob/main/mix_pyro_slider_MCMC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Import modules and dependencies


In [1]:
!pip install pyro-ppl
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro
!pip install funsor

Collecting pyro-ppl
  Downloading pyro_ppl-1.8.5-py3-none-any.whl (732 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/732.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m732.5/732.5 kB[0m [31m38.8 MB/s[0m eta [36m0:00:00[0m
Collecting pyro-api>=0.1.1 (from pyro-ppl)
  Downloading pyro_api-0.1.2-py3-none-any.whl (11 kB)
Installing collected packages: pyro-api, pyro-ppl
Successfully installed pyro-api-0.1.2 pyro-ppl-1.8.5
  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for numpyro (setup.py) ... [?25l[?25hdone
Collecting funsor
  Downloading funsor-0.4.5-py3-none-any.whl (174 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m174.9/174.9 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting makefun (from funsor)
  Downloading makefun-1.15.1-py2.py3-none-any.whl (22 kB)
Installing collected packages: makefun, funsor
Successfully installed funsor-0.4.5 mak

In [2]:
! git clone https://github.com/HeningWang/numpyro_adjective_modelling.git

Cloning into 'numpyro_adjective_modelling'...
remote: Enumerating objects: 28, done.[K
remote: Counting objects:   3% (1/28)[Kremote: Counting objects:   7% (2/28)[Kremote: Counting objects:  10% (3/28)[Kremote: Counting objects:  14% (4/28)[Kremote: Counting objects:  17% (5/28)[Kremote: Counting objects:  21% (6/28)[Kremote: Counting objects:  25% (7/28)[Kremote: Counting objects:  28% (8/28)[Kremote: Counting objects:  32% (9/28)[Kremote: Counting objects:  35% (10/28)[Kremote: Counting objects:  39% (11/28)[Kremote: Counting objects:  42% (12/28)[Kremote: Counting objects:  46% (13/28)[Kremote: Counting objects:  50% (14/28)[Kremote: Counting objects:  53% (15/28)[Kremote: Counting objects:  57% (16/28)[Kremote: Counting objects:  60% (17/28)[Kremote: Counting objects:  64% (18/28)[Kremote: Counting objects:  67% (19/28)[Kremote: Counting objects:  71% (20/28)[Kremote: Counting objects:  75% (21/28)[Kremote: Counting objects:  78% (22/28)

In [3]:
import os

from IPython.display import set_matplotlib_formats
import jax
import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import logsumexp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import math

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS

plt.style.use("bmh")
if "NUMPYRO_SPHINXBUILD" in os.environ:
    set_matplotlib_formats("svg")

assert numpyro.__version__.startswith("0.12.1")

In [4]:
import sys
sys.path.append('/content/numpyro_adjective_modelling')

Some helper functions:

In [5]:
# Mutate the dataset to include the states of the objects
# ... states are independent variables for models

def extract_states(line):
    states = []
    for i in range(6):
      color = 1 if line[10 + i] == "blue" else 0
      form = 1 if line[16 + i] == "circle" else 0
      new_obj = (line[4 + i], color, form)
      states.append(new_obj)
    return jnp.array(states)


# Transform/rescale slider value from range 0 to 100 to 0 to 1
# ... in order to match predicted probability from models

def transformation_data(slider_value, link = None):
    if link == "identity":
      slider_value = jnp.clip(slider_value, 0, 100)
      transformed_prob = slider_value / 100
    elif link == "logit":
        transformed_prob = 1 / (1 + math.exp(-slider_value))
    return transformed_prob

def link_function(x, param = 1):
    return 1 / (1 + jnp.exp(param * -(x - 0.5)))

def compute_alpha_beta_concentration(mu, v):
    alpha = mu * v
    beta = (1 - mu) * v
    return alpha, beta

def Marginal(fn):
    return memoize(lambda *args: HashingMarginal(Search(fn).run(*args)))

def plot_dist(d, ax=None):
    support = d.enumerate_support()
    data = [d.log_prob(s).exp().item() for s in d.enumerate_support()]
    names = list(map(str, support))

    if ax is None:
        ax = plt.subplot(111)

    width = 0.3
    bins = [x-width/2 for x in range(1, len(data) + 1)]
    ax.bar(bins,data,width=width)
    ax.set_xticks(list(range(1, len(data) + 1)))
    ax.set_xticklabels(names, rotation=45, rotation_mode="anchor", ha="right")

def get_results(posterior):
    results = {}
    support = posterior.enumerate_support()
    data = [posterior.log_prob(s).exp().item() for s in posterior.enumerate_support()]
    results["support"] = support
    results["probs"] = data
    return results

def normalize(arr, axis=1):
    """
    Normalize arr along axis
    """
    return arr / arr.sum(axis, keepdims=True)

In [6]:
# Import dataset
dataset_url = "/content/numpyro_adjective_modelling/dataset/dataset_slider.csv"
df = pd.read_csv(dataset_url)

# subset data to only include combination dimension_color
df = df[df['combination'] == 'dimension_color']
df.reset_index(inplace=True, drop=True)

# Mutate the dataset to include the states of the objects
df_experiment = df.copy()
df_experiment["states"] = df_experiment.apply(lambda row: extract_states(row), axis=1)
#print(df_experiment.prefer_first_1st.describe())

df_experiment.prefer_first_1st = jnp.clip(df_experiment.prefer_first_1st.to_numpy(), 0, 100)
df_experiment.prefer_first_1st = df_experiment.prefer_first_1st/100
print(df_experiment.prefer_first_1st.describe())




count    3166.000000
mean        0.704643
std         0.380797
min         0.000000
25%         0.500000
50%         0.920000
75%         1.000000
max         1.000000
Name: prefer_first_1st, dtype: float64


In [7]:
from sklearn.model_selection import train_test_split
# split the dataset into training and test sets
#train, test = train_test_split(df_experiment, test_size=0.99, random_state=42)

# use the whole dataset as training set
train = df_experiment

print(train.shape)

states_train = jnp.stack([cell for cell in train.states])
empirical_train = jnp.array(train.prefer_first_1st.to_numpy())

(3166, 27)


In [8]:
def get_threshold_kp(current_state_prior, k=0.5):
    objs = current_state_prior
    #measures_array = jnp.sort([x[1] for obj in objs for x in obj if x[0] == 'size'])
    min_val = jnp.min(objs)
    max_val = jnp.max(objs)
    threshold = max_val - k * (max_val - min_val)
    return threshold

def adjMeaning(word, obj, current_state_prior, color_semvalue=0.98, form_semvalue=0.98, wf=0.6, k=0.5):
    colors = [1]  # Specify the color values
    sizes = [0]  # Specify the size values

    if word == 1:
        return numpyro.sample("color", numpyro.distributions.Bernoulli(color_semvalue)) if word == obj[1] else numpyro.sample("color", numpyro.distributions.Bernoulli(1 - color_semvalue))
    elif word == 0:
        threshold = get_threshold_kp(current_state_prior, k)
        size = obj[0]
        prob_big = 1 - dist.Normal(size - threshold, wf * jnp.sqrt(size ** 2 + threshold ** 2)).cdf(jnp.array([0.0]))
        return numpyro.sample("size", numpyro.distributions.Bernoulli(prob_big))


In [9]:
# Example usage for meaning function
states = jnp.array([[10., 1., 1.],
                   [3., 1., 1.],
                   [3., 1., 1.],
                   [3., 1., 0.],
                   [3., 1., 0.],
                   [3., 0., 1.]], dtype=jnp.float32)

word = 0 # Example word, 0 for size
obj = states[5]  # Example object from states

# Example prior values
color_semvalue = 0.98
form_semvalue = 0.98
wf = 0.6
k = 0.5

states[0][1]
#get_threshold_kp(states)
# Call the meaning function
with handlers.seed(rng_seed=27):
 meaning = adjMeaning(word, obj, states, color_semvalue, form_semvalue, wf, k)

print(meaning)
print(obj)

[1]
[3. 0. 1.]


In [10]:
utterances = jnp.array([
    [0],
    [1]
    ])

def utterance_prior(bias=1):
    probs = jnp.array([bias,1])/(bias+1)
    n = numpyro.sample("utterance_index", dist.Categorical(probs=probs),infer={"enumerate": "parallel"})
    return n

def state_prior(states):
    length = len(states)
    n = numpyro.sample("state", dist.Categorical(probs=jnp.ones(length) / length))
    return states[n]

In [11]:
def literal_listener(states, color_semvalue = 0.98, form_semvalue = 0.98, wf = 0.6, k = 0.5):
  probs_blue = jnp.where((1. == states[:, 1]), color_semvalue, 1 - color_semvalue)
  threshold = get_threshold_kp(states, k)
  probs_big = jnp.array([1 - dist.Normal(obj[0] - threshold, wf * jnp.sqrt(obj[0] ** 2 + threshold ** 2)).cdf(0.0) for obj in states])
  probs = normalize(jnp.array([probs_big,probs_blue]))
  return probs


def speaker(states, alpha = 1, bias = 1, color_semvalue = 0.98, form_semvalue = 0.98, wf = 0.6, k = 0.5):
  listener = literal_listener(states, color_semvalue, form_semvalue,wf,k)
  bias_weights = jnp.array([0, 1]) * bias
  util_speaker = jnp.log(jnp.transpose(listener)) - bias_weights
  softmax_result = jax.nn.softmax(util_speaker)
  return softmax_result[0][0]


In [12]:
index = 14

states_example = df_experiment.iloc[index, df_experiment.columns.get_loc("states")]
condition = df_experiment.iloc[index, df_experiment.columns.get_loc("conditions")]
distribution = df_experiment.iloc[index, df_experiment.columns.get_loc("sharpness")]
preference = df_experiment.iloc[index, df_experiment.columns.get_loc("prefer_first_1st")]
print(states_example)
print(condition + " " + distribution)
print(preference)
print(f"literal listener: {literal_listener(states_example)}")
model_speaker = speaker(states_example, bias=0)
print(f"model_prediction: {model_speaker}")
print(f"speaker: {model_speaker}")


[[ 9.  1.  1.]
 [10.  0.  1.]
 [10.  0.  1.]
 [10.  0.  1.]
 [10.  0.  0.]
 [ 4.  0.  1.]]
zrdc blurred
0.98
literal listener: [[0.17540349 0.18264773 0.18264773 0.18264773 0.18264773 0.09400556]
 [0.90740746 0.01851852 0.01851852 0.01851852 0.01851852 0.01851852]]
model_prediction: 0.1619890332221985
speaker: 0.1619890332221985


In [13]:
vectorized_speaker = jax.vmap(speaker, in_axes=(0,None,None,None,None,None,None))
model_prob = vectorized_speaker(states_train, 1,1,0.5,0.5,0.5,0.5)
print(model_prob)
slider_predict = jax.vmap(link_function, in_axes = (0,None))(model_prob,20)

slider_predict = jnp.clip(slider_predict, 1e-5, 1 - 1e-5)
print(slider_predict)

[0.7527358  0.8800761  0.75607145 ... 0.76912326 0.802139   0.81604856]
[0.9936613  0.9995005  0.9940679  ... 0.99542457 0.9976306  0.99820507]


In [14]:
# define the conditioned model for MCMC
vectorized_speaker = jax.vmap(speaker, in_axes=(0,None,None,None,None,None,None))

def model_inc_utt_parallel_normal(states = None, data = None):
    gamma = numpyro.sample("gamma", dist.HalfNormal(5))
    color_semvalue = numpyro.sample("color_semvalue", dist.Uniform(0, 1))
    form_semvalue = color_semvalue
    k = numpyro.sample("k", dist.Uniform(0, 1))
    wf = 0.5
    bias = numpyro.sample("bias", dist.HalfNormal(5))
    steepness = numpyro.sample("steepness", dist.HalfNormal(0.5))
    sigma = numpyro.sample("sigma", dist.Uniform(0,1))

    with numpyro.plate("data",len(states)):
      model_prob = vectorized_speaker(states_train, gamma, bias, color_semvalue, form_semvalue, wf, k)
      slider_predict = jax.vmap(link_function, in_axes = (0,None))(model_prob, steepness)
      slider_predict = jnp.clip(slider_predict, 1e-5, 1 - 1e-5)
      obs = jnp.clip(data, 1e-5, 1 - 1e-5)
      numpyro.sample("obs", dist.TruncatedNormal(slider_predict, sigma, low = 0, high = 1), obs=obs) # use this for inference
    #pyro.sample("obs_{}".format(i), dist.Beta(alpha,beta)) # use this for prior predictive


In [15]:
# define the conditioned model for MCMC
def model_inc_utt_serial_beta(states, data):
    gamma = numpyro.sample("gamma", dist.HalfNormal(5))
    color_semvalue = numpyro.sample("color_semvalue", dist.Uniform(0, 1))
    form_semvalue = color_semvalue
    k = numpyro.sample("k", dist.Uniform(0, 1))
    wf = 0.5
    bias = numpyro.sample("bias", dist.HalfNormal(5))
    steepness = numpyro.sample("steepness", dist.HalfNormal(0.5))
    v = numpyro.sample("v", dist.Uniform(1e-5,5))
    for i in range(len(data)):
        model = speaker(states[i], gamma, bias, color_semvalue, form_semvalue, wf, k)
        model_prob = model[0][0]
        slider_predict = link_function(model_prob, link = "rapidlogit", param = steepness)
        slider_predict = jnp.clip(slider_predict, 1e-5, 1 - 1e-5)
        obs = jnp.clip(data[i], 1e-5, 1 - 1e-5)
        alpha, beta = compute_alpha_beta_concentration(slider_predict, v)
        numpyro.sample("obs_{}".format(i), dist.Beta(alpha,beta), obs=obs) # use this for inference
        #pyro.sample("obs_{}".format(i), dist.Beta(alpha,beta)) # use this for prior predictive


In [None]:
# define the MCMC kernel and the number of samples
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

kernel = NUTS(model_inc_utt_parallel_normal, target_accept_prob=0.8)
mcmc_inc = MCMC(kernel, num_warmup=1000,num_samples=1000)
mcmc_inc.run(rng_key_, states_train, empirical_train)

# print the summary of the posterior distribution
mcmc_inc.print_summary()

# Get the MCMC samples and convert to a DataFrame
posterior_inc = mcmc_inc.get_samples()
df_inc = pd.DataFrame(posterior_inc)

# Save the DataFrame to a CSV file
df_inc.to_csv('posterior_inc_utt_slider.csv', index=False)

sample: 100%|██████████| 2000/2000 [01:51<00:00, 17.92it/s, 7 steps of size 4.43e-01. acc. prob=0.91]



                      mean       std    median      5.0%     95.0%     n_eff     r_hat
            bias      7.39      2.62      7.16      3.57     12.00    726.39      1.00
  color_semvalue      0.42      0.29      0.40      0.00      0.84    759.24      1.00
           gamma      3.90      3.08      3.24      0.00      8.25    978.92      1.00
               k      0.48      0.28      0.47      0.03      0.90   1075.56      1.00
       steepness      5.85      0.27      5.85      5.45      6.38    851.08      1.00
               v      0.63      0.02      0.63      0.60      0.67   1131.21      1.00

Number of divergences: 0
