In [1]:
import os
import tqdm
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from utils import check_accuracy_classification
import transformers
from torch.optim import Adam
from models import BertProbeClassifer
from utils import text_to_dataloader, tokenize_word
from bert_embedding import BertEmbeddingExtractor

In [2]:
#%load_ext autoreload
#%autoreload 2

In [3]:
train_path = os.path.join("data","en_partut-ud-train.conllu")
dev_path = os.path.join("data","en_partut-ud-dev.conllu")
test_path = os.path.join("data","en_partut-ud-test.conllu")

In [4]:
HEADER_CONST = "# sent_id = "
TEXT_CONST = "# text = "
STOP_CONST = "\n"
WORD_OFFSET = 1
LABEL_OFFSET = 3


def txt_to_dataframe(data_path):
    '''
    read UD text file and convert to df format
    '''
    with open(data_path, "r") as fp:
        df = pd.DataFrame(
            columns={
                "text",
                "word",
                "label"
            }
        )
        for line in fp.readlines():
            if TEXT_CONST in line:
                words_list = []
                labels_list = []
                text = line.split(TEXT_CONST)[1]
                # this is a new text, need to parse all the words in it
            elif line is not STOP_CONST and HEADER_CONST not in line:
                temp_list = line.split("\t")
                words_list.append(temp_list[WORD_OFFSET])
                labels_list.append(temp_list[LABEL_OFFSET])
            if line == STOP_CONST:
                # this is the end of the text, adding to df
                cur_df = pd.DataFrame(
                    {
                        "text": len(words_list) * [text],
                        "word": words_list,
                        "label": labels_list
                    }
                )
                df = pd.concat([df,cur_df])
        return df
            


In [5]:
df_train = txt_to_dataframe(train_path)
df_dev = txt_to_dataframe(dev_path)
df_test = txt_to_dataframe(test_path)

In [6]:
TYPES = [
    "ADJ",
    "ADP",
    "ADV",
    "AUX",
    "CCONJ",
    "DET",
    "INTJ",
    "NOUN",
    "NUM",
    "PART",
    "PRON",
    "PROPN",
    "PUNCT",
    "SCONJ",
    "SYM",
    "VERB",
    "X",
    "_"
]

In [7]:
file_name = 'tex_artifacts/label_dist_train.tex'
SORT_COL = "Count"

with open(file_name,'w') as tf:
    display_df = df_train["label"].value_counts().rename_axis("Type").to_frame("Count").reset_index()
    #display_df.index = TYPES
    display_df.sort_values(by=SORT_COL, inplace=True, ascending=False)
    latex_data = display_df.to_latex(index=False)
    tf.write(latex_data)

In [8]:
file_name = 'tex_artifacts/label_dist_dev.tex'


with open(file_name,'w') as tf:
    display_df = df_dev["label"].value_counts().rename_axis("Type").to_frame("Count").reset_index()
    #display_df.index = TYPES
    display_df.sort_values(by="Type", inplace=True)
    latex_data = display_df.to_latex(index=False)
    tf.write(latex_data)

In [9]:
file_name = 'tex_artifacts/label_dist_test.tex'


with open(file_name,'w') as tf:
    display_df = df_test["label"].value_counts().rename_axis("Type").to_frame("Count").reset_index()
    #display_df.index = TYPES
    display_df.sort_values(by="Type", inplace=True)
    latex_data = display_df.to_latex(index=False)
    tf.write(latex_data)

In [10]:
df_test[df_test["label"] == "VERB"]

Unnamed: 0,label,word,text
8,VERB,authorized,Any use of the work other than as authorized u...
16,VERB,prohibited,Any use of the work other than as authorized u...
2,VERB,AGREED,UNLESS OTHERWISE AGREED TO BY THE PARTIES IN W...
11,VERB,OFFERS,UNLESS OTHERWISE AGREED TO BY THE PARTIES IN W...
16,VERB,IS,UNLESS OTHERWISE AGREED TO BY THE PARTIES IN W...
...,...,...,...
12,VERB,spread,"In the 18th and 19th centuries, his reputation..."
6,VERB,believe,Only a small minority of academics believe the...
8,VERB,is,Only a small minority of academics believe the...
11,VERB,question,Only a small minority of academics believe the...


