In [1]:
import pandas as pd
from ast import literal_eval

In [2]:
csvs = [
    '~/dataset/train/train.csv',
    '~/dataset/test/test_data.csv'
]

In [3]:
dfs = []
for csv in csvs:
    dfs.append(pd.read_csv(csv))

In [4]:
def parse_entities(entities: pd.Series) -> pd.Series:
    parsed = entities.apply(lambda entity: literal_eval(entity))
    return parsed

In [5]:
for df in dfs:
    df.subject_entity = parse_entities(df.subject_entity)
    df.object_entity = parse_entities(df.object_entity)
    

In [6]:
# word, start_idx, end_idx, type
def localize_entities(sentence: str, entity: dict):
    return sentence[entity['start_idx']:entity['end_idx'] + 1] == entity['word']

In [7]:
for df in dfs:
    assert(df.apply(lambda row: localize_entities(row.sentence, row.subject_entity), axis=1).all())
    assert(df.apply(lambda row: localize_entities(row.sentence, row.object_entity), axis=1).all())

In [8]:
for i, df in enumerate(dfs):
    print(csvs[i])
    print(df.subject_entity.apply(lambda entity: entity['type']).value_counts())
    print(df.object_entity.apply(lambda entity: entity['type']).value_counts())
    print()

~/dataset/train/train.csv
PER    16786
ORG    15684
Name: subject_entity, dtype: int64
PER    9788
ORG    9346
POH    5113
DAT    4249
LOC    3561
NOH     413
Name: object_entity, dtype: int64

~/dataset/test/test_data.csv
PER    3925
ORG    3839
LOC       1
Name: subject_entity, dtype: int64
POH    3171
LOC    1204
PER    1138
ORG    1047
DAT     790
NOH     415
Name: object_entity, dtype: int64

