In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from matplotlib import pyplot as plt

sns.set(rc={'figure.figsize':(8,5),
            "font.size":16,
            "axes.titlesize":16,
            "axes.labelsize":16,
            "xtick.labelsize": 16.0,
            "ytick.labelsize": 16.0,
            "legend.fontsize": 16.0})

import torch
import torch.nn.functional as F
from fairseq.modules.adaptive_input import AdaptiveInput
from fairseq.modules.adaptive_softmax import AdaptiveSoftmax

pd.set_option('display.width', 1000)
pd.set_option('display.max_colwidth', 199)

from itertools import islice
import csv
import os
from collections import Counter
from random import shuffle, sample
from tqdm.notebook import tqdm
tqdm.pandas()

import wptools
from nltk.corpus import wordnet 
# import nltk
# nltk.download('wordnet')
import warnings
warnings.filterwarnings('ignore')

## Load WikiLM's Vocabulary

In [10]:
model_dir = "adaptive_lm_wiki103.v2/"
vocab = [['<s>', -1], ['<pad>', -1], ['</s>', -1], ['<unk>', -1]]
dict_path = os.path.join(model_dir, "dict.txt")
with open(dict_path) as fd:
    vocab.extend([line.strip('\n').split(' ') for line in fd.readlines()])
tok_to_idx = {tok[0]: idx for idx, tok in enumerate(vocab)}

## Load Data and Compute Saturation Events

In [2]:
pkl_path = "dim_coef_with_vectors_10000sents.pkl"
df = pd.read_pickle(pkl_path)

In [3]:
def get_prefix_target(row):
    tokens = row["text_"].split(" ")
    prefix = " ".join(tokens[:row["random_pos"] + 1])
    
    if row["random_pos"] + 1 < len(tokens):
        target = tokens[row["random_pos"] + 1]
    else:
        target = "</>"
    
    return prefix, target

df["text_"] = df.text.apply(lambda x: "<s> " + x)
df[["prefix", "target"]] = df.apply(lambda x: pd.Series(get_prefix_target(x)), axis=1)

In [7]:
df["residual_top_tokens"] = df.residual_top_tokens_idx.apply(lambda x: [
    [vocab[token_idx][0] for token_idx in layer_token_idx]
    for layer_token_idx in x
])
df[[f'residual_top_tokens_{i}' for i in range(16)]] = pd.DataFrame(df.residual_top_tokens.tolist(), index= df.index)
df[[f'residual_top_tokens_prob_{i}' for i in range(16)]] = pd.DataFrame(df.residual_top_tokens_prob.tolist(), index= df.index)
df["layer_top_tokens"] = df.layer_output_top_tokens_idx.apply(lambda x: [
    [vocab[token_idx][0] for token_idx in layer_token_idx]
    for layer_token_idx in x
])
df[[f'layer_top_tokens_{i}' for i in range(16)]] = pd.DataFrame(df.layer_top_tokens.tolist(), index= df.index)
df[[f'layer_top_tokens_prob_{i}' for i in range(16)]] = pd.DataFrame(df.layer_output_top_tokens_prob.tolist(), index= df.index)

df["layer_preds"] = df.layer_output_top_tokens_idx.apply(lambda x: [vocab[layer_preds[0]][0] for layer_preds in x])

In [8]:
def get_top_coeffs(coefs, k):
    max_coef_dims = []
    max_coef_vals = []
    for l in range(16):
        arr = np.array(coefs[l][:-1])
        ind = np.argpartition(arr, -k)[-k:]

        max_coef_dims_l = ind[np.argsort(arr[ind])][::-1]
        max_coef_vals_l = [round(arr[i], 4) for i in max_coef_dims_l]
        
        max_coef_dims.append(max_coef_dims_l)
        max_coef_vals.append(max_coef_vals_l)
    
    return max_coef_dims, max_coef_vals

for k in [1, 3, 10]:
    df[[f"max{k}_coef_dims", f"max{k}_coef_vals"]] = df[f"coeffs_vals"].apply(lambda x: pd.Series(get_top_coeffs(x, k)))

df[[f'max10_coef_dims_{i}' for i in range(16)]] = pd.DataFrame(df.max10_coef_dims.tolist(), index= df.index)
df[[f'max10_coef_vals_{i}' for i in range(16)]] = pd.DataFrame(df.max10_coef_vals.tolist(), index= df.index)

