In [None]:
import math
from typing import List
import warnings

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import string
import spacy

warnings.filterwarnings(action='ignore', category=RuntimeWarning)

In [None]:
def plot_surprisals(words:List[str], surprisals_df):
    num_words = len(words)
    cols = min(num_words, 3)
    rows = math.ceil(num_words / cols)

    plt.style.use('ggplot')
    fig, axs = plt.subplots(rows, cols, figsize=(cols*3, rows*3))
    axs = np.atleast_2d(axs)

    for i, word in enumerate(words):   
        word_data = surprisals_df[surprisals_df['Token'] == word].iloc[1:]  # mention this
        if word_data.empty:
            print(f'No data found for the word "{word}"')
            continue

        ax = axs[i//cols, i%cols]
        ax.plot(word_data['Steps'], word_data['MeanSurprisal'], marker='o')
        ax.set_title(f'"{word}"')
        ax.set_xlabel('Steps (log10)')
        ax.set_ylabel('Mean surprisal')
        ax.set_xscale('log')
        ax.invert_yaxis()

    # Remove empty subplots
    for j in range(i+1, rows*cols):
        fig.delaxes(axs.flatten()[j])

    plt.tight_layout()
    plt.show()


In [None]:
wikitext_surprisals = 'sample_data/wikitext/bert_surprisals_large.txt'
chang_bergen_surprisals = 'r_code/tacl_data/lm_data/bert_surprisals.txt'
surprisals = pd.read_csv(wikitext_surprisals, delimiter='\t')

In [None]:
surprisals = (surprisals
    .sort_values(['Token', 'Steps'])
    .groupby('Token')
    .apply(lambda x: x.assign(MeanSurprisalDiff = x['MeanSurprisal'].diff().fillna(0), 
                              StdevSurprisalDiff = x['StdevSurprisal'].diff().fillna(0)))
    .reset_index(drop=True))

surprisals.insert(0, 'Token', surprisals.pop('Token'))
surprisals = surprisals[surprisals['Token'].apply(lambda t: t.isascii() and not t.isdigit())]
surprisals

In [None]:
surprisals['Token'].nunique()

In [None]:
import random

random_1 = random.choice(surprisals['Token'].unique().tolist())
random_2 = random.choice(surprisals['Token'].unique().tolist())
plot_surprisals(['walk', random_1, random_2], surprisals)

### Most vs least frequent words

In [None]:
frequent = surprisals[surprisals['NumExamples'] == 512]
infrequent = surprisals[surprisals['NumExamples'] == 1]

print(f'Percentage of frequent words in the dataset (occuring at least 512 times): {len(frequent)/len(surprisals)*100:.2f}%')
print(f'Percentage of infrequent words in the dataset (occuring only once): {len(infrequent)/len(surprisals)*100:.2f}%')

plot_surprisals(
    frequent['Token'].drop_duplicates().sample(1).tolist() + infrequent['Token'].drop_duplicates().sample(1).tolist(), 
    surprisals
    )

In [None]:
# for each step, average surprisal across all words

avg_surprisals = surprisals.groupby('Steps')['MeanSurprisal'].mean().reset_index()
avg_surprisals.plot(x='Steps', y='MeanSurprisal', logx=True)\
    .set_ylim(avg_surprisals['MeanSurprisal'].iloc[1:].max(), avg_surprisals['MeanSurprisal'].min())

In [None]:
# for each step, average surprisal across all words with at least 512 examples (frequent words)
avg_freq_surprisals = frequent.groupby('Steps')['MeanSurprisal'].mean().reset_index().assign(Diffs=lambda x: x['MeanSurprisal'].diff().fillna(0))

# for each step, average surprisal across all words with only 1 example (infrequent words)
avg_infreq_surprisals = infrequent.groupby('Steps')['MeanSurprisal'].mean().reset_index().assign(Diffs=lambda x: x['MeanSurprisal'].diff().fillna(0))

plt.figure()
plt.xscale('log')
# plt.yscale('log')

plt.plot(avg_freq_surprisals['Steps'] + 10000, avg_freq_surprisals['MeanSurprisal'], marker='o', label='Frequent Words')
plt.plot(avg_infreq_surprisals['Steps'] + 10000, avg_infreq_surprisals['MeanSurprisal'], marker='o', label='Infrequent Words')

plt.ylim(max(avg_freq_surprisals['MeanSurprisal'].max(), avg_infreq_surprisals['MeanSurprisal'].max()), 0)
plt.xlabel('Steps (log 10)')
plt.ylabel('Mean Surprisal')
plt.legend()

print('Frequent words (>= 512 examples):')
print(f"Min surprisal: {avg_freq_surprisals['MeanSurprisal'].min():.2f}")
print(f"Max surprisal (excluding the first step): {avg_freq_surprisals['MeanSurprisal'].iloc[1:].max():.2f}")
print(f"Average variability (excluding the first step): {avg_freq_surprisals['Diffs'].iloc[2:].mean():.2f}\n")

print('Inrequent words (= 1 example):')
print(f"Min surprisal: {avg_infreq_surprisals['MeanSurprisal'].min():.2f}")
print(f"Max surprisal (excluding the first step): {avg_infreq_surprisals['MeanSurprisal'].iloc[1:].max():.2f}")
print(f"Average variability (excluding the first step): {avg_infreq_surprisals['Diffs'].iloc[2:].mean():.2f}")

plt.show()

In [None]:
# plot_surprisals(bert_surprisals['Token'].drop_duplicates().tolist(), bert_surprisals)

### Words with different POS

In [None]:
def get_pos_tags(doc_path):
    nlp = spacy.load('en_core_web_sm')
    nlp.max_length = 2000000
    pos_dict = {}
    with open(doc_path, 'r', encoding='utf-8') as file:
        text = file.read()
        doc = nlp(text)
        for token in doc:
            if token.text in pos_dict and not token.pos_ in pos_dict[token.text]:
                pos_dict[token.text].append(token.pos_)
            else:
                pos_dict[token.text] = [token.pos_]
        
    return pd.DataFrame(list(pos_dict.items()), columns=['Token', 'POS'])

In [None]:
def plot_avg(dfs: List[pd.DataFrame]):
    plt.figure()
    max_y = 0
    min_y = float('inf')
    for df in dfs:
        avg = (df.groupby('Steps')
                 .agg({'MeanSurprisal': 'mean', 'POS': 'first'})
                 .reset_index()
                 .assign(Diffs=lambda x: x['MeanSurprisal'].diff().fillna(0)))

        plt.plot(avg['Steps'] + 10000, avg['MeanSurprisal'], 
                 label=f"{avg['POS'].values[0][0]} (var: {avg['Diffs'].iloc[2:].mean():.2f})")
        max_y = avg['MeanSurprisal'].max() if avg['MeanSurprisal'].max() > max_y else max_y
        min_y = avg['MeanSurprisal'].min() if avg['MeanSurprisal'].min() < min_y else min_y

    plt.ylim(max_y, min_y - 1)
    plt.xlabel('Steps (log 10)')
    plt.ylabel('Mean Surprisal')
    plt.xscale('log')
    plt.legend()
    plt.show()

In [None]:
document = "sample_data/wikitext/wikitext103_test.txt"
pos_tags = get_pos_tags(document)
pos_tags

In [None]:
merged_df = pd.merge(surprisals, pos_tags, on='Token', how='inner')
merged_df

In [None]:
all_pos_tags = set([pos for pos_list in merged_df['POS'] if isinstance(pos_list, list) for pos in pos_list])
all_pos_tags

In [None]:
num_nouns = merged_df[merged_df['POS'].apply(lambda pos_list: 'NOUN' in pos_list)]['Token'].nunique()
num_verbs = merged_df[merged_df['POS'].apply(lambda pos_list: 'VERB' in pos_list)]['Token'].nunique()
num_adjs = merged_df[merged_df['POS'].apply(lambda pos_list: 'ADJ' in pos_list)]['Token'].nunique()
num_advs = merged_df[merged_df['POS'].apply(lambda pos_list: 'ADV' in pos_list)]['Token'].nunique()

print(f"Total number of nouns: {num_nouns}")
print(f"Total number of verbs: {num_verbs}")
print(f"Total number of adjectives: {num_adjs}")
print(f"Total number of adverbs: {num_advs}")

In [None]:
exclusive_noun = merged_df[merged_df['POS'].apply(lambda pos_list: 'NOUN' in pos_list and len(pos_list) == 1)]
exclusive_verb = merged_df[merged_df['POS'].apply(lambda pos_list: 'VERB' in pos_list and len(pos_list) == 1)]
adj = merged_df[merged_df['POS'].apply(lambda pos_list: 'ADJ' in pos_list and len(pos_list) == 1)]
adv = merged_df[merged_df['POS'].apply(lambda pos_list: 'ADV' in pos_list and len(pos_list) == 1)]

noun_sample = exclusive_noun['Token'].drop_duplicates().sample(3).tolist()
verb_sample = exclusive_verb['Token'].drop_duplicates().sample(3).tolist()
adj_sample = adj['Token'].drop_duplicates().sample(3).tolist()
adv_sample = adv['Token'].drop_duplicates().sample(3).tolist()

plot_surprisals(
    noun_sample + verb_sample + adj_sample + adv_sample, 
    merged_df
    )

In [None]:
plot_avg([exclusive_verb, exclusive_noun, adv, adj])

In [None]:
plot_avg([exclusive_verb, exclusive_verb[exclusive_verb['Token'] == 'walk']])

### Top 5 largest and smallest absolute diffs

In [None]:
largest_abs_diffs = merged_df.loc[merged_df['MeanSurprisalDiff'].abs().nlargest(5).index]
largest_abs_diffs

In [None]:
smallest_abs_diffs = merged_df[merged_df['Steps'] != 0].sort_values(by='MeanSurprisalDiff', key=abs).head(5)
smallest_abs_diffs

In [None]:
largest_diffs = merged_df.groupby('Steps').apply(lambda df: df.nlargest(5, 'MeanSurprisalDiff')).reset_index(drop=True)
largest_diffs[largest_diffs['Steps'] != 0]

In [None]:
smallest_diffs = merged_df.groupby('Steps').apply(lambda df: df.nsmallest(5, 'MeanSurprisalDiff')).reset_index(drop=True)
smallest_diffs[smallest_diffs['Steps'] != 0]

In [None]:
largest_abs_diffs = merged_df.groupby('Steps').apply(lambda df: df.loc[df['MeanSurprisalDiff'].abs().nlargest(1).index])
largest_abs_diffs[largest_abs_diffs['Steps'] != 0]

In [None]:
smallest_abs_diffs = merged_df.groupby('Steps').apply(lambda df: df.loc[df['MeanSurprisalDiff'].abs().nsmallest(1).index])
smallest_abs_diffs[smallest_abs_diffs['Steps'] != 0]