In [None]:
import pandas as pd
import numpy as np

# Load the data, require labels and features
df_train_set = pd.read_csv('data/train_judgments.csv')
df_dev_set = pd.read_csv('data/dev_judgments.csv')

In [None]:
import matplotlib.pyplot as plt

print(len(df_train_set))
# project the cosine similarity with the proximity judgement
plt.plot(df_train_set['cosine_similarity'], df_train_set['median_judgment'], ".") 
plt.xlabel('cosine_similarity')
plt.ylabel('median_judgment')

In [None]:
import pickle
import cloudpickle
import arviz as az

# Load the linear model
pickle_filepath = f'pickle.pkl'
with open(pickle_filepath , 'rb') as buff:
    model_dict = cloudpickle.load(buff)

trace = model_dict['trace']
model = model_dict['model']

az.summary(trace, kind='stats')

In [None]:
trace_df = trace.posterior.to_dataframe()
trace_df.cov().round(3)

In [None]:
xbar = df_train_set['cosine_similarity'].mean()
plt.plot(df_train_set['cosine_similarity'], df_train_set['median_judgment'], ".")
plt.plot(
    df_train_set['cosine_similarity'],
    trace.posterior["a"].mean().item(0)
    + trace.posterior["b"].mean().item(0) * (df_train_set['cosine_similarity'] - xbar),
)
plt.xlabel('cosine_similarity')
plt.ylabel('median_judgment')

In [None]:
import pymc as pm

samp_size = 1000
slice_rate = int(len(trace["posterior"]["draw"]) / samp_size)
thin_data = trace.sel(draw=slice(None, None, slice_rate))
xbar = df_dev_set['cosine_similarity'].mean()
with pm.Model() as WiC_predict:
    # priors
    a = pm.Normal("a", mu=0, sigma=4)
    b = pm.Lognormal("b", mu=0, sigma=1)
    sigma = pm.Uniform("sigma", 0, 1)

    mu = a + b * (df_dev_set['cosine_similarity'] - xbar)
    proximity = pm.Normal("proximity", mu=mu, sigma=sigma)
    
    proximity_pred = pm.sample_posterior_predictive(thin_data, var_names=['proximity'])
az.plot_hdi(df_dev_set['cosine_similarity'], proximity_pred.posterior_predictive["proximity"], hdi_prob=0.89)
plt.scatter(df_dev_set['cosine_similarity'], df_dev_set['median_judgment'])
plt.xlabel("cosine_similarity")
plt.ylabel("median_judgment")