In [1]:
from IPython.display import clear_output
from functions import *
import numpy as np
import pandas as pd
from transformers import *
from transformers.tokenization_utils import TextInputPair
from sklearn.neural_network import MLPClassifier
import tensorflow as tf
import pickle
import scipy as sc
import math as mt
from joblib import dump, load
import random
from sklearn.neighbors import KNeighborsTransformer

### Load model & data

In [2]:
# BERT
casing = "bert-base-uncased" 
tokenizer = BertTokenizer.from_pretrained(casing, do_lower_case=True, add_special_tokens=True)

config = BertConfig(dropout=0.2, attention_dropout=0.2 ) #hidden_dropout_prob=0.2, attention_probs_dropout_prob=0.2
config.output_hidden_states = False # if true outputs all layers

model = TFBertModel.from_pretrained(casing, config = config)
model.trainable = False
emb_len = 768
clear_output()

# BERT
n_cluster = 27 # Number of clusters to use
n_pc = 12 # Number of main principal components to drop for local method
n_pc_global = 15 # Number of main principal components to drop for global method
spectkn = ""

In [88]:
# GPT-2
casing = "gpt2" 
tokenizer = GPT2Tokenizer.from_pretrained(casing, do_lower_case=True, add_special_tokens=True)
config = GPT2Config()
config.output_hidden_states = True

model = TFGPT2Model.from_pretrained(casing, config=config)
model.trainable = False

emb_len = 768
clear_output()

# GPT2
n_cluster = 10
n_pc = 30
#n_pc = 12
n_pc_global = 30
spectkn = "Ġ"

In [144]:
# RoBERTa
casing = "roberta-base"
tokenizer = RobertaTokenizer.from_pretrained(casing, do_lower_case=True, add_special_tokens=True)
config = RobertaConfig.from_pretrained(casing)
config.output_hidden_states = True

model = TFRobertaModel.from_pretrained(casing, config=config)
model.trainable = False
emb_len = 768
clear_output()

# RoBERTa
n_cluster = 27
n_pc = 12
n_pc_global = 25
spectkn = "Ġ"

In [3]:
with open('data.150k.pickle', 'rb') as f:
    x = pickle.load(f)

### Gather representations of each stop word / punctuation of interest

In [4]:
the_grps = [y for y in x if "the" in y]
of_grps = [y for y in x if "of" in y]
period_grps =  [y for y in x if "." in y]
comma_grps = [y for y in x if "," in y]

In [136]:
random.seed(1)
data_the = random.sample(the_grps,200)
data_of = random.sample(of_grps,200)
data_period = random.sample(period_grps,200)
data_comma = random.sample(comma_grps,200)

In [145]:
data_comma_sents = []
for i in range(200):
    sents_grp = []
    for j in range(6):
        sents_grp.append(tokenizer.convert_ids_to_tokens(tokenizer.encode(" ".join(data_comma[i][j])[:-2] + "." )))
    data_comma_sents.append(sents_grp)
data_comma = data_comma_sents

data_period_sents = []
for i in range(200):
    sents_grp = []
    for j in range(6):
        sents_grp.append(tokenizer.convert_ids_to_tokens(tokenizer.encode(" ".join(data_period[i][j])[:-2] + "." )))
    data_period_sents.append(sents_grp)
data_period = data_period_sents

data_of_sents = []
for i in range(200):
    sents_grp = []
    for j in range(6):
        sents_grp.append(tokenizer.convert_ids_to_tokens(tokenizer.encode(" ".join(data_of[i][j])[:-2] + ".")))
    data_of_sents.append(sents_grp)
data_of = data_of_sents

data_the_sents = []
for i in range(200):
    sents_grp = []
    for j in range(6):
        sents_grp.append(tokenizer.convert_ids_to_tokens(tokenizer.encode(" ".join(data_the[i][j])[:-2] + ".")))
    data_the_sents.append(sents_grp)
data_the = data_the_sents

In [155]:
ids_the = []
for group in data_the:
    gids = []
    for i in range(6):
        gids.append(tokenizer.convert_tokens_to_ids(group[i]))
    ids_the.append(gids)

