In [1]:
import nltk
from nltk.corpus import semcor
from IPython.display import clear_output
from functions import *
import numpy as np
import pandas as pd
from transformers import *
from transformers.tokenization_utils import TextInputPair
from sklearn.neural_network import MLPClassifier
import tensorflow as tf
import pickle
import scipy as sc
import math as mt
from joblib import dump, load
from sklearn.neighbors import KNeighborsTransformer
from collections import defaultdict
from scipy.spatial.distance import pdist, cdist, squareform

### Load model & data

In [2]:
# BERT
casing = "bert-base-uncased" 
tokenizer = BertTokenizer.from_pretrained(casing, do_lower_case=True, add_special_tokens=True)

config = BertConfig(dropout=0.2, attention_dropout=0.2 ) #hidden_dropout_prob=0.2, attention_probs_dropout_prob=0.2
config.output_hidden_states = False # if true outputs all layers

model = TFBertModel.from_pretrained(casing, config = config)
model.trainable = False
emb_len = 768
clear_output()

# BERT
n_cluster = 27 # Number of clusters to use
n_pc = 12 # Number of main principal components to drop for local method
n_pc_global = 15 # Number of main principal components to drop for global method

In [29]:
# GPT-2
casing = "gpt2" 
tokenizer = GPT2Tokenizer.from_pretrained(casing, do_lower_case=True, add_special_tokens=True)
config = GPT2Config()
config.output_hidden_states = True

model = TFGPT2Model.from_pretrained(casing, config=config)
model.trainable = False

emb_len = 768
clear_output()

# GPT2
n_cluster = 10
n_pc = 30
n_pc_global = 30

In [55]:
# RoBERTa
casing = "roberta-base"
tokenizer = RobertaTokenizer.from_pretrained(casing, do_lower_case=True, add_special_tokens=True)
config = RobertaConfig.from_pretrained(casing)
config.output_hidden_states = True

model = TFRobertaModel.from_pretrained(casing, config=config)
model.trainable = False
emb_len = 768
clear_output()

# RoBERTa
n_cluster = 27
n_pc = 12
n_pc_global = 25

In [None]:
nltk.download("semcor")

In [3]:
sents = semcor.tagged_sents(tag="sem")

### Get verbs representations

In [4]:
# Extract metadata - get all verbs from the dataset
verbs = [] # sentix, wordix, label, word
sentix = 0
for sent in sents:
    wordix = 0
    for word in sent:
        if type(word) == nltk.tree.Tree:
            if word.label() and type(word.label()) == nltk.corpus.reader.wordnet.Lemma:
                if word.label().synset() and word.label().synset().pos() == "v":
                    verbs.append((sentix,wordix,word.label(),word[0]))
        wordix+=1
    sentix+=1

In [5]:
# List of verbs we are interested in - provided by authors
trueverbs = [
    'say', 'said', 'have', 'had', 'win', 'won', 'study', 'studied', 'find', 'found', 'hold', 'held', 'make', 'made', 'tell', 'told', 'seek', 'sought', 'see', 'saw', 'get', 'got', 'shoot', 'shot', 'go', 'went', 'lead', 'led', 'leave', 'left', 'deny', 'denied', 'send', 'sent', 'keep', 'kept', 'lose', 'lost', 'feel', 'felt', 'spend', 'spent', 'draw', 'drew', 'throw', 'threw', 'try', 'tried', 'pay', 'paid', 'break', 'broke', 'come', 'came', 'run', 'ran', 'think', 'thought', 'carry', 'carried', 'catch', 'caught', 'lie', 'lay', 'fall', 'fell', 'write', 'wrote', 'know', 'knew', 'stand', 'stood', 'teach', 'taught', 'fight', 'fought', 'rise', 'rose', 'speak', 'spoken', 'choose', 'chosen', 'forget', 'forgotten', 'strike', 'struck', 'meet', 'met', 'build', 'built', 'apply', 'applied', 'sit', 'sat', 'sell', 'sold', 'buy', 'bought', 'feed', 'fed', 'ride', 'rode', 'drive', 'drove', 'wear', 'wore'
]

In [6]:
# filter out verbs we don't care about
filtered_verbs = list(filter(lambda x: x[3] in trueverbs,verbs))