In [11]:
bert_tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")

In [12]:
df_train, dataloader_train = text_to_dataloader(df_train, "cuda", 32, bert_tokenizer, 256)
df_test, dataloader_test = text_to_dataloader(df_test, "cuda", 32, bert_tokenizer, 256)

In [13]:
file_name = 'tex_artifacts/tokens_per_word_dist_train.tex'

INDEX_AXIS_NAME = "Tokens/Word"
SORT_COL = "Tokens/Word"

with open(file_name,'w') as tf:
    display_df = df_train["query_mask"].apply(lambda x: sum(x)).value_counts().rename_axis(INDEX_AXIS_NAME).to_frame("Count").reset_index()
    display_df.sort_values(by=SORT_COL, inplace=True)
    latex_data = display_df.to_latex(index=False)
    tf.write(latex_data)


In [14]:
import warnings
warnings.filterwarnings(action="ignore")


# extract bert base embedding
for i in [1,2,3,4,5,6,7,8,9,10,11,12]:
    bex = BertEmbeddingExtractor(i, "bert-base-uncased")
    embedding_df = bex.extract_embedding(dataloader_train, "sum")
    save_path = os.path.join("bert_embeddings", f"bert_base_embedding_layer_{i}")
    embedding_df.to_csv(save_path, index=False)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['bert.encoder.layer.1.attention.self.query.weight', 'bert.encoder.layer.1.attention.self.query.bias', 'bert.encoder.layer.1.attention.self.key.weight', 'bert.encoder.layer.1.attention.self.key.bias', 'bert.encoder.layer.1.attention.self.value.weight', 'bert.encoder.layer.1.attention.self.value.bias', 'bert.encoder.layer.1.attention.output.dense.weight', 'bert.encoder.layer.1.attention.output.dense.bias', 'bert.encoder.layer.1.intermediate.dense.weight', 'bert.encoder.layer.1.intermediate.dense.bias', 'bert.encoder.layer.1.output.dense.weight', 'bert.encoder.layer.1.output.dense.bias', 'bert.encoder.layer.2.attention.self.query.weight', 'bert.encoder.layer.2.attention.self.query.bias', 'bert.encoder.layer.2.attention.self.key.weight', 'bert.encoder.layer.2.attention.self.key.bias', 'bert.encoder.layer.2.attention.self.value.weight', 'bert.encoder.layer.2.attention.self.value.bias', 'ber

100%|██████████| 1356/1356 [00:40<00:00, 33.87it/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['bert.encoder.layer.2.attention.self.query.weight', 'bert.encoder.layer.2.attention.self.query.bias', 'bert.encoder.layer.2.attention.self.key.weight', 'bert.encoder.layer.2.attention.self.key.bias', 'bert.encoder.layer.2.attention.self.value.weight', 'bert.encoder.layer.2.attention.self.value.bias', 'bert.encoder.layer.2.attention.output.dense.weight', 'bert.encoder.layer.2.attention.output.dense.bias', 'bert.encoder.layer.2.intermediate.dense.weight', 'bert.encoder.layer.2.intermediate.dense.bias', 'bert.encoder.layer.2.output.dense.weight', 'bert.encoder.layer.2.output.dense.bias', 'bert.encoder.layer.3.attention.self.query.weight', 'bert.encoder.layer.3.attention.self.query.bias', 'bert.encoder.layer.3.attention.self.key.weight', 'bert.encoder.layer.3.attention.self.key.bias', 'bert.encoder.layer.3.attention.self.value.weight', 'b

100%|██████████| 1356/1356 [01:13<00:00, 18.34it/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['bert.encoder.layer.3.attention.self.query.weight', 'bert.encoder.layer.3.attention.self.query.bias', 'bert.encoder.layer.3.attention.self.key.weight', 'bert.encoder.layer.3.attention.self.key.bias', 'bert.encoder.layer.3.attention.self.value.weight', 'bert.encoder.layer.3.attention.self.value.bias', 'bert.encoder.layer.3.attention.output.dense.weight', 'bert.encoder.layer.3.attention.output.dense.bias', 'bert.encoder.layer.3.intermediate.dense.weight', 'bert.encoder.layer.3.intermediate.dense.bias', 'bert.encoder.layer.3.output.dense.weight', 'bert.encoder.layer.3.output.dense.bias', 'bert.encoder.layer.4.attention.self.query.weight', 'bert.encoder.layer.4.attention.self.query.bias', 'bert.encoder.layer.4.attention.self.key.weight', 'bert.encoder.layer.4.attention.self.key.bias', 'bert.encoder.layer.4.attention.self.value.weight', 'b

100%|██████████| 1356/1356 [02:15<00:00,  9.99it/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.5.output.dense.bias', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.6.attention.self.key.bias', 'bert.encoder.layer.6.attention.self.value.weight', 'b

100%|██████████| 1356/1356 [03:26<00:00,  6.57it/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.7.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.7.output.dense.bias', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.8.attention.self.query.bias', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.8.attention.self.key.bias', 'bert.encoder.layer.8.attention.self.value.weight', 'b

100%|██████████| 1356/1356 [04:28<00:00,  5.04it/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.9.attention.self.query.bias', 'bert.encoder.layer.9.attention.self.key.weight', 'bert.encoder.layer.9.attention.self.key.bias', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.9.attention.self.value.bias', 'bert.encoder.layer.9.attention.output.dense.weight', 'bert.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.9.intermediate.dense.bias', 'bert.encoder.layer.9.output.dense.weight', 'bert.encoder.layer.9.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.10.attention.self.value.weight

In [41]:
import umap


chosen_bert_layer = 1
bert_results_path = os.path.join("bert_embeddings", f"bert_base_embedding_layer_{chosen_bert_layer}")

query_df = pd.read_csv(
    bert_results_path
)

query_df.dropna(inplace=True)

embedding_columns_list = [str(i) for i in list(range(768))]
contextual_embedding_array = query_df[embedding_columns_list].to_numpy()

reducer = umap.UMAP()
lower_dim_data = reducer.fit_transform(
    contextual_embedding_array,
)

In [42]:
lower_dim_data.shape

(43323, 2)

In [43]:
import matplotlib.pyplot as plt
from pylab import cm
import mplcursors

%matplotlib qt

word_list = list(query_df["word"].tolist())
all_labels = query_df["label"].tolist()
labels = list(set(all_labels))
labels.sort()
n_colors = len(labels)


#create new colormap
cmap = cm.get_cmap('tab20', n_colors)


print(n_colors)


fig, ax = plt.subplots(figsize=(10,10))

sc = plt.scatter(
    lower_dim_data[:,0], 
    lower_dim_data[:,1], 
    c=query_df["label_idx"].tolist(),
    cmap=cmap,
    s=1
)

# cursor
crs = mplcursors.cursor(ax,hover=True)
crs.connect(
    "add", 
    lambda sel: sel.annotation.set_text(
        f"{word_list[sel.target.index]}\n{all_labels[sel.target.index]}"
    ))
    

# colorbar
c_ticks = np.arange(n_colors) * (n_colors / (n_colors + 1)) + (2 / n_colors)
cbar = plt.colorbar(sc, ticks=c_ticks)
#cbar = plt.colorbar()

ticklabs = cbar.ax.get_yticklabels()
cbar.ax.set_yticklabels(labels, ha="right")
cbar.ax.yaxis.set_tick_params(pad=40)
plt.show()

16


In [None]:
labels

In [None]:
set(display_labels)

In [None]:
display_labels