In [1]:
!pip install -q transformers

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m43.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m28.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m98.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m35.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:

import logging
logger = logging.getLogger('my_logger')

# Remove all handlers associated with the root logger object.
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    filename='app.log', # write to this file
    filemode='a', # open in append mode
    format='%(name)s - %(levelname)s - %(message)s'
    )

logging.getLogger().setLevel(logging.INFO)
logging.getLogger().setLevel(logging.DEBUG)

logging.getLogger("urllib3.connectionpool").disabled = True
logging.getLogger("filelock").disabled = True


In [3]:
from collections import defaultdict, Counter
import re
import os,time
from operator import itemgetter
import random
import json
import pathlib
import subprocess
import sys

import tqdm



class ConllEntry:
    def __init__(self, id, form, lemma, pos, cpos, feats=None, parent_id=None, relation=None,
        deps=None, misc=None):

        self.id = id
        self.form = form
        self.cpos = cpos
        self.pos = pos
        self.parent_id = parent_id
        self.relation = relation

        self.lemma = lemma
        self.feats = feats
        self.deps = deps
        self.misc = misc

        self.pred_parent_id = None
        self.pred_relation = None

        self.pred_pos = None
        self.pred_cpos = None


    def __str__(self):
        '''values = [str(self.id), self.form, self.lemma, \
                  self.pred_cpos if self.pred_cpos else self.cpos,\
                  self.pred_pos if self.pred_pos else self.pos,\
                  self.feats, str(self.pred_parent_id) if self.pred_parent_id \
                  is not None else str(self.parent_id), self.pred_relation if\
                  self.pred_relation is not None else self.relation, \
                  self.deps, self.misc]
        return '\t'.join(['_' if v is None else v for v in values])'''
        return self.form + " " + str(self.id)

class ParseForest:
    def __init__(self, sentence):
        self.roots = list(sentence)

        for root in self.roots:
            root.children = []
            root.scores = None # TODO: зачем?
            root.parent = None
            root.pred_parent_id = None
            root.pred_relation = None
            root.vecs = None

    def __len__(self):
        return len(self.roots)


    def Attach(self, parent_index, child_index):
        parent = self.roots[parent_index]
        child = self.roots[child_index]

        child.pred_parent_id = parent.id
        del self.roots[child_index]

    def __str__(self):
        return " ".join(map(str, self.roots))


def isProj(sentence):
    forest = ParseForest(sentence)
    unassigned = {entry.id: sum([1 for pentry in sentence if pentry.parent_id == entry.id]) for entry in sentence}

    for _ in xrange(len(sentence)):
        for i in xrange(len(forest.roots) - 1):
            if forest.roots[i].parent_id == forest.roots[i+1].id and unassigned[forest.roots[i].id] == 0:
                unassigned[forest.roots[i+1].id]-=1
                forest.Attach(i+1, i)
                break
            if forest.roots[i+1].parent_id == forest.roots[i].id and unassigned[forest.roots[i+1].id] == 0:
                unassigned[forest.roots[i].id]-=1
                forest.Attach(i, i+1)
                break

    return len(forest.roots) == 1


def get_irels(data):
    """
    Collect frequencies of words, cpos, pos and deprels + languages.
    """

    # could use sets directly rather than counters for most of these,
    # but having the counts might be useful in the future or possibly for debugging etc
    relCount = Counter()

    for sentence in data:
        for node in sentence:
            if isinstance(node, ConllEntry):
                relCount.update([node.relation])

    return list(relCount.keys())


def generate_root_token():
    return ConllEntry(0, '*root*', '*root*', 'ROOT-POS', 'ROOT-CPOS', '_', -1,
        'rroot', '_', '_')