In [7]:
# Get indices of only the sentences that contain the verbs we are interested in
sents_to_parse = set()
for vb in filtered_verbs:
    sents_to_parse = sents_to_parse.union([vb[0]])

In [8]:
len(sents_to_parse)

11838

In [9]:
idxs = list(sorted(sents_to_parse))

In [10]:
# maps to translate actual sentence index -> filtered sentence index and vice versa
translation_dict = dict(zip(idxs,range(len(idxs))))
translation_dict_inv = dict(zip(range(len(idxs)),idxs))

In [230]:
# Get the sentences we care about
finalsents = []
i = 0
for sent in semcor.sents():
    if i in idxs:
        finalsents.append(sent)
    i+=1

In [231]:
# Lowercase
for i in range(len(finalsents)):
    for j in range(len(finalsents[i])):
        finalsents[i][j] = finalsents[i][j].lower()

In [236]:
# Tokenize
ids = []
for sent in finalsents:
    ids.append(tokenizer.convert_tokens_to_ids(sent))

In [56]:
# Send through model... Run commented lines if no embeddings saved
with open('tense_roberta.pkl', 'rb') as f:
    reps = pickle.load(f)
# reps = []
# for i in range(len(ids)):
#     reps.append(model(np.asarray([ids[i]], dtype="int32"))[0])

In [238]:
# with open('gpt2_tense.pkl', 'wb') as f:
#     pickle.dump(reps, f)

