In [1]:
from transformers import BertTokenizer, BertModel
import torch


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased',num_hidden_layers=1)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.9.attention.self.key.weight', 'bert.encoder.layer.4.output.LayerNorm.bias', 'bert.encoder.layer.10.intermediate.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.4.attention.output.dense.bias', 'bert.encoder.layer.2.intermediate.dense.weight', 'bert.encoder.layer.2.attention.self.key.weight', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.2.attention.self.query.weight', 'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.1.attention.output.dense.bias', 'bert.encoder.layer.2.

In [3]:
import nltk
from nltk.corpus import semcor
from nltk.tokenize import word_tokenize
from nltk.stem import PorterStemmer
from nltk.stem import WordNetLemmatizer

from num2words import num2words
import numpy as np
import tqdm.notebook as tqdm
from string import punctuation
import math
from nltk.corpus import wordnet as wn


lemmatizer = WordNetLemmatizer()

EXTRA_STOPWORDS = ["''", "'s", "``"]
STOPWORDS = nltk.corpus.stopwords.words('english') + EXTRA_STOPWORDS
STOPWORDS += list(punctuation)


def treebank2wn(tag):
    if tag.startswith('J'):
        return wn.ADJ
    elif tag.startswith('V'):
        return wn.VERB
    elif tag.startswith('N'):
        return wn.NOUN
    elif tag.startswith('R'):
        return wn.ADV
    else:
        return None


def syn2sense(syn):
    # get the sense (= lemma.postag.num) for a given synset
    s = syn.name()
    # n = ".".join(s.split(".")[-2:]) # n.01 and v.01 are different senses (eg: ash.n.01, ash.v.01)
    return s


def lemmatize(word, tag):
    if tag is None:
        return lemmatizer.lemmatize(word)
    else:
        return lemmatizer.lemmatize(word, tag)


def num2Word(s):
    if s.isnumeric() and s.lower() != "infinity" and s.lower() != "nan":
        s = num2words(s)
    return s


def clean(tokens):
    tagged = nltk.pos_tag(tokens)
    lemmatized = [lemmatize(w, treebank2wn(tag)) for w, tag in tagged]
    cleaned = [num2Word(w) for w in lemmatized if w.lower() not in STOPWORDS]
    return cleaned


In [4]:
def parse(sent):
    tokens = []
    senses = []

    for e in sent:
        if isinstance(e, nltk.tree.Tree):
            lemma = e.label()
            if isinstance(lemma, nltk.corpus.reader.wordnet.Lemma):
                synset = lemma.synset()
                sense = syn2sense(synset)
            else:
                sense = None
            le = len(e)
            if le == 1:
                w = e[0]
                if isinstance(w, nltk.tree.Tree) or isinstance(w, list):
                    lw = len(w)
                    w = " ".join([w[i] for i in range(lw)])
            else:
                w = " ".join([e[i] for i in range(le)])
        elif isinstance(e, list):
            w = e[0]
            sense = None
        else:
            invtype = type(e)
            raise Exception("Invalid type: %s" % invtype)
        if w:
            tokens.append(w)
            senses.append(sense)
    return tokens, senses


In [5]:
semcor_data = semcor.tagged_sents(tag='sem')

X = []
y = []
words = []
for sent in semcor_data:
    try:
        tokens, senses = parse(sent)
        tagged_tokens = clean(tokens)

        for i in range(len(tokens)):
            context = tagged_tokens[max(
                0, i - 4):i] + tagged_tokens[i + 1:min(len(tagged_tokens), i + 5)]
            context_str = ' '.join(
                [w.lower() for w in context])
            X.append(context_str)
            y.append(senses[i])
    except Exception as e:
        print(e)


In [6]:
# save the data
import pickle
with open('X.pickle', 'wb') as f:
    pickle.dump(X, f)
with open('y.pickle', 'wb') as f:
    pickle.dump(y, f)

In [3]:
# use cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)


cuda


In [None]:
# load the data
import pickle
with open('X.pickle', 'rb') as f:
    X = pickle.load(f)
with open('y.pickle', 'rb') as f:
    y = pickle.load(f)

window_size = max([len(x.split()) for x in X])
X = X[:2000]
for i in range(len(X)):
    tokens = X[i].split()
    if len(tokens) < window_size:
        tokens += ['[PAD]']* (window_size - len(tokens))
    X[i] = ' '.join(tokens)

input_ids = [tokenizer.convert_tokens_to_ids(tokens.split()) for tokens in X]
attention_mask = [[int(token_id != tokenizer.pad_token_id) for token_id in input_ids] for input_ids in input_ids]

input_ids = torch.tensor(input_ids).to(device)
attention_mask = torch.tensor(attention_mask).to(device)
with torch.no_grad():
    output = model(input_ids, attention_mask=attention_mask)

In [9]:
label_dict = {label: idx for idx, label in enumerate(set(y[:1000]))}
y = [label_dict[label] for label in y[:1000]]
print(input_ids.shape)
print(len(y))

torch.Size([1000, 34])
1000


In [11]:
import numpy as np
from sklearn.naive_bayes import GaussianNB

# convert BERT output to numpy array
bert_output = output[0][:, 0, :].cpu().numpy()

# train-test split
split_idx = int(len(bert_output)*0.8)
train_x, test_x = bert_output[:split_idx], bert_output[split_idx:]
train_y, test_y = y[:split_idx], y[split_idx:]

# fit Naive Bayes model
clf = GaussianNB()
clf.fit(train_x, train_y)

# evaluate model
acc = clf.score(test_x, test_y)
print('Accuracy:', acc)


Accuracy: 0.325
