In [1]:
import os
import tqdm
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from utils import check_accuracy_classification
import transformers
from torch.optim import Adam
from models import BertProbeClassifer
from utils import text_to_dataloader, tokenize_word
from bert_embedding import BertEmbeddingExtractor

In [2]:
#%load_ext autoreload
#%autoreload 2

In [3]:
train_path = os.path.join("data","en_partut-ud-train.conllu")
dev_path = os.path.join("data","en_partut-ud-dev.conllu")
test_path = os.path.join("data","en_partut-ud-test.conllu")

In [4]:
HEADER_CONST = "# sent_id = "
TEXT_CONST = "# text = "
STOP_CONST = "\n"
WORD_OFFSET = 1
LABEL_OFFSET = 3


def txt_to_dataframe(data_path):
    '''
    read UD text file and convert to df format
    '''
    with open(data_path, "r") as fp:
        df = pd.DataFrame(
            columns={
                "text",
                "word",
                "label"
            }
        )
        for line in fp.readlines():
            if TEXT_CONST in line:
                words_list = []
                labels_list = []
                text = line.split(TEXT_CONST)[1]
                # this is a new text, need to parse all the words in it
            elif line is not STOP_CONST and HEADER_CONST not in line:
                temp_list = line.split("\t")
                words_list.append(temp_list[WORD_OFFSET])
                labels_list.append(temp_list[LABEL_OFFSET])
            if line == STOP_CONST:
                # this is the end of the text, adding to df
                cur_df = pd.DataFrame(
                    {
                        "text": len(words_list) * [text],
                        "word": words_list,
                        "label": labels_list
                    }
                )
                df = pd.concat([df,cur_df])
        return df
            


In [5]:
df_train = txt_to_dataframe(train_path)
df_dev = txt_to_dataframe(dev_path)
df_test = txt_to_dataframe(test_path)

In [6]:
TYPES = [
    "ADJ",
    "ADP",
    "ADV",
    "AUX",
    "CCONJ",
    "DET",
    "INTJ",
    "NOUN",
    "NUM",
    "PART",
    "PRON",
    "PROPN",
    "PUNCT",
    "SCONJ",
    "SYM",
    "VERB",
    "X",
    "_"
]

In [7]:
file_name = 'tex_artifacts/label_dist_train.tex'
SORT_COL = "Count"

with open(file_name,'w') as tf:
    display_df = df_train["label"].value_counts().rename_axis("Type").to_frame("Count").reset_index()
    #display_df.index = TYPES
    display_df.sort_values(by=SORT_COL, inplace=True, ascending=False)
    latex_data = display_df.to_latex(index=False)
    tf.write(latex_data)

In [8]:
file_name = 'tex_artifacts/label_dist_dev.tex'


with open(file_name,'w') as tf:
    display_df = df_dev["label"].value_counts().rename_axis("Type").to_frame("Count").reset_index()
    #display_df.index = TYPES
    display_df.sort_values(by="Type", inplace=True)
    latex_data = display_df.to_latex(index=False)
    tf.write(latex_data)

In [9]:
file_name = 'tex_artifacts/label_dist_test.tex'


with open(file_name,'w') as tf:
    display_df = df_test["label"].value_counts().rename_axis("Type").to_frame("Count").reset_index()
    #display_df.index = TYPES
    display_df.sort_values(by="Type", inplace=True)
    latex_data = display_df.to_latex(index=False)
    tf.write(latex_data)

In [10]:
df_test[df_test["label"] == "VERB"]

Unnamed: 0,text,word,label
8,Any use of the work other than as authorized u...,authorized,VERB
16,Any use of the work other than as authorized u...,prohibited,VERB
2,UNLESS OTHERWISE AGREED TO BY THE PARTIES IN W...,AGREED,VERB
11,UNLESS OTHERWISE AGREED TO BY THE PARTIES IN W...,OFFERS,VERB
16,UNLESS OTHERWISE AGREED TO BY THE PARTIES IN W...,IS,VERB
...,...,...,...
12,"In the 18th and 19th centuries, his reputation...",spread,VERB
6,Only a small minority of academics believe the...,believe,VERB
8,Only a small minority of academics believe the...,is,VERB
11,Only a small minority of academics believe the...,question,VERB


In [11]:
bert_tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")