def read_conll(filename, drop_nproj=False, train=True):
    fh = open(filename,'r',encoding='utf-8')
    logging.info(f"Reading {filename}")
    ts = time.time()
    dropped = 0
    sents_read = 0
    sentences = []
    tokens = [generate_root_token()]
    words = [] # all words from the dataset
    for line in fh:
        tok = line.strip().split('\t')
        if not tok or line.strip() == '': # empty line, add sentence to list or yield
            if len(tokens) > 1:
                sents_read += 1
                conll_tokens = [t for t in tokens if isinstance(t,ConllEntry)]
                if not drop_nproj or isProj(conll_tokens):
                    # keep going if it's projective or we're not dropping non-projective sents
                    if train:
                        inorder_tokens = inorder(conll_tokens)
                        for i,t in enumerate(inorder_tokens):
                            t.projective_order = i
                        for tok in conll_tokens:
                            tok.rdeps = [i.id for i in conll_tokens if i.parent_id == tok.id]
                            if tok.id != 0:
                                tok.parent_entry = [i for i in conll_tokens if i.id == tok.parent_id][0]
                    sentences.append(tokens)
                else:
                    logging.debug('Non-projective sentence dropped')
                    dropped += 1
            tokens = [generate_root_token()]
        else:
            if line[0] == '#' or '-' in tok[0] or '.' in tok[0]: # a comment line, add to tokens as is
                tokens.append(line.strip())
            else: # an actual ConllEntry, add to tokens
                if tok[2] == "_":
                    tok[2] = tok[1].lower()
                lemma = tok[2]
                words.append(lemma)
                token = ConllEntry(int(tok[0]), tok[1], lemma, tok[4], tok[3], tok[5], \
                    int(tok[6]) if tok[6] != '_' else -1, tok[7], tok[8], tok[9])

                tokens.append(token)

# deal with case where there are still tokens, that aren`t in sentences list
# e.g. when there is no newline at end of file
    if len(tokens) > 1:
        sentences.append(tokens)

    logging.debug(f'{sents_read} sentences read')

    te = time.time()
    logging.info(f'Time: {te-ts:.2g}s')
    return sentences, words


def write_conll(fn, conll_gen):
    logging.info(f"Writing to {fn}")
    sents = 0
    with open(fn, 'w', encoding='utf-8') as fh:
        for sentence in conll_gen:
            sents += 1
            for entry in sentence[1:]:
                fh.write(str(entry) + '\n')
            fh.write('\n')
        logging.debug(f"Wrote {sents} sentences")


numberRegex = re.compile("[0-9]+|[0-9]+\\.[0-9]+|[0-9]+[0-9,]+");
def normalize(word):
    return 'NUM' if numberRegex.match(word) else word.lower()


def inorder(sentence):
    queue = [sentence[0]]
    def inorder_helper(sentence,i):
        results = []
        left_children = [entry for entry in sentence[:i] if entry.parent_id == i]
        for child in left_children:
            results += inorder_helper(sentence,child.id)
        results.append(sentence[i])

        right_children = [entry for entry in sentence[i:] if entry.parent_id == i ]
        for child in right_children:
            results += inorder_helper(sentence,child.id)
        return results
    return inorder_helper(sentence,queue[0].id)


def set_seeds():
    python_seed = 1
    logging.debug("Using default Python seed")
    random.seed(python_seed)


def generate_seed():
    return random.randint(0,10**9) # this range seems to work for Dynet and Python's random function


In [4]:
train_dir = 'sample_data/UD_Russian-SynTagRus/ru_syntagrus-ud-train.conllu'
val_dir = 'sample_data/UD_Russian-SynTagRus/ru_syntagrus-ud-dev.conllu'
test_dir = 'sample_data/UD_Russian-SynTagRus/ru_syntagrus-ud-test.conllu'

train, train_words = read_conll(train_dir)
val, val_words = read_conll(val_dir)
test, test_words = read_conll(test_dir)
all_words = train_words + val_words + test_words
all_words = set(all_words)

In [5]:
print(len(all_words))

516


In [6]:
import time
from tqdm import tqdm

In [7]:
from transformers import AutoTokenizer, BertModel
def get_embed(tokenizer, model, word):
    inputs = tokenizer(word, return_tensors="pt")
    outputs = model(**inputs)

    last_hidden_states = outputs.last_hidden_state[0][0]
    return last_hidden_states.detach().cpu()

embeds = {}
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

print(f'Creating {len(all_words)} embeddings')
logging.info(f'Creating {len(all_words)} embeddings')
ts = time.time()
for word in tqdm(all_words):
    embeds[word] = get_embed(tokenizer, model, word)
logging.debug(f'{len(embeds)} embeddings were created')
te = time.time()
logging.info(f'Time of embedding creation: {te-ts:.2g}s')

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Creating 516 embeddings


100%|██████████| 516/516 [01:15<00:00,  6.86it/s]


In [8]:
import pickle

with open('embeds_small.pickle', 'wb') as f:
    pickle.dump(embeds, f)

In [9]:
from sys import getsizeof

getsizeof(embeds)

18520

In [10]:
#with open('embeds.pickle', 'rb') as f:
#    data_new = pickle.load(f)