ids_of = []
for group in data_of:
    gids = []
    for i in range(6):
        gids.append(tokenizer.convert_tokens_to_ids(group[i]))
    ids_of.append(gids)

ids_period = []
for group in data_period:
    gids = []
    for i in range(6):
        gids.append(tokenizer.convert_tokens_to_ids(group[i]))
    ids_period.append(gids)

ids_comma = []
for group in data_comma:
    gids = []
    for i in range(6):
        gids.append(tokenizer.convert_tokens_to_ids(group[i]))
    ids_comma.append(gids)

In [158]:
reps_comma = []
for i in range(200):
    groupoutput = []
    for j in range(6):
        groupoutput.append(model(np.asarray([ids_comma[i][j]], dtype="int32"))[0][0])
    reps_comma.append(np.asarray(groupoutput))

reps_period = []
for i in range(200):
    groupoutput = []
    for j in range(6):
        groupoutput.append(model(np.asarray([ids_period[i][j]], dtype="int32"))[0][0])
    reps_period.append(np.asarray(groupoutput))

reps_of = []
for i in range(200):
    groupoutput = []
    for j in range(6):
        groupoutput.append(model(np.asarray([ids_of[i][j]], dtype="int32"))[0][0])
    reps_of.append(np.asarray(groupoutput))

reps_the = []
for i in range(200):
    groupoutput = []
    for j in range(6):
        groupoutput.append(model(np.asarray([ids_the[i][j]], dtype="int32"))[0][0])
    reps_the.append(np.asarray(groupoutput))


  return array(a, dtype, copy=False, order=order)


## Results

Set the stop word / punctuation of interest

In [198]:
token = "," #  "the" , "of" "," , "."
data_ = data_comma # data_the, data_of, data_comma, data_period
reps = reps_comma # reps_the, reps_of, reps_comma, reps_period
spectkn_ = spectkn
if token == "." and casing == "roberta-base":
    spectkn_ = ""

Get results for punctuation of interest

In [199]:
tokenpositions = [] # list of tuple (group,sent,index,fullidx,grpix)
finalreps = []
tokenreps = []
cnt = 0
kix = 0
for i in range(200):
    for j in range(6):
        first_added = False
        for l in range(len(data_[i][j])):
            finalreps.append(reps[i][j][l])
            if data_[i][j][l] == spectkn_ + token and not first_added:
                tokenreps.append(reps[i][j][l])
                tokenpositions.append( (i,j,l,cnt,kix) )
                first_added = True
                kix +=1
            cnt+=1

In [200]:
knn = KNeighborsTransformer(n_neighbors=6)
knn.fit_transform(tokenreps)

<1200x1200 sparse matrix of type '<class 'numpy.float32'>'
	with 8400 stored elements in Compressed Sparse Row format>

In [201]:
fullixs = set([pos[4] for pos in tokenpositions])
trueneighbors = []
for pos in tokenpositions:
    group,sent,index,fullidx,tokenidx = pos
    scores, neighs = knn.kneighbors([tokenreps[tokenidx]], 6)
    truegroupixs = set(map(lambda y: y[4], filter(lambda x: x[0] == group,tokenpositions)))
    trueneighbors.append(len(set(neighs[0]).intersection(truegroupixs))/6)

In [202]:
# Baseline result
np.mean(trueneighbors)

0.8266666666666667

In [203]:
# Remove dominant directions and repeat
iso_tokenreps = cluster_based(np.array(tokenreps),1,n_pc,emb_len)

In [204]:
knn = KNeighborsTransformer(n_neighbors=6)
knn.fit_transform(iso_tokenreps)

<1200x1200 sparse matrix of type '<class 'numpy.float64'>'
	with 8400 stored elements in Compressed Sparse Row format>

In [205]:
fullixs = set([pos[4] for pos in tokenpositions])
trueneighbors = []
for pos in tokenpositions:
    group,sent,index,fullidx,tokenidx = pos
    scores, neighs = knn.kneighbors([iso_tokenreps[tokenidx]], 6)
    truegroupixs = set(map(lambda y: y[4], filter(lambda x: x[0] == group,tokenpositions)))
    trueneighbors.append(len(set(neighs[0]).intersection(truegroupixs))/6)

In [206]:
np.mean(trueneighbors)

0.7934722222222222