In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.append('../../../')
from collections import Counter
from typing import List, Tuple

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from datasets import load_dataset
from multitask_nlp.settings import DATASETS_DIR

tqdm.pandas()

dataset_path = DATASETS_DIR / 'kpwr_n82'

In [2]:
def read_file(filepath: str) -> List[Tuple[str, List[str], List[str]]]:
    all_documents_data = []
    document_data = []
    sentence_tokens = []
    tags = []

    f = open(filepath, encoding='UTF-8')
    for i, line in enumerate(f, 1):
        if not line.strip() or len(line) == 0 or line[0] == "\n":
            if len(sentence_tokens) > 0:
                sentence = ' '.join(sentence_tokens)
                document_data.append((sentence, sentence_tokens, tags))
                sentence_tokens = []
                tags = []
            continue
            
        elif line.startswith('-DOCSTART'):
            if len(document_data) > 0:
                all_documents_data.append(document_data)
                document_data = []
            continue

        splits = line.split('\t')
        assert len(splits) >= 2, "error on line {}. Found {} splits".format(i, len(splits))
        word, ner_tag = splits[0], splits[3]
        sentence_tokens.append(word.strip())
        tags.append(ner_tag.strip())

    if len(sentence_tokens) > 0:
        sentence = ' '.join(sentence_tokens)
        document_data.append((sentence, sentence_tokens, tags))
        
    if len(document_data) > 0:
        all_documents_data.append(document_data)

    f.close()
    return all_documents_data

text_ids, texts, texts_tokens, tags = [], [], [], []

f_name = 'kpwr-ner-n82-train-tune.iob'

text_id = 1

all_documents_data = read_file(dataset_path / f_name)
for document_data in all_documents_data:
    for sentence, sentence_tokens, sentence_tags in document_data:
        texts.append(sentence)
        texts_tokens.append(sentence_tokens)
        tags.append(sentence_tags)
        text_id += 1

In [7]:
# all_documents_data[1]

In [3]:
assert all([len(tokens) == len(texts_tags) for tokens, texts_tags 
            in zip(texts_tokens, tags)])

In [4]:
len(texts)

13959

In [5]:
all_tags = [t for sent_tags in tags for t in sent_tags]

In [6]:
len(all_tags)

227982

In [7]:
counter = Counter(all_tags)

In [8]:
counter.most_common()

