In [13]:
import os
import pandas as pd

from typing import Tuple, Dict

from spacy import displacy
import matplotlib.pyplot as plt
import plotly.express as px

from tqdm import tqdm

In [14]:
DATASET_FOLDER = 'feedback-prize-2021'

def load_data_csv() -> pd.DataFrame:
    return pd.read_csv(
        os.path.join(DATASET_FOLDER, 'train.csv'), 
        dtype={'discourse_id': 'int64', 'discourse_start': int, 'discourse_end': int})

def load_file(file_id: str, folder: str = 'train') -> str:
    path = os.path.join(DATASET_FOLDER, folder, file_id + '.txt')
    with open(path, 'r') as f:
        text = f.read()
    return text

def load_texts(folder: str = 'train') -> pd.Series:
    data_path = os.path.join(DATASET_FOLDER, folder)

    def read(filename):
        with open(os.path.join(data_path, filename), 'r') as f:
            text = f.read()
        return text

    return pd.Series({fname.replace('.txt', ''): read(fname) for fname in tqdm(os.listdir(data_path))})     

def load_dataset() -> Tuple[pd.Series, pd.DataFrame]:
    return load_texts(), load_data_csv()

In [15]:
# Credits for this part of visualisation _> https://www.kaggle.com/thedrcat

COLORS = {
            'Lead': '#8000ff',
            'Position': '#2b7ff6',
            'Evidence': '#2adddd',
            'Claim': '#80ffb4',
            'Concluding Statement': 'd4dd80',
            'Counterclaim': '#ff8042',
            'Rebuttal': '#ff0000'
         }

def visualize(id_example, texts, data):         
    ents = []
    for _, row in data[data['id'] == id_example].iterrows():
        ents.append({
            'start': int(row['discourse_start']), 
            'end': int(row['discourse_end']), 
            'label': row['discourse_type']
            })
    doc2 = {'text': texts[id_example], 'ents': ents, 'title': id_example}
    options = {'ents': data.discourse_type.unique().tolist(), 'colors': COLORS}
    displacy.render(doc2, style='ent', options=options, manual=True, jupyter=True)

In [16]:
def plot_histograms(values: Dict[str, int], title: str, x_label: str, y_label: str):
    fig =px.bar(x=list(values.keys()),y = list(values.values()) )
    fig.update_xaxes(title = x_label)
    fig.update_yaxes(title = y_label)
    fig.update_layout(showlegend = True,
        title = {
            'text': title,
            'y':0.95,
            'x':0.5,
            'xanchor': 'center',
            'yanchor': 'top'})
    fig.show()

In [17]:
texts, data = load_dataset()

100%|██████████| 15594/15594 [00:01<00:00, 12711.45it/s]


In [18]:
test_texts = load_texts('test')

100%|██████████| 5/5 [00:00<?, ?it/s]


In [20]:
data.head()

Unnamed: 0,id,discourse_id,discourse_start,discourse_end,discourse_text,discourse_type,discourse_type_num,predictionstring
0,423A1CA112E2,1622627660524,8,229,Modern humans today are always on their phone....,Lead,Lead 1,1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 1...
1,423A1CA112E2,1622627653021,230,312,They are some really bad consequences when stu...,Position,Position 1,45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
2,423A1CA112E2,1622627671020,313,401,Some certain areas in the United States ban ph...,Evidence,Evidence 1,60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
3,423A1CA112E2,1622627696365,402,758,"When people have phones, they know about certa...",Evidence,Evidence 2,76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 9...
4,423A1CA112E2,1622627759780,759,886,Driving is one of the way how to get around. P...,Claim,Claim 1,139 140 141 142 143 144 145 146 147 148 149 15...


In [19]:
values = data['discourse_type'].value_counts().to_dict()
plot_histograms(values, 'Discourse Type Distribution', 'Classes', 'Number of Rows')

In [21]:
data.discourse_type.value_counts(ascending=True).plot()

In [None]:
values = data['discourse_type'].value_counts().to_dict()
plot_histograms(values, 'Discourse Type Distribution', 'Classes', 'Number of Rows')

In [7]:
ids_with_parts_no_label = []

for group_id, group in data.groupby(by='id'):
    init = -1
    for _, row in group.iterrows():
        g_init = row.discourse_start
        if g_init > init + 1:
            part = texts[group_id][init: g_init]
            if len(part) > 100:
                ids_with_parts_no_label.append(group_id)
                break
        init = row.discourse_end
    if len(ids_with_parts_no_label) > 5:
        break
    

In [8]:
# examples = data['id'].sample(n=1, random_state=42).values.tolist()
examples = ids_with_parts_no_label
for ex in examples:
    visualize(ex, texts, data)
    print('\n')