In [57]:
# Data structure to keep our verbs embeddings organized based on sense and tense
list_of_dicts = [] # will keep dicts of verbs in order of trueverbs, {synset: [rep]}
# for each verb, (for present and past form), collect their CWRs from the model outputs, store them in dicts
for i in range(len(trueverbs)//2):
    present_verb = trueverbs[i*2]
    past_verb = trueverbs[i*2+1]
    present_dict = defaultdict(list)
    past_dict = defaultdict(list)
    # go over all filtered verbs, fetch their CWRs... 
    for j in range(len(filtered_verbs)):
        # present
        if filtered_verbs[j][3] == present_verb:
            sent_ix = translation_dict[filtered_verbs[j][0]]
            wrd_ix = filtered_verbs[j][1]
            present_dict[filtered_verbs[j][2].synset().name()].append(reps[sent_ix][0][wrd_ix].numpy())
        # past
        elif filtered_verbs[j][3] == past_verb:
            sent_ix = translation_dict[filtered_verbs[j][0]]
            wrd_ix = filtered_verbs[j][1]
            past_dict[filtered_verbs[j][2].synset().name()].append(reps[sent_ix][0][wrd_ix].numpy())
    
    list_of_dicts.append(present_dict)
    list_of_dicts.append(past_dict)

## Results

In [58]:
#### Metrics defined in paper

def st_sm(list_of_dicts):
    """
        From list of dicts, get average distance between verb representations of same tense and same meaning 
    """
    # go over all verbs
    means = []
    for i in range(len(list_of_dicts)//2):
        present_verb_dict = list_of_dicts[2*i]
        past_verb_dict = list_of_dicts[2*i+1]

        # go over all of this verbs present tense synsets
        for synset in present_verb_dict.keys():
            verb_representations = present_verb_dict[synset]
            if len(verb_representations) > 1:
                means+= list(pdist(np.array(verb_representations), 'euclidean'))

        # go over all of this verbs past tense synsets
        for synset in past_verb_dict.keys():
            verb_representations = past_verb_dict[synset]
            if len(verb_representations) > 1:
                means+= list(pdist(np.array(verb_representations), 'euclidean'))
                #means.append(pdist(np.array(verb_representations), 'euclidean').mean())

    return np.mean(means)

def st_dm(list_of_dicts):
    """
        From list of dicts, get average distance between verb representations of same tense and different meaning 
    """
    # go over all verbs
    means = []
    for i in range(len(list_of_dicts)//2):
        present_verb_dict = list_of_dicts[2*i]
        past_verb_dict = list_of_dicts[2*i+1]

        # go over all of this verbs present tense synsets
        for synset in present_verb_dict.keys():
            verb_representations = present_verb_dict[synset]
            # get all other synsets represenattions 
            other_synsets = [x for x in present_verb_dict.keys() if x!= synset]
            other_syn_reps = []
            for ss in other_synsets:
                other_syn_reps+= present_verb_dict[ss]
            other_syn_reps = np.array(other_syn_reps)

            
            if len(other_syn_reps.shape) == 2 and len(verb_representations) > 0:
                means.append(cdist(np.array(verb_representations), other_syn_reps, 'euclidean'))
            

        # go over all of this verbs past tense synsets
        for synset in past_verb_dict.keys():
            verb_representations = past_verb_dict[synset]
            # get all other synsets represenattions 
            other_synsets = [x for x in past_verb_dict.keys() if x!= synset]
            other_syn_reps = []
            for ss in other_synsets:
                other_syn_reps+= past_verb_dict[ss]
            other_syn_reps = np.array(other_syn_reps)

            if len(other_syn_reps.shape) == 2 and len(verb_representations) > 0:
                means.append(cdist(np.array(verb_representations), other_syn_reps, 'euclidean'))
                #means+= list(pdist(np.array(verb_representations), 'euclidean'))

    finalels = []
    for arr in means:
        finalels += arr.flatten().tolist()
    return np.mean(finalels)

def dt_sm(list_of_dicts):
    """
        From list of dicts, get average distance between verb representations of different tense and same meaning 
    """
    # go over all verbs
    means = []
    for i in range(len(list_of_dicts)//2):
        present_verb_dict = list_of_dicts[2*i]
        past_verb_dict = list_of_dicts[2*i+1]

        # look at intersection of synsets
        synsets_intersect = list(set(present_verb_dict.keys()).intersection(past_verb_dict.keys()))

        # go over all of this verbs filtered synsets
        for synset in synsets_intersect:
            verb_representations_pres = present_verb_dict[synset]
            verb_representation_past = past_verb_dict[synset]
            
            if  len(verb_representation_past) > 0 and len(verb_representations_pres) > 0:
                means.append(cdist(np.array(verb_representations_pres), np.array(verb_representation_past), 'euclidean'))
            
    finalels = []
    for arr in means:
        finalels += arr.flatten().tolist()
    return np.mean(finalels)

In [None]:
st_sm(list_of_dicts)

In [None]:
st_dm(list_of_dicts)

In [None]:
dt_sm(list_of_dicts)

Combine all representations to measure their isotropy...

In [59]:
all_reps = []
for ix in range(len(list_of_dicts)):
    for key in list_of_dicts[ix].keys():
        all_reps += list_of_dicts[ix][key]

In [34]:
len(all_reps)

14955

In [17]:
isotropy(all_reps)

2.4068748e-05

Repeat process with cluster-based method imrovement (removing dominant components)

In [60]:
# Isotropy calculated based on enhancing all verbs CWRs at once
clstrd_reps = cluster_based(np.array(all_reps), 1, n_pc, emb_len)
isotropy(clstrd_reps)

In [62]:
# Better approach : cluster & enhance based on each verb seperately
new_dicts = []

for i in range(len(list_of_dicts)//2): #len(list_of_dicts)//2
    present_verb_dict = list_of_dicts[2*i]
    past_verb_dict = list_of_dicts[2*i+1]

    combined_reps = [] 
    for key in present_verb_dict.keys():
        combined_reps += present_verb_dict[key]
    for key in past_verb_dict.keys():
        combined_reps += past_verb_dict[key]


    cb_reps = cluster_based(np.array(combined_reps), 1, n_pc, 768)

    present_dict_new = defaultdict(list)
    past_dict_new = defaultdict(list)

    u = 0
    for key in present_verb_dict.keys():
        present_dict_new[key] = []
        for ix in range(len(present_verb_dict[key])):
            present_dict_new[key].append(cb_reps[u])
            u+=1
    for key in past_verb_dict.keys():
        past_dict_new[key] = []
        for ix in range(len(past_verb_dict[key])):
            past_dict_new[key].append(cb_reps[u])
            u+=1
            

    new_dicts.append(present_dict_new)
    new_dicts.append(past_dict_new)

In [63]:
st_sm(new_dicts)

4.102383363457763

In [64]:
st_dm(new_dicts)

4.487159470883429

In [65]:
dt_sm(new_dicts)

4.463559610829546

In [53]:
all_reps_new = []
for ix in range(len(new_dicts)):
    for key in new_dicts[ix].keys():
        all_reps_new += new_dicts[ix][key]

In [None]:
isotropy(all_reps_new)