In [1]:
import os
import sys
import random

import numpy as np

sys.path.append('../src/')

from copy import deepcopy
from elmo_ner import ELMo_NER

In [2]:
def make_dataset(path, percent_test=0.2, target_mapping={}):
    X = []
    y = []
    for filename in sorted(os.listdir(path)):
        if filename.endswith('.txt'):
            filename_full = os.path.join(path, filename)
            with open(filename_full) as f:
                text = f.read().strip()
            span_dict = {}
            with open(filename_full[:-4] + '.spans') as f:
                span = f.read().strip()
                for line in span.split('\n'):
                    metadata, items = line.strip().split('  # ')
                    postag, offset, length = metadata.split()[1:4]
                    try:
                        real_postag = target_mapping[postag]
                    except KeyError:
                        continue
                    if real_postag in span_dict:
                        span_dict[real_postag] += [(int(offset), int(offset)+int(length))]
                    else:
                        span_dict[real_postag] = [(int(offset), int(offset)+int(length))]
            X += [text]
            y += [span_dict]
    X_y = list(zip(X, y))
    random.shuffle(X_y)
    n_test = int(len(X_y)*percent_test)
    X_train, y_train = list(zip(*X_y[:-n_test]))
    X_test, y_test = list(zip(*X_y[-n_test:]))
    return X_train, X_test, y_train, y_test

In [3]:
target_mapping = {'loc_name': 'ORG',
                  'org_name': 'LOC',
                  'name': 'PER',
                  'surname': 'PER'}
X_train, X_test, y_train, y_test = make_dataset('/home/olya/datasets/factRuEval-2016/devset/',
                                                target_mapping=target_mapping)

In [4]:
ner = ELMo_NER(elmo_hub_module_handle='http://files.deeppavlov.ai/deeppavlov_data/elmo_ru-news_wmt11-16_1.5M_steps.tar.gz')

In [None]:
ner.fit(X_train[:10], y_train[:10])

INFO:tensorflow:Using /tmp/tfhub_modules to cache modules.
INFO:tensorflow:Initialize variable module/aggregation/elmo_output_ELMo_W:0 from checkpoint b'/tmp/tfhub_modules/48ab951bfe35732722bd465a2822752b390006c5/variables/variables' with aggregation/elmo_output_ELMo_W
INFO:tensorflow:Initialize variable module/aggregation/elmo_output_ELMo_gamma:0 from checkpoint b'/tmp/tfhub_modules/48ab951bfe35732722bd465a2822752b390006c5/variables/variables' with aggregation/elmo_output_ELMo_gamma
INFO:tensorflow:Initialize variable module/bilm/CNN/W_cnn_0:0 from checkpoint b'/tmp/tfhub_modules/48ab951bfe35732722bd465a2822752b390006c5/variables/variables' with bilm/CNN/W_cnn_0
INFO:tensorflow:Initialize variable module/bilm/CNN/W_cnn_1:0 from checkpoint b'/tmp/tfhub_modules/48ab951bfe35732722bd465a2822752b390006c5/variables/variables' with bilm/CNN/W_cnn_1
INFO:tensorflow:Initialize variable module/bilm/CNN/W_cnn_2:0 from checkpoint b'/tmp/tfhub_modules/48ab951bfe35732722bd465a2822752b390006c5/varia

INFO:tensorflow:Initialize variable module/bilm/Variable:0 from checkpoint b'/tmp/tfhub_modules/48ab951bfe35732722bd465a2822752b390006c5/variables/variables' with bilm/Variable
INFO:tensorflow:Initialize variable module/bilm/Variable_1:0 from checkpoint b'/tmp/tfhub_modules/48ab951bfe35732722bd465a2822752b390006c5/variables/variables' with bilm/Variable_1
INFO:tensorflow:Initialize variable module/bilm/Variable_2:0 from checkpoint b'/tmp/tfhub_modules/48ab951bfe35732722bd465a2822752b390006c5/variables/variables' with bilm/Variable_2
INFO:tensorflow:Initialize variable module/bilm/Variable_3:0 from checkpoint b'/tmp/tfhub_modules/48ab951bfe35732722bd465a2822752b390006c5/variables/variables' with bilm/Variable_3
INFO:tensorflow:Initialize variable module/bilm/Variable_4:0 from checkpoint b'/tmp/tfhub_modules/48ab951bfe35732722bd465a2822752b390006c5/variables/variables' with bilm/Variable_4
INFO:tensorflow:Initialize variable module/bilm/Variable_5:0 from checkpoint b'/tmp/tfhub_modules/4

In [7]:
y_train[33]

{'PER': [(177, 184), (187, 194), (1304, 1310), (1311, 1320)],
 'ORG': [(253, 261), (264, 270)],
 'LOC': [(629, 633), (691, 695), (1200, 1204), (1331, 1335)]}