In [1]:
import sys
import json
import numpy as np
from collections import defaultdict
from scipy.stats import spearmanr
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [68]:

log_path = "../outputs/checkpoints/vocab_3000-batch_size_256/scale_5000-recursion_depth_10000.json"
outputs = json.load(open(log_path))

influences = defaultdict(lambda: [])
binary = {}
grades = {}
for i in outputs:
    target_word, corpus, influence, binary_gt, grade_gt, n_removed = i
    binary[target_word] = int(binary_gt)
    grades[target_word] = float(grade_gt)

    if influence < 1:
        influences[target_word].append(influence)
    else:
        influences[target_word].append(1.)
        #influences[target_word].append(influence / n_removed)


for k in influences.keys():
    influences[k] = np.mean(influences[k])

In [69]:
min_inf = min(influences.values())
max_inf = max(influences.values())
best_thres, best_acc = None, -np.inf
for thres in np.linspace(min_inf, max_inf, 10):
    print("==================================================")
    print(f"Threshold: {thres}")
    keys = np.array(list(influences.keys()))
    pred_inf = np.array([influences[k] for k in keys])
    binary_gt = np.array([binary[k] for k in keys]) == 1
    
    correct = (pred_inf > thres) == binary_gt
    acc = np.mean(correct)
    if acc > best_acc:
        best_thres = thres
        best_acc = acc

    print(f"Accuracy: {acc:.2f}")
    
    #print(f"Correct target words   ({correct.sum():2}): {', '.join(keys[correct])}")
    #print(f"Incorrect target words ({(~correct).sum():2}): {', '.join(keys[~correct])}")
    


inf = np.array([influences[k] for k in keys])
grad = np.array([grades[k] for k in keys])
corr = spearmanr(inf, grad)[0]
print(f"Spearman's correlation: {corr}")


Threshold: 0.008040632354095578
Accuracy: 0.48
Threshold: 0.01980238759683238
Accuracy: 0.70
Threshold: 0.03156414283956918
Accuracy: 0.64
Threshold: 0.04332589808230599
Accuracy: 0.52
Threshold: 0.05508765332504279
Accuracy: 0.55
Threshold: 0.0668494085677796
Accuracy: 0.58
Threshold: 0.0786111638105164
Accuracy: 0.58
Threshold: 0.0903729190532532
Accuracy: 0.58
Threshold: 0.10213467429599
Accuracy: 0.58
Threshold: 0.1138964295387268
Accuracy: 0.55
Spearman's correlation: 0.5986963145634498


In [70]:
from scipy.stats import linregress
res = linregress(inf, grad)


fig = go.Figure()
fig.add_trace(go.Scatter(x=inf, y=grad, text=keys, showlegend=False, mode="markers"))
fig.add_trace(go.Scatter(x=np.array([0, 0.15]), y=res.intercept + res.slope*np.array([0, 0.15]), showlegend=True, name="fitted line", mode="lines"))
fig.update_layout(template="plotly_white", width=400, height=300, margin=dict(t=10, l=50, b=50, r=10))
fig.update_layout(legend=dict(x=0.7, y=0.95))
fig.update_xaxes(showline=True, linewidth=2, linecolor='black', mirror=False)
fig.update_yaxes(showline=True, linewidth=2, linecolor='black', mirror=False)
fig.update_xaxes(title=r'$\hat{d}_\text{MLM}$')
fig.update_yaxes(title="GT values")

fig.write_image("influences_grade_scatter.png")
fig.write_image("influences_grade_scatter.pdf")
fig.show()

In [75]:

pred_inf = np.array([influences[k] for k in keys])
binary_gt = np.array([binary[k] for k in keys]) == 1

pred = pred_inf > thres
correct = pred == binary_gt


fig = make_subplots(rows=1, cols=2, shared_yaxes=True, subplot_titles=("w/ Semantic change", "w/o Semantic change"), y_title=r'$\hat{d}_\text{MLM}$')

x = keys[binary_gt]
y = pred_inf[binary_gt]
idx = y.argsort()[::-1]
fig.add_trace(go.Bar(x=x[idx], y=y[idx], showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=x[idx], y=np.ones(len(idx))*best_thres, mode="lines", line=dict(dash="dash", color="red"), showlegend=False), row=1, col=1)


x = keys[~binary_gt]
y = pred_inf[~binary_gt]
idx = y.argsort()[::-1]
fig.add_trace(go.Bar(x=x[idx], y=y[idx], showlegend=False), row=1, col=2)
fig.add_trace(go.Scatter(x=x[idx], y=np.ones(len(idx))*best_thres, mode="lines", line=dict(dash="dash", color="red"), showlegend=False), row=1, col=2)

fig.update_layout(template="plotly_white", width=700, height=400, margin=dict(t=30, l=70, b=50, r=10))
fig.update_xaxes(showline=True, linewidth=2, linecolor='black', mirror=False)
fig.update_yaxes(showline=True, linewidth=2, linecolor='black', mirror=False)
#fig.update_yaxes(title="Pred Influences")

fig.write_image("influences_binary.png")
fig.write_image("influences_binary.pdf")
fig.show()

In [76]:

influences = {
    "ccoha1.txt": {}, 
    "ccoha2.txt": {}
}
for i in outputs:
    target_word, corpus, influence, binary_gt, grade_gt, n_removed = i

    
    if influence < 1:
        influences[corpus][target_word] = influence
        #influences[target_word].append(influence)
    else:
        influences[corpus][target_word] = 1.
        #influences[target_word].append(1.)
        #influences[target_word].append(influence / n_removed)


In [83]:
fig = go.Figure()

keys = list(set(list(influences["ccoha1.txt"].keys()) + list(influences["ccoha2.txt"].keys())))

values = np.array([influences["ccoha1.txt"][k] if k in influences["ccoha1.txt"] else 0. for k in keys])
fig.add_trace(go.Bar(x=keys, y=values, name="1810-1860"))

values = np.array([influences["ccoha2.txt"][k] if k in influences["ccoha2.txt"] else 0. for k in keys])
fig.add_trace(go.Bar(x=keys, y=values, name="1960-2010"))


fig.update_layout(template="plotly_white", width=600, height=400, margin=dict(t=30, l=50, b=50, r=10))
fig.update_xaxes(showline=True, linewidth=2, linecolor='black', mirror=False)
fig.update_yaxes(showline=True, linewidth=2, linecolor='black', mirror=False)
fig.update_layout(legend=dict(orientation="h", x=0.05, y=0.93))
fig.update_yaxes(title=r'$\Delta L_\text{MLM}(\hat{\theta}, .)$')

fig.write_image("influences_by_corpus.png")
fig.write_image("influences_by_corpus.pdf")
fig.show()
