In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from matplotlib import pyplot as plt

sns.set(rc={'figure.figsize':(8,5),
            "font.size":16,
            "axes.titlesize":16,
            "axes.labelsize":16,
            "xtick.labelsize": 16.0,
            "ytick.labelsize": 16.0,
            "legend.fontsize": 16.0})

import torch
import torch.nn.functional as F
from fairseq.modules.adaptive_input import AdaptiveInput
from fairseq.modules.adaptive_softmax import AdaptiveSoftmax

pd.set_option('display.width', 1000)
pd.set_option('display.max_colwidth', 199)

from itertools import islice
import csv
import os
from collections import Counter
from random import shuffle, sample
from tqdm.notebook import tqdm
tqdm.pandas()

import wptools
from nltk.corpus import wordnet 
# import nltk
# nltk.download('wordnet')
import warnings
warnings.filterwarnings('ignore')

## Load WikiLM's Vocabulary

In [3]:
model_dir = "adaptive_lm_wiki103.v2/"
vocab = [['<s>', -1], ['<pad>', -1], ['</s>', -1], ['<unk>', -1]]
dict_path = os.path.join(model_dir, "dict.txt")
with open(dict_path) as fd:
    vocab.extend([line.strip('\n').split(' ') for line in fd.readlines()])
tok_to_idx = {tok[0]: idx for idx, tok in enumerate(vocab)}

## Load Data and Compute Saturation Events

In [5]:
pkl_path = "dim_coef_with_vectors_10000sents.pkl"
df = pd.read_pickle(pkl_path)

In [6]:
def get_prefix_target(row):
    tokens = row["text_"].split(" ")
    prefix = " ".join(tokens[:row["random_pos"] + 1])
    
    if row["random_pos"] + 1 < len(tokens):
        target = tokens[row["random_pos"] + 1]
    else:
        target = "</>"
    
    return prefix, target

df["text_"] = df.text.apply(lambda x: "<s> " + x)
df[["prefix", "target"]] = df.apply(lambda x: pd.Series(get_prefix_target(x)), axis=1)

In [7]:
df["residual_top_tokens"] = df.residual_top_tokens_idx.apply(lambda x: [
    [vocab[token_idx][0] for token_idx in layer_token_idx]
    for layer_token_idx in x
])
df[[f'residual_top_tokens_{i}' for i in range(16)]] = pd.DataFrame(df.residual_top_tokens.tolist(), index= df.index)
df[[f'residual_top_tokens_prob_{i}' for i in range(16)]] = pd.DataFrame(df.residual_top_tokens_prob.tolist(), index= df.index)
df["layer_top_tokens"] = df.layer_output_top_tokens_idx.apply(lambda x: [
    [vocab[token_idx][0] for token_idx in layer_token_idx]
    for layer_token_idx in x
])
df[[f'layer_top_tokens_{i}' for i in range(16)]] = pd.DataFrame(df.layer_top_tokens.tolist(), index= df.index)
df[[f'layer_top_tokens_prob_{i}' for i in range(16)]] = pd.DataFrame(df.layer_output_top_tokens_prob.tolist(), index= df.index)

df["layer_preds"] = df.layer_output_top_tokens_idx.apply(lambda x: [vocab[layer_preds[0]][0] for layer_preds in x])

In [8]:
def get_top_coeffs(coefs, k):
    max_coef_dims = []
    max_coef_vals = []
    for l in range(16):
        arr = np.array(coefs[l][:-1])
        ind = np.argpartition(arr, -k)[-k:]

        max_coef_dims_l = ind[np.argsort(arr[ind])][::-1]
        max_coef_vals_l = [round(arr[i], 4) for i in max_coef_dims_l]
        
        max_coef_dims.append(max_coef_dims_l)
        max_coef_vals.append(max_coef_vals_l)
    
    return max_coef_dims, max_coef_vals

for k in [1, 3, 10]:
    df[[f"max{k}_coef_dims", f"max{k}_coef_vals"]] = df[f"coeffs_vals"].apply(lambda x: pd.Series(get_top_coeffs(x, k)))

df[[f'max10_coef_dims_{i}' for i in range(16)]] = pd.DataFrame(df.max10_coef_dims.tolist(), index= df.index)
df[[f'max10_coef_vals_{i}' for i in range(16)]] = pd.DataFrame(df.max10_coef_vals.tolist(), index= df.index)