In [11]:
def get_pred_fix_event_info(row):
    layer_preds = row["layer_preds"]
    final_pred = layer_preds[-1]
    j = tok_to_idx[final_pred]
    
    # Find the layer where the model outputs its final prediction.
    i = 0
    while i < len(layer_preds):
        if layer_preds[len(layer_preds)-1-i] != final_pred:
            break
        i += 1
    l = 16 - i   # l is between 0 and 15.
    
    # Get the candidate tokens at that layer -- the tokens that the predicted token eliminated.
    if final_pred in row[f"residual_top_tokens_{l}"]:
        final_pred_idx_in_res = row[f"residual_top_tokens_{l}"].index(final_pred)
    else:
        final_pred_idx_in_res = len(row[f"residual_top_tokens_{l}"])
    ts = row[f"residual_top_tokens_{l}"][:final_pred_idx_in_res]
    ts_ids = [
        tok_to_idx[tok]
        for tok in ts
    ]
    
    return (l, final_pred_idx_in_res, ts, ts_ids, j, final_pred)


df["pred_fix_event_info"] = df.apply(
    lambda row: get_pred_fix_event_info(row),
    axis=1
)

In [12]:
my_df = df.copy()
my_df['pred_fix_event_info_'] = 0
my_df['val_to_cluster'] = None
my_df['val_to_cluster_real'] = None
my_df['val_to_m'] = None
for index, row in my_df.iterrows():
    my_df.at[index,'pred_fix_event_info_'] = int(row['pred_fix_event_info'][0])

## Organize Dataset, Split, and Perform Training

In [13]:
from sklearn import datasets, svm, metrics
from sklearn.svm import SVC
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import precision_recall_curve

def organize_df_for_model(df, vec_type = 'sum', train=True):
    vecs_data = {i:[] for i in range(num_layers)}
    for idx in range(len(df)):
        if vec_type=='sum':
            curr_vecs = df.at[idx, 'residual_vectors'] + df.at[idx, 'ffn_output_vectors']
        else:
            curr_vecs = df.at[idx, vec_type]
        curr_fixation_layer = df.at[idx,'pred_fix_event_info_']
        for l in range(num_layers):
            vec = np.squeeze(np.array(curr_vecs[l]))
            if l == curr_fixation_layer:
                vecs_data[l].append((vec,1))
            else:
                vecs_data[l].append((vec,0))

    return vecs_data

