# Prepare environment

In [1]:
import os
import numpy as np
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from utils import format_result, load_raw_data

In [2]:
def format_labeling_df(filepath):
    with open(filepath, mode='r', encoding='utf-8') as f:
        chars = []
        labels = []
        result = pd.DataFrame()
        count = 0
        wl_list = f.readlines()
        for wl in wl_list:
            if not wl == '\n':
                wl = wl.split()
                chars.append(wl[0].strip())
                labels.append(wl[1].strip())
            else:
                entity_result = format_result(chars, labels)
                df = pd.DataFrame(entity_result)
                df.insert(0, 'article_id', [count for _ in range(len(entity_result))])
                result = pd.concat([result, df], ignore_index=True)
                chars = []
                labels = []
                count += 1
                
    return result

In [3]:
all_labels = ['name', 'location', 'time', 'contact', 'id', 'profession', 'biomarker', 'family', 
              'clinical_event', 'special_skills', 'unique_treatment', 'account', 'organization', 
              'education', 'money', 'belonging_mark', 'med_exam', 'others']

# Check labels

In [25]:
# traindf = format_labeling_df('data/TRAIN_FINAL')
# testdf = format_labeling_df('data/TEST_FINAL')
articles_1, df1 = load_raw_data('data/train_1/train_1.txt')
articles_2, df2 = load_raw_data('data/train_2/train_2_raw.txt')

In [50]:
df2.query('entity_type == "contact"')

Unnamed: 0,article_id,start_position,end_position,entity_text,entity_type
1232,70,1025,1029,LINE,contact
1238,72,58,62,LINE,contact
1250,72,1646,1650,LINE,contact
1280,76,798,800,臉書,contact
1281,76,807,809,臉書,contact
1282,76,813,815,臉書,contact
1283,76,822,824,臉書,contact
1288,77,1528,1532,ＬＩＮＥ,contact
1289,77,1872,1876,ＬＩＮＥ,contact
1290,77,2042,2046,ＬＩＮＥ,contact


In [29]:
df1['entity_type'].value_counts()

time              1434
med_exam           220
name               169
location           161
money               78
family              25
contact             19
profession          13
ID                   8
clinical_event       5
education            3
organization         1
Name: entity_type, dtype: int64

In [30]:
df2.loc[df2['article_id'] < 120, 'entity_type'].value_counts()

time              1358
med_exam           220
location           175
name               167
money               78
contact             30
family              25
profession          13
ID                   8
clinical_event       5
education            3
organization         1
Name: entity_type, dtype: int64

In [32]:
# train_labels_count = [traindf['entity_type'].value_counts().get(l, 0) for l in all_labels]
# test_labels_count = [testdf['entity_type'].value_counts().get(l, 0) for l in all_labels]
df1_lcount = [df1['entity_type'].value_counts().get(l, 0) for l in all_labels]
df2_lcount = [df2['entity_type'].value_counts().get(l, 0) for l in all_labels]

In [44]:
fig = make_subplots(rows=1, cols=2, 
                    subplot_titles=['Label Counts For Train 1', 'Label Counts For Train 2'], 
                    specs=[[{'type': 'domain'}, {'type': 'domain'}]])

fig.add_trace(go.Pie(labels=all_labels, values=df1_lcount), row=1, col=1)

fig.add_trace(go.Pie(labels=all_labels, values=df2_lcount), row=1, col=2)

fig.show()

In [None]:
fig = go.Figure(data=[
    go.Bar(x=all_labels, y=train_labels_count, name='train data<br>labels', 
           text=train_labels_count,
           textposition='outside'),
    go.Bar(x=all_labels, y=test_labels_count, name='test data<br>labels', 
           text=test_labels_count,
           textposition='outside')
])

fig.show(config={'displayModeBar': True})

In [None]:
testdf.query('entity_type == "clinical_event"')

In [None]:
traindf.query('entity_type == "clinical_event" or entity_type == "profession" or entity_type == "education"')
# 2, 18, 10

In [45]:
import chart_studio.tools as tls
import chart_studio.plotly as cplt

usr = 'SharpKoi'
api_key = 'XwSSNCqGzY0TWCDthlkC'
tls.set_credentials_file(username=usr, api_key=api_key)
cplt.plot(fig, file_name='Train Label Counts Diff', auto_open=False)

'https://plotly.com/~SharpKoi/28/'

# Check raw data

In [None]:
articles, ldf = load_raw_data('data/train_2/train_2.txt')

In [None]:
articles[64][900:1100]

In [None]:
ldf.query('entity_type == "education"')