In [9]:
def get_pred_fix_event_info(row):
    layer_preds = row["layer_preds"]
    final_pred = layer_preds[-1]
    j = tok_to_idx[final_pred]
    
    # Find the layer where the model outputs its final prediction.
    i = 0
    while i < len(layer_preds):
        if layer_preds[len(layer_preds)-1-i] != final_pred:
            break
        i += 1
    l = 16 - i   # l is between 0 and 15.
    
    # Get the candidate tokens at that layer -- the tokens that the predicted token eliminated.
    if final_pred in row[f"residual_top_tokens_{l}"]:
        final_pred_idx_in_res = row[f"residual_top_tokens_{l}"].index(final_pred)
    else:
        final_pred_idx_in_res = len(row[f"residual_top_tokens_{l}"])
    ts = row[f"residual_top_tokens_{l}"][:final_pred_idx_in_res]
    ts_ids = [
        tok_to_idx[tok]
        for tok in ts
    ]
    
    return (l, final_pred_idx_in_res, ts, ts_ids, j, final_pred)

df["pred_fix_event_info"] = df.apply(
    lambda row: get_pred_fix_event_info(row),
    axis=1
)

## Loading and Processing clusters, Organizing the dataset

In [11]:
clusters = 200
predicted_clusters = np.load('cosine_'+str(clusters)+'_projected_values.npy')
clsters = {i:[] for i in range(clusters)}
clsters_dicted = {i:[] for i in range(clusters)}
d = {}
inv_d = {}
cnt = 0
for i in range(16):
    for j in range(4096):
        d[cnt] = (i,j)
        inv_d[(i,j)] = cnt
        cnt += 1
for i,x in enumerate(predicted_clusters):
    clsters[x].append(i)
    clsters_dicted[x].append(d[i])
inv_map = {vi: k for k, v in clsters_dicted.items() for vi in v}

In [12]:
my_df = df.copy()
my_df['pred_fix_event_info_'] = 0
my_df['val_to_cluster'] = None
my_df['val_to_cluster_real'] = None
my_df['val_to_m'] = None
for index, row in my_df.iterrows():
    my_df.at[index,'pred_fix_event_info_'] = int(row['pred_fix_event_info'][0])
    dr = {}
    ddr = {}
    dddr = {}
    ddddr = {}
    for k in range(0,16):
        dr[k] = [x for x in my_df.at[index,'max10_coef_dims_'+ str(k)]]
        ddr[k] = [df.at[index,'max10_coef_vals_'+ str(k)][i] for i,x in enumerate(my_df.at[index,'max10_coef_dims_'+ str(k)])]
        dddr[k] = [inv_map[(k,x)] for x in my_df.at[index,'max10_coef_dims_'+ str(k)]]
    my_df.at[index,'val_to_cluster'] = dr
    my_df.at[index,'val_to_cluster_real'] = dddr
    my_df.at[index,'val_to_m'] = ddr

## Training and Evaluation Function

