In [15]:
import os
import tqdm
import torch
import pandas as pd
import torch.nn as nn
from utils import check_accuracy_classification
import transformers
from torch.optim import Adam
from models import BertProbeClassifer
from utils import text_to_dataloader, tokenize_word

In [2]:
train_path = os.path.join("data","en_partut-ud-train.conllu")
dev_path = os.path.join("data","en_partut-ud-dev.conllu")
test_path = os.path.join("data","en_partut-ud-test.conllu")

In [3]:
HEADER_CONST = "# sent_id = "
TEXT_CONST = "# text = "
STOP_CONST = "\n"
WORD_OFFSET = 1
LABEL_OFFSET = 3


def txt_to_dataframe(data_path):
    '''
    read UD text file and convert to df format
    '''
    with open(data_path, "r") as fp:
        df = pd.DataFrame(
            columns={
                "text",
                "word",
                "label"
            }
        )
        for line in fp.readlines():
            if TEXT_CONST in line:
                words_list = []
                labels_list = []
                text = line.split(TEXT_CONST)[1]
                # this is a new text, need to parse all the words in it
            elif line is not STOP_CONST and HEADER_CONST not in line:
                temp_list = line.split("\t")
                words_list.append(temp_list[WORD_OFFSET])
                labels_list.append(temp_list[LABEL_OFFSET])
            if line == STOP_CONST:
                # this is the end of the text, adding to df
                cur_df = pd.DataFrame(
                    {
                        "text": len(words_list) * [text],
                        "word": words_list,
                        "label": labels_list
                    }
                )
                df = pd.concat([df,cur_df])
        return df
            


In [4]:
df_train = txt_to_dataframe(train_path)
df_dev = txt_to_dataframe(dev_path)
df_test = txt_to_dataframe(test_path)

In [5]:
df_train["label"].value_counts()

NOUN     9249
ADP      5220
PUNCT    5105
DET      4616
VERB     4126
ADJ      3410
AUX      2076
PROPN    2033
PRON     1734
ADV      1707
CCONJ    1472
PART     1168
NUM       787
SCONJ     627
X         140
SYM        42
_          27
INTJ        6
Name: label, dtype: int64

In [6]:
df_dev["label"].value_counts()

NOUN     568
PUNCT    353
ADP      297
VERB     276
DET      266
ADJ      210
PRON     153
AUX      124
ADV      108
PROPN    107
CCONJ     88
NUM       60
PART      56
SCONJ     41
X         13
SYM        2
_          1
Name: label, dtype: int64

In [7]:
df_test["label"].value_counts()

NOUN     753
ADP      488
DET      439
PUNCT    339
VERB     326
AUX      234
ADJ      224
ADV      131
PRON     106
CCONJ     96
PROPN     90
PART      66
NUM       61
SCONJ     51
_          4
INTJ       2
X          2
Name: label, dtype: int64

In [85]:
df_test[df_test["label"] == "VERB"]