In [12]:
df_train, dataloader_train = text_to_dataloader(df_train, "cuda", 32, bert_tokenizer, 256)
df_test, dataloader_test = text_to_dataloader(df_test, "cuda", 32, bert_tokenizer, 256)

In [14]:
file_name = 'tex_artifacts/tokens_per_word_dist_train.tex'

INDEX_AXIS_NAME = "Tokens/Word"
SORT_COL = "Tokens/Word"

with open(file_name,'w') as tf:
    display_df = df_train["query_mask"].apply(lambda x: sum(x)).value_counts().rename_axis(INDEX_AXIS_NAME).to_frame("Count").reset_index()
    display_df.sort_values(by=SORT_COL, inplace=True)
    latex_data = display_df.to_latex(index=False)
    tf.write(latex_data)


In [21]:
embedding_df_9.sample(10)

Unnamed: 0,word,label_idx,embedding,label,embedding_shape
30137,and,4,"[0.15574118, 1.0346084, 0.5065768, 0.45339814,...",CCONJ,"(768,)"
17961,the,5,"[-0.28361148, -0.43603727, 0.08641514, 0.36970...",DET,"(768,)"
39003,shakespeare,11,"[0.06348628, -0.30637914, 0.92618424, -0.28851...",PROPN,"(768,)"
32187,his,5,"[-0.39322716, 0.4311395, 0.48470542, 0.2271323...",DET,"(768,)"
32078,had,15,"[-1.1276822, 0.06973337, -0.034231067, -0.3327...",VERB,"(768,)"
13290,it,10,"[-0.1944471, 0.23595017, -0.3196561, -0.541352...",PRON,"(768,)"
22870,economy,7,"[0.036572564, 0.16637069, 0.23336361, 0.303162...",NOUN,"(768,)"
3224,when,13,"[-1.1226387, 0.61780995, -0.021228101, 0.22293...",SCONJ,"(768,)"
19222,the,5,"[-0.17872518, 0.046215557, 0.056803722, 0.0912...",DET,"(768,)"
23037,in,1,"[-0.3867881, 0.8922424, -0.4306744, -0.8759147...",ADP,"(768,)"


In [13]:
num_hidden_layers = 9


bex_9 = BertEmbeddingExtractor(num_hidden_layers, "bert-base-uncased")
embedding_df_9 = bex_9.extract_embedding(dataloader_train, "sum")

100%|██████████| 1356/1356 [06:41<00:00,  3.37it/s]


In [47]:
num_hidden_layers = 4


bex_4 = BertEmbeddingExtractor(num_hidden_layers, "bert-base-uncased")
embedding_df_4 = bex_9.extract_embedding(dataloader_train, "sum")

100%|██████████| 1356/1356 [06:40<00:00,  3.38it/s]


In [77]:
import umap

# query_df = embedding_df_4
query_df = embedding_df_9

contextual_embedding_array = np.vstack(query_df["embedding"].values)

reducer = umap.UMAP()
lower_dim_data = reducer.fit_transform(
    contextual_embedding_array,
    y=query_df["label_idx"].tolist()
)

In [81]:
import matplotlib.pyplot as plt
from pylab import cm
import mplcursors

%matplotlib qt

word_list = list(query_df["word"].tolist())
all_labels = query_df["label"].tolist()
labels = list(set(all_labels))
labels.sort()
n_colors = len(labels)


#create new colormap
cmap = cm.get_cmap('tab20', n_colors)


print(n_colors)


fig, ax = plt.subplots(figsize=(10,10))

sc = plt.scatter(
    lower_dim_data[:,0], 
    lower_dim_data[:,1], 
    c=query_df["label_idx"].tolist(),
    cmap=cmap,
    s=1
)

# cursor
crs = mplcursors.cursor(ax,hover=True)
crs.connect(
    "add", 
    lambda sel: sel.annotation.set_text(
        f"{word_list[sel.target.index]}\n{all_labels[sel.target.index]}"
    ))
    

# colorbar
c_ticks = np.arange(n_colors) * (n_colors / (n_colors + 1)) + (2 / n_colors)
cbar = plt.colorbar(sc, ticks=c_ticks)
#cbar = plt.colorbar()

ticklabs = cbar.ax.get_yticklabels()
cbar.ax.set_yticklabels(labels, ha="right")
cbar.ax.yaxis.set_tick_params(pad=40)
plt.show()

16


In [27]:
labels

(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)

In [None]:
set(display_labels)

In [None]:
display_labels