In [15]:
def train_and_test(train, test):
    df_to_train = train.copy()#.drop_duplicates(subset=['text']) #acc 90, save 3.9
    to_analyize = []
    df_to_test = test.copy()[:].reset_index()
    cnt = 0
    tot = 0
    df_to_test['final_pred_rank'] = 0
    df_to_test['pred_freq'] = 0
    df_to_test['correct_early_exit'] = False
    df_to_test['correct_pred'] = False
    df_to_test['saved_layers'] = 0
    df_to_test['intersection'] = 0
    per_layer_score = {i:([],[]) for i in range(16)}
    per_layer_acc = {i:[0,0] for i in range(16)}
    y_test, y_pred = [], []
    saved_layers = []
    succ = []
    func = lambda x: int(x)
    clustering_format = 'val_to_cluster_real'
    intersect_two_lists = lambda l1,l2: [x for x in l1 if x in l2]
    all_layers_but = lambda x: [a for a in range(x+1,min(x+8,15))]
    def all_layers_but(x):
        return [a for a in range(x+1,16)]

    def ordered_inter(x1,x2):
        cnt = 0
        for i in range(len(x1)):
            if x1[i] == x2[i]:
                cnt+=1
        return cnt
    comparison_func = ordered_inter
    import random

    for idx in tqdm(range(len(df_to_test)), leave=False):
        for l in range(16):
            if l == 15:
                break
            curr_clusters = df_to_test.at[idx,clustering_format][l]

            df_to_trainn = df_to_train.copy()        
            df_to_train_ = [g[l] for g in df_to_trainn[df_to_trainn['pred_fix_event_info_'].isin(list(range(l,l+1)))][clustering_format].values]

            set_df_to_train = [tuple(i) for i in df_to_train_]
            curr_max_intersect = max([comparison_func(a,curr_clusters) for a in set_df_to_train])
            curr_mean_intersect = np.mean([comparison_func(a,curr_clusters) for a in set_df_to_train])

            df_not_to_train_per_layer = [[g[l] for g in df_to_trainn[df_to_trainn['pred_fix_event_info_']==x][clustering_format].values]\
                                         for x in all_layers_but(l)]
            set_df_not_to_train_per_layer = [[tuple(i) for i in x] for x in df_not_to_train_per_layer]
            not_curr_max_intersect_per_layer = [max([comparison_func(a,curr_clusters) for a in x]) for x in set_df_not_to_train_per_layer]
            not_curr_min_intersect_per_layer = [min([comparison_func(a,curr_clusters) for a in x]) for x in set_df_not_to_train_per_layer]
            not_curr_mean_intersect_per_layer = [np.mean([comparison_func(a,curr_clusters) for a in x]) for x in set_df_not_to_train_per_layer]
            condition = [curr_max_intersect>not_curr_max_intersect_per_layer[i] and curr_mean_intersect>not_curr_mean_intersect_per_layer[i]  for i in range(len(not_curr_max_intersect_per_layer))]
            condition = np.all(condition)
            if condition:
                break
        curr_fixation_layer = df_to_test.at[idx,'pred_fix_event_info_']
        fix_pred_top_tokens = df_to_test.at[idx,'layer_top_tokens_{}'.format(curr_fixation_layer)]
        df_to_test.at[idx,'pred_freq'] = tok_to_idx[fix_pred_top_tokens[0]]
        last_l_preds = df_to_test.at[idx,'layer_top_tokens_{}'.format(15)]
        curr_l_preds = df_to_test.at[idx,'layer_top_tokens_{}'.format(l)]
        if fix_pred_top_tokens[0] in last_l_preds:
            df_to_test.at[idx,'final_pred_rank'] = last_l_preds.index(fix_pred_top_tokens[0])
        else:
            df_to_test.at[idx,'final_pred_rank'] = -1
        if l == df_to_test.at[idx,'pred_fix_event_info_']:
            per_layer_score[l][0].append(1)
            per_layer_score[l][1].append(1)
            df_to_test.at[idx,'correct_early_exit'] = True
        else:
            per_layer_score[l][0].append(1)
            per_layer_score[l][1].append(0)
            per_layer_score[df_to_test.at[idx,'pred_fix_event_info_']][0].append(0)
            per_layer_score[df_to_test.at[idx,'pred_fix_event_info_']][1].append(1)
            if curr_l_preds[0] in last_l_preds:
                df_to_test.at[idx,'final_pred_rank'] = last_l_preds.index(curr_l_preds[0])
            else:
                df_to_test.at[idx,'final_pred_rank'] = -1

        df_to_test.at[idx,'saved_layers'] = 15-l
        df_to_test.at[idx,'correct_pred'] = True if curr_l_preds[0] == last_l_preds[0] else False
        y_test.append(df_to_test.at[idx,'pred_fix_event_info_'])
        y_pred.append(l)


    if l == df_to_test.at[idx,'pred_fix_event_info_']:
        cnt += 1
        per_layer_acc[df_to_test.at[idx,'pred_fix_event_info_']][0] += 1
    tot += 1
    per_layer_acc[df_to_test.at[idx,'pred_fix_event_info_']][1] += 1
    t_test = df_to_test[~df_to_test['correct_pred']]
    p_test = df_to_test[df_to_test['correct_pred']]
    df_to_test['correct_pred'].mean()
    # return len(p_test)/len(df_to_test), df_to_test[df_to_test['correct_pred']]['saved_layers'].mean()
    return len(p_test)/len(df_to_test), df_to_test['saved_layers'].mean()

In [16]:
saved_layers = []
accuracies = []
for random_seed in tqdm([38,39,40,41,42]):
    train, test = \
                  np.split(my_df.sample(frac=1, random_state=random_seed),
                           [int(.9*len(df))])

    train.reset_index(inplace=True)
    test.reset_index(inplace=True)
    acc ,saved = train_and_test(train, test)
    saved_layers.append(saved)
    accuracies.append(acc)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [18]:
results = {'baseline':[], 'accuracy':[], 'accuracy (std)':[],'saved layers':[], 'saved layers (std)':[]}
results['baseline'].append('Our method')
arr_save = np.array(saved_layers)
arr_acc = 100*np.array(accuracies)
results['saved layers'].append(arr_save.mean())
results['accuracy'].append(arr_acc.mean())
results['saved layers (std)'].append(arr_save.std())
results['accuracy (std)'].append(arr_acc.std())
results_df = pd.DataFrame.from_dict(results)
results_df

Unnamed: 0,baseline,accuracy,accuracy (std),saved layers,saved layers (std)
0,Our method,94.86,0.557136,3.3916,0.138319