Unnamed: 0,word,text,label,label_idx,text_ids,attn_mask,query_mask
8,authorized,Any use of the work other than as authorized u...,VERB,15,"[101, 2151, 2224, 1997, 1996, 2147, 2060, 2084...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ..."
16,prohibited,Any use of the work other than as authorized u...,VERB,15,"[101, 2151, 2224, 1997, 1996, 2147, 2060, 2084...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,AGREED,UNLESS OTHERWISE AGREED TO BY THE PARTIES IN W...,VERB,15,"[101, 4983, 4728, 3530, 2000, 2011, 1996, 4243...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
11,OFFERS,UNLESS OTHERWISE AGREED TO BY THE PARTIES IN W...,VERB,15,"[101, 4983, 4728, 3530, 2000, 2011, 1996, 4243...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."
16,IS,UNLESS OTHERWISE AGREED TO BY THE PARTIES IN W...,VERB,15,"[101, 4983, 4728, 3530, 2000, 2011, 1996, 4243...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...
12,spread,"In the 18th and 19th centuries, his reputation...",VERB,15,"[101, 1999, 1996, 4985, 1998, 3708, 4693, 1010...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ..."
6,believe,Only a small minority of academics believe the...,VERB,15,"[101, 2069, 1037, 2235, 7162, 1997, 15032, 290...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ..."
8,is,Only a small minority of academics believe the...,VERB,15,"[101, 2069, 1037, 2235, 7162, 1997, 15032, 290...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ..."
11,question,Only a small minority of academics believe the...,VERB,15,"[101, 2069, 1037, 2235, 7162, 1997, 15032, 290...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ..."


In [8]:
bert_tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")

In [9]:
dataloader_train = text_to_dataloader(df_train, "cuda", 32, bert_tokenizer, 256)
dataloader_test = text_to_dataloader(df_test, "cuda", 32, bert_tokenizer, 256)

In [10]:
dataloader_train

<torch.utils.data.dataloader.DataLoader at 0x7f0f2357e940>

In [11]:
df_train["query_mask"].apply(lambda x: sum(x)).value_counts()
df_train["text_from sklearn.preprocessing import OneHotEncoderlen"] = df_train["attn_mask"].apply(lambda x: sum(x))

In [12]:
df_train.sample(10)

Unnamed: 0,word,text,label,label_idx,text_ids,attn_mask,query_mask,text_from sklearn.preprocessing import OneHotEncoderlen
4,press,Public opinion and the press nowadays accuse u...,NOUN,7,"[101, 2270, 5448, 1998, 1996, 2811, 13367, 269...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",33
21,drama,He wrote them in a stylised language that does...,NOUN,7,"[101, 2002, 2626, 2068, 1999, 1037, 2358, 8516...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",26
32,belief,this right includes freedom to change his reli...,NOUN,7,"[101, 2023, 2157, 2950, 4071, 2000, 2689, 2010...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...",45
8,Korea,"Over the last half-century, South Korea has ma...",PROPN,11,"[101, 2058, 1996, 2197, 2431, 1011, 2301, 1010...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...",39
17,to,(31) For the purpose of definitions used in th...,ADP,1,"[101, 1006, 2861, 1007, 2005, 1996, 3800, 1997...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",31
16,efficiency,(34) Since Council Directive 92/42/EEC of 21 M...,NOUN,7,"[101, 1006, 4090, 1007, 2144, 2473, 16449, 622...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",117
11,requirements,Those having to deal with these risks should t...,NOUN,7,"[101, 2216, 2383, 2000, 3066, 2007, 2122, 1083...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...",15
5,the,(4) Everyone has the right to form and to join...,DET,5,"[101, 1006, 1018, 1007, 3071, 2038, 1996, 2157...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...",22
10,issues,That is why I want to highlight some of the is...,NOUN,7,"[101, 2008, 2003, 2339, 1045, 2215, 2000, 1294...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...",23
5,the,Ukraine will stand up to the bully – on our ow...,DET,5,"[101, 5924, 2097, 3233, 2039, 2000, 1996, 2071...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...",16


In [None]:
df_train[df_train["query_mask"].apply(lambda x: sum(x)) == 0]

In [21]:
bert_model = transformers.BertModel.from_pretrained("bert-base-uncased")
bert_model.to("cuda")

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [29]:
labels = []
contextual_embeddings = []

with torch.no_grad():
    for batch in tqdm.tqdm(dataloader_train):
        text, mask, _, batch_labels = batch
        batch_contextual_embeddings, _ = bert_model(text, mask)
        
        # need to get the embeddings of only masked words
        
        print(batch_labels)
        batch_contextual_embeddings.reshape([-1, 768])

  0%|          | 1/1361 [00:00<13:27,  1.69it/s]

tensor([ 7,  1,  5,  7,  3,  9, 15,  5,  7, 12,  7,  7, 12, 11, 11, 15,  5,  7,
         1,  5, 12,  1, 12, 15, 12,  7, 12, 11, 11, 15,  5,  7],
       device='cuda:0')


  0%|          | 2/1361 [00:01<13:18,  1.70it/s]

tensor([15,  5,  7, 15, 12,  4, 15,  7,  1,  7, 15,  1,  5,  7, 12,  7, 12,  5,
         7,  3, 15,  1,  7,  4, 12,  4,  0,  0,  7, 12,  1, 15],
       device='cuda:0')


  0%|          | 3/1361 [00:01<13:17,  1.70it/s]

tensor([ 5,  7,  1,  5,  7, 15,  2, 12, 10, 15,  4, 15,  9,  3, 15,  1,  5,  7,
         1,  5,  7, 12,  5,  7, 15, 10,  5,  7, 15,  2,  1,  7],
       device='cuda:0')


  0%|          | 4/1361 [00:02<13:15,  1.71it/s]

tensor([ 1,  5,  7,  1,  5,  7,  4,  7, 12,  5,  7, 10, 15,  5,  0,  7,  3,  9,
         3, 15,  5,  0,  7, 12,  1, 15,  2, 12,  1,  5,  7,  1],
       device='cuda:0')


  0%|          | 5/1361 [00:02<13:13,  1.71it/s]

tensor([ 5,  7, 12, 12, 11, 12, 15,  5,  0,  7,  1,  7, 15,  1,  5,  7,  1,  5,
         7, 12,  7, 12, 11, 12,  8, 12,  0,  7,  7, 12,  8, 12],
       device='cuda:0')


  0%|          | 6/1361 [00:03<13:14,  1.71it/s]

tensor([ 7,  7, 12,  0,  1,  5,  7,  4,  7,  1,  5,  7, 12, 11,  2, 15, 10,  5,
         0, 12,  7, 12,  0, 12,  2, 12,  0, 12,  0, 12,  1,  5],
       device='cuda:0')


  1%|          | 7/1361 [00:04<13:15,  1.70it/s]

tensor([ 7,  1,  5,  0,  7, 12,  7,  9, 15,  5,  7,  1,  5,  7,  1, 15,  2, 12,
         9, 15,  5,  7, 12,  9, 15,  5,  7,  1,  8,  4,  2,  0],
       device='cuda:0')


  1%|          | 8/1361 [00:04<13:09,  1.71it/s]

tensor([ 7, 12,  4,  9, 15,  5,  7,  1, 15,  1,  5,  0,  7, 12,  9, 15,  4, 15,
         0,  7, 12,  9, 15,  7,  4,  7,  1, 12, 15,  2, 12, 15],
       device='cuda:0')


  1%|          | 8/1361 [00:05<14:48,  1.52it/s]


KeyboardInterrupt: 

In [39]:
np.allclose(
    batch_contextual_embeddings[0,255,:].detach().cpu().numpy(), 
    batch_contextual_embeddings[0,250,:].detach().cpu().numpy()
)

False

In [68]:
batch_contextual_embeddings.shape

torch.Size([32, 256, 768])

In [72]:
mask.reshape([32,256,1]).shape

torch.Size([32, 256, 1])

In [65]:
a = torch.nonzero(mask)
batch_contextual_embeddings[:,  ,:].shape
torch.mul(contextual_embeddings, mask)

torch.Size([32, 32, 256, 768])

In [79]:
torch.mul(batch_contextual_embeddings, mask.reshape([32,256,1])).nonzero()

tensor([[  0,   0,   0],
        [  0,   0,   1],
        [  0,   0,   2],
        ...,
        [ 31,  30, 765],
        [ 31,  30, 766],
        [ 31,  30, 767]], device='cuda:0')