[('O', 204746),
 ('B-nam_liv_person', 2911),
 ('I-nam_liv_person', 1880),
 ('I-nam_org_institution', 1470),
 ('B-nam_loc_gpe_city', 1342),
 ('B-nam_loc_gpe_country', 992),
 ('I-nam_org_organization', 850),
 ('B-nam_org_institution', 783),
 ('I-nam_pro_title_document', 731),
 ('B-nam_org_organization', 591),
 ('B-nam_org_group_team', 464),
 ('B-nam_adj_country', 453),
 ('I-nam_org_group_team', 422),
 ('I-nam_pro_title', 391),
 ('I-nam_eve_human', 355),
 ('B-nam_org_company', 324),
 ('I-nam_org_company', 300),
 ('B-nam_pro_media_periodic', 299),
 ('I-nam_fac_goe', 293),
 ('I-nam_pro_media_periodic', 286),
 ('B-nam_fac_road', 265),
 ('B-nam_liv_god', 257),
 ('I-nam_eve_human_sport', 241),
 ('B-nam_org_nation', 231),
 ('B-nam_oth_tech', 229),
 ('B-nam_pro_media_web', 227),
 ('B-nam_fac_goe', 212),
 ('B-nam_eve_human', 209),
 ('B-nam_pro_title', 207),
 ('B-nam_pro_brand', 205),
 ('I-nam_pro_model_car', 200),
 ('I-nam_pro_brand', 198),
 ('I-nam_loc_gpe_city', 193),
 ('B-nam_org_political_par

In [9]:
iob_unique_tags = np.unique(all_tags)

In [10]:
len(iob_unique_tags)

160

In [11]:
iob_unique_tags

array(['B-nam_adj', 'B-nam_adj_city', 'B-nam_adj_country',
       'B-nam_adj_person', 'B-nam_eve', 'B-nam_eve_human',
       'B-nam_eve_human_cultural', 'B-nam_eve_human_holiday',
       'B-nam_eve_human_sport', 'B-nam_fac_bridge', 'B-nam_fac_goe',
       'B-nam_fac_goe_stop', 'B-nam_fac_park', 'B-nam_fac_road',
       'B-nam_fac_square', 'B-nam_fac_system', 'B-nam_liv_animal',
       'B-nam_liv_character', 'B-nam_liv_god', 'B-nam_liv_habitant',
       'B-nam_liv_person', 'B-nam_loc', 'B-nam_loc_astronomical',
       'B-nam_loc_country_region', 'B-nam_loc_gpe_admin1',
       'B-nam_loc_gpe_admin2', 'B-nam_loc_gpe_admin3',
       'B-nam_loc_gpe_city', 'B-nam_loc_gpe_conurbation',
       'B-nam_loc_gpe_country', 'B-nam_loc_gpe_district',
       'B-nam_loc_gpe_subdivision', 'B-nam_loc_historical_region',
       'B-nam_loc_hydronym', 'B-nam_loc_hydronym_lake',
       'B-nam_loc_hydronym_ocean', 'B-nam_loc_hydronym_river',
       'B-nam_loc_hydronym_sea', 'B-nam_loc_land',
       'B-nam_loc

In [12]:
unique_tags = []
for t in iob_unique_tags:
    if len(t.split('-')) > 1:
        unique_tags.append(t.split('-')[1])
    else:
         unique_tags.append(t)

unique_tags = np.unique(unique_tags)

In [13]:
unique_tags

array(['O', 'nam_adj', 'nam_adj_city', 'nam_adj_country',
       'nam_adj_person', 'nam_eve', 'nam_eve_human',
       'nam_eve_human_cultural', 'nam_eve_human_holiday',
       'nam_eve_human_sport', 'nam_fac_bridge', 'nam_fac_goe',
       'nam_fac_goe_stop', 'nam_fac_park', 'nam_fac_road',
       'nam_fac_square', 'nam_fac_system', 'nam_liv_animal',
       'nam_liv_character', 'nam_liv_god', 'nam_liv_habitant',
       'nam_liv_person', 'nam_loc', 'nam_loc_astronomical',
       'nam_loc_country_region', 'nam_loc_gpe_admin1',
       'nam_loc_gpe_admin2', 'nam_loc_gpe_admin3', 'nam_loc_gpe_city',
       'nam_loc_gpe_conurbation', 'nam_loc_gpe_country',
       'nam_loc_gpe_district', 'nam_loc_gpe_subdivision',
       'nam_loc_historical_region', 'nam_loc_hydronym',
       'nam_loc_hydronym_lake', 'nam_loc_hydronym_ocean',
       'nam_loc_hydronym_river', 'nam_loc_hydronym_sea', 'nam_loc_land',
       'nam_loc_land_continent', 'nam_loc_land_island',
       'nam_loc_land_mountain', 'nam_loc_

In [14]:
len(unique_tags)

83

In [25]:
from collections import defaultdict

tag_positions_unique = defaultdict(set)

In [26]:
for tag in unique_tags:
    split_tag = tag.split('_')
    for i, split in enumerate(split_tag):
        if split == 'O': 
            tag_positions_unique[i].add(split)
        elif split != 'nam':
            if i == 1:
                split = 'nam_' + split
            
            tag_positions_unique[i-1].add(split)

In [27]:
tag_positions_unique

defaultdict(set,
            {0: {'O',
              'nam_adj',
              'nam_eve',
              'nam_fac',
              'nam_liv',
              'nam_loc',
              'nam_num',
              'nam_org',
              'nam_oth',
              'nam_pro'},
             1: {'animal',
              'astronomical',
              'award',
              'brand',
              'bridge',
              'character',
              'city',
              'company',
              'country',
              'currency',
              'data',
              'god',
              'goe',
              'gpe',
              'group',
              'habitant',
              'historical',
              'house',
              'human',
              'hydronym',
              'institution',
              'land',
              'license',
              'media',
              'model',
              'nation',
              'organization',
              'park',
              'person',
              'phone',
    