num_layers = 16
saved_layers_results = {'residual_vectors':[],'ffn_output_vectors':[], 'sum':[]}
accuracy_results = {'residual_vectors':[],'ffn_output_vectors':[], 'sum':[]}
for random_seed in tqdm([38,39,40,41,42]):
    train_df, validate_df, test_df = \
              np.split(my_df.sample(frac=1, random_state=random_seed), 
                       [int(.8*len(my_df)), int(0.9*len(my_df))])
    train_df = train_df.reset_index()
    validate_df = validate_df.reset_index()
    test_df = test_df.reset_index()
    for feature_vec in tqdm(['residual_vectors','ffn_output_vectors', 'sum'], leave=False):
        train_data = organize_df_for_model(train_df, vec_type=feature_vec)
        validation_data = organize_df_for_model(validate_df, train=False, vec_type=feature_vec)
        test_data= organize_df_for_model(test_df, train=False, vec_type=feature_vec)
        best_scores = []
        best_estimators = []
        scalers = []
        for l in tqdm(range(num_layers-1), leave=False):
            sc = StandardScaler()
            X = train_data[l]
            X_train, y_train = [x[0] for x in X], [x[1] for x in X]
            model = LogisticRegression(solver='liblinear')
            X_train = sc.fit_transform(X_train)
            scalers.append(sc)
            grid={"C":np.logspace(-3,3,7), "penalty":["l1","l2"],"class_weight":['balanced']}# l1 lasso l2 ridge
            logreg=LogisticRegression()
            logreg_cv=GridSearchCV(logreg,grid,cv=8)
            logreg_cv.fit(X_train,y_train)

            best_scores.append(logreg_cv.best_score_)
            best_estimators.append(logreg_cv.best_estimator_)
        valid_data_probas = []
        threshes = []
        for l in range(num_layers-1):
            sc = StandardScaler()
            valid_data_scaled_l = np.array(sc.fit_transform(np.array([x[0] for x in validation_data[l]])))
            valid_labels = [x[1] for x in validation_data[l]]
            model = best_estimators[l]
            probas = model.predict_proba(valid_data_scaled_l)[:,1]
            precision, recall, thresholds = precision_recall_curve(valid_labels, probas)
            f1 = 2*precision*recall/(precision+recall)
            threshes.append(thresholds[np.where(precision>=0.75)[0][0]-1])
        test_data_scaled = []
        for l in range(num_layers-1):
            test_data_scaled.append(scalers[l].transform(np.array([x[0] for x in test_data[l]])))
        to_analyize = []
        df_to_test = test_df.copy()[:].reset_index()
        cnt = 0
        tot = 0
        df_to_test['final_pred_rank'] = 0
        df_to_test['pred_freq'] = 0
        df_to_test['correct_early_exit'] = False
        df_to_test['correct_pred'] = False
        df_to_test['saved_layers'] = 0
        df_to_test['intersection'] = 0
        per_layer_score = {i:([],[]) for i in range(num_layers)}
        per_layer_acc = {i:[0,0] for i in range(num_layers)}
        y_test, y_pred = [], []
        saved_layers = []
        succ = []
        def all_layers_but(x):
            a1 = [a for a in range(x+1,num_layers)]
            return a1 


        for idx in tqdm(range(len(df_to_test)), leave=False):
            for l in range(16):
                if l == 15:
                    break
                model = best_estimators[l]
                X_test = test_data_scaled[l][idx]
                if model.predict_proba(X_test.reshape(1,-1))[0][1]>threshes[l]:
                    break
            curr_fixation_layer = df_to_test.at[idx,'pred_fix_event_info_']
            fix_pred_top_tokens = df_to_test.at[idx,'layer_top_tokens_{}'.format(curr_fixation_layer)]
            df_to_test.at[idx,'pred_freq'] = tok_to_idx[fix_pred_top_tokens[0]]
            last_l_preds = df_to_test.at[idx,'layer_top_tokens_{}'.format(num_layers-1)]
            curr_l_preds = df_to_test.at[idx,'layer_top_tokens_{}'.format(l)]
            if fix_pred_top_tokens[0] in last_l_preds:
                df_to_test.at[idx,'final_pred_rank'] = last_l_preds.index(fix_pred_top_tokens[0])
            else:
                df_to_test.at[idx,'final_pred_rank'] = -1
            if l == df_to_test.at[idx,'pred_fix_event_info_']:
                per_layer_score[l][0].append(1)
                per_layer_score[l][1].append(1)
                df_to_test.at[idx,'correct_early_exit'] = True
            else:
                per_layer_score[l][0].append(1)
                per_layer_score[l][1].append(0)
                per_layer_score[df_to_test.at[idx,'pred_fix_event_info_']][0].append(0)
                per_layer_score[df_to_test.at[idx,'pred_fix_event_info_']][1].append(1)
                if curr_l_preds[0] in last_l_preds:
                    df_to_test.at[idx,'final_pred_rank'] = last_l_preds.index(curr_l_preds[0])
                else:
                    df_to_test.at[idx,'final_pred_rank'] = -1

            df_to_test.at[idx,'saved_layers'] = 15-l
            df_to_test.at[idx,'correct_pred'] = True if curr_l_preds[0] == last_l_preds[0] else False
            y_test.append(df_to_test.at[idx,'pred_fix_event_info_'])
            y_pred.append(l)

            if l == df_to_test.at[idx,'pred_fix_event_info_']:
                cnt += 1
                per_layer_acc[df_to_test.at[idx,'pred_fix_event_info_']][0] += 1
            tot += 1
            per_layer_acc[df_to_test.at[idx,'pred_fix_event_info_']][1] += 1
            t_test = df_to_test[~df_to_test['correct_pred']]
        p_test = df_to_test[df_to_test['correct_pred']]
        accuracy_results[feature_vec].append(len(p_test)/len(df_to_test))
        saved_layers_results[feature_vec].append(df_to_test['saved_layers'].mean())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [16]:
results = {'baseline':[], 'accuracy':[], 'accuracy (std)':[],'saved layers':[], 'saved layers (std)':[]}
for k in saved_layers_results.keys():
    results['baseline'].append(k)
    arr_save = np.array(saved_layers_results[k])
    arr_acc = 100*np.array(accuracy_results[k])
    results['saved layers'].append(arr_save.mean())
    results['accuracy'].append(arr_acc.mean())
    results['saved layers (std)'].append(arr_save.std())
    results['accuracy (std)'].append(arr_acc.std())
results_df = pd.DataFrame.from_dict(results)
results_df

Unnamed: 0,baseline,accuracy,accuracy (std),saved layers,saved layers (std)
0,residual_vectors,94.4,2.128849,3.6046,0.542955
1,ffn_output_vectors,92.9,1.473771,3.8034,0.445841
2,sum,94.4,2.128849,3.6046,0.542955
