# Purpose
Examine simulated data acc/sse related to word freqeuncy as a function of sigmoid. See Bry

In [None]:
import os
import pandas as pd
import numpy as np
import altair as alt
import ipywidgets as widgets
import wordfreq
from scipy.optimize import curve_fit

In [None]:
# Pull data from BQ

from google.cloud import bigquery
os.environ["GOOGLE_APPLICATION_CREDENTIALS"]="/home/jupyter/tf/secret/majestic-camp-303620-e8cb3a12037b.json"
client = bigquery.Client(location="US", project="majestic-camp-303620")

def load_raw_data():
    """Read data from BQ database"""
    query = """
    SELECT 
        epoch,
        sample,
        word,
        AVG(wf) AS wf ,  
        AVG(acc) AS acc, 
        AVG(sse) AS sse, 
    FROM 
        slow_op_10.train
    WHERE 
        unit_time=4.0
    GROUP BY
        epoch,
        sample,
        word;
    """
    query_job = client.query(query)

    return query_job.to_dataframe()

# df = load_raw_data()
# df.to_csv("op10_ave_results.csv")

In [None]:
df = pd.read_csv("op10_ave_results.csv")
df.rename({'wf':'wf_dynamic'}, axis=1, inplace=True)

In [None]:
# Get OP measure (unconditional surprisal)
op = pd.read_csv('noam/supplementary_material.csv')
op = op[['word', 'uncond.surprisal']]
op.rename({'uncond.surprisal': 'op'}, axis=1, inplace=True)
df = df.merge(op, how='left', on='word')

In [None]:
# Obsolete, calculate Zipf from WSJ frequency

# df_train = pd.read_csv("../../dataset/df_train.csv")
# df_train = df_train[['word', 'wf', 'img']]
# df_train.rename({'wf':'wf_wsj'}, axis=1, inplace=True)
# df = df.merge(df_train, 'left', 'word')

# df['zipf_wsj'] = np.log10((df.wf_wsj/1000) + 1)
# df['log_wf_wsj'] = np.log10(df.wf_wsj+1)
# df['log_wf_dynamic'] = np.log10(df.wf_dynamic+1)

- When converting WSJ to Zipf, Zipf range is 0-3.4, which is a bit off the regular range of 0-7, perhaps WSJ is not a wpm scale in the raw data
- To get Zipf scale, I used a [word_freq](https://github.com/LuminosoInsight/wordfreq/) library that based on [exquisite-corpus](https://github.com/LuminosoInsight/exquisite-corpus), which aggregated corpus from Wikipedia, SUBTLEX, News, Books, Web, Twitter, Reddit, and MISC content

In [None]:
def get_zipf(x):
    return wordfreq.zipf_frequency(str(x), lang='en', minimum=0)

def get_wf(x):
    return wordfreq.word_frequency(str(x), lang='en', minimum=0)

df['zipf'] = df.word.apply(get_zipf)
df['wf'] = df.word.apply(get_wf)

# Peek at 1M sample
df.loc[df.epoch==100 ,['wf_dynamic', 'wf', 'zipf', 'op']].describe()

In [None]:
df.to_csv("parsed_df.csv")

## Plot acc vs. frequency measures

In [None]:
@widgets.interact(x_var=['wf_dynamic', "wf", "zipf"],
                  epoch=(10,100,10), 
                  min_op=(0, 10, 0.1), 
                  max_op=(0, 10, 0.1),
                  min_zipf=(0, 8, 0.01),
                  max_zipf=(0, 7, 0.01),
                  loess_bandwidth=(0,1,0.1))
def plot_exploratory(x_var="zipf", epoch=100, min_op=0, max_op=0, min_zipf=0, max_zipf=8, loess_bandwidth=0.3):
    x = df.loc[(df.epoch==epoch) & 
               (df.op >= min_op) & 
               (df.op <= max_op) &
               (df.zipf >= min_zipf) &
               (df.zipf <= max_zipf)]

    annotatation = f'Epoch: {epoch}; OP surprisal within: [{min_op}, {max_op}]; Zipf within: [{min_zipf}, {max_zipf}]'

    if len(x) > 1000:
        x = x.sample(1000)

    p = alt.Chart(x).encode(x=x_var, y=alt.Y("acc", scale=alt.Scale(domain=(0,1))),  tooltip=["word", "wf", "wf_dynamic", "zipf", "op"]).mark_point()
    l = p.transform_loess(x_var, 'acc', bandwidth=loess_bandwidth).mark_line(color='red')

    return (p + l).properties(title=annotatation)

## Curve fit

### Fit accuracy to 2-PL IRT like equation

$P(X=1|\theta, a, b)= \frac{e^{(a(\theta -b))}}{1+e^{(a(\theta -b))}}$

where 
$\theta$: frequency (zipf scale)

$a$: max slope (IRT: discriminability)

$b$: x-shift (IRT: difficulty)

In [None]:
def f2pl(theta, a, b):
    """2PL equation"""
    x = a * (theta - b)
    ex = np.exp(x)
    return ex/(1+ex)

In [None]:
tmp = df.loc[df.epoch==10]

params, _ = curve_fit(f=f2pl, xdata=tmp.zipf, ydata=tmp.acc)

### Fit sse to 4-PL IRT like equation

$P(X=1|\theta, a, b, c, d)= c + (d-c) \frac{e^{(a(\theta -b))}}{1+e^{(a(\theta -b))}}$

where 
$\theta$: frequency (zipf scale)

$a$: max slope (IRT: discriminability)

$b$: x-shift (IRT: difficulty)

$c$: lower asymptote

$d$: upper asymptote

In [None]:
def f4pl(theta, a, b, c, d):
    """4PL equation"""
    x = a * (theta - b)
    ex = np.exp(x)
    fr = ex / (1 + ex)
    return c + (d-c) * fr