In [1]:

# basics
from typing import List
from bs4 import BeautifulSoup as BS
from pathlib import Path
import pandas as pd

# segnlp
from segnlp.datasets import PE
from segnlp.experiment import Experiment
from segnlp.models.lstm_crf import LSTM_CRF
from segnlp.utils import RangeDict
from segnlp.data import Sample
from segnlp.metrics import default_token_metric

Using backend: pytorch


In [2]:
def load_ecc_test_data(path_to_data:str) -> List[Sample]:
    
    # find all xml files
    all_test_qas = list(Path(path_to_data).rglob("*.xml"))

    seg_id = 0
    sentences = []
    for fp in all_test_qas:
            
        # Because of the datafromat we first read the xml file as a
        # string then load parse it with BeautifulSoup
        with open(fp, "r") as f:
            xml_str = f.read()
        
        # parse xml
        soup = BS(xml_str)
        
        # find the QA body text
        QA = soup.find("body").text.split("Question:")[1]
        
        # init a range dict for our segments
        span_labels = RangeDict()
        
        # we find all the segments in the data
        for seg in soup.find_all("segment"):

            #if ";" not in seg.attrs["features"]:
            #    continue

            start =  QA.find(seg.text)
            end = start + len(seg.text)

            #print(seg.attrs["features"])
            #_, label, link, link_label = seg.attrs["features"].split(";")

            span_labels[(start, end)] = {"seg_id": seg_id}
            seg_id += 1


        # fill in some spans
        #span_start = 0
        #for (start,end), label_dict in span_labels.copy().items():

        #    if start-1 != span_start:
        #        span_labels[(span_start, start-1)] = {"seg_id": -1}
        #        span_start = end+1
        
        # we preprocess the QA
        s = Sample(QA)
        
        # label the QA
        s.add_span_labels(
                        span_labels,
                        task_labels = {"seg":["O","B","I"]}
                         )
        
        # split the QA into sentences
        sentences.extend(s.split("sentence"))
        
    return sentences

In [3]:
ecc_sentences = load_ecc_test_data("/home/axlalm/ecc")

In [4]:
pe = PE(
        tasks = ["seg"],
        prediction_level = "token",
        sample_level = "sentence"
)

In [5]:

exp = Experiment(
            id = "pe_seg_sentence",
            dataset = pe,
            metric = "default_token_metric",
            n_random_seeds = 1
        )


In [6]:
hps = {
        "general":{
                "optimizer": {
                                "name": "SGD",
                                "lr": 0.1,
                                },
                "lr_scheduler":{ 
                                "name": "ReduceLROnPlateau",
                                "mode": "max",
                                "factor": 0.5,
                                "patience": 3,
                                "min_lr": 0.0001
                                },
                "batch_size": 32,
                "max_epochs": 1,
                "patience": 5,
                "gradient_clip_val": 5.0
                },
        "flair_embeddings": {
                        "embs": "flair+bert+glove"
                    },
        "linear_finetuner":{},
        "token_dropout": {
                        "p":0.05
                        },
        "paramater_dropout": {
                        "p":0.5
                        },
        "lstm": {  
                    "dropout":0.5,
                    "hidden_size": 256,
                    "num_layers":2,
                    "bidir":True,
                    },
        "crf": {
                }
    }


exp.train(
            model = LSTM_CRF,
            hyperparamaters = hps,
            monitor_metric = "f1"
)



Hyperparamaters:   0%|          | 0/1 [00:00<?, ?it/s]

Random Seeds: 0it [00:00, ?it/s]

 _______________  Val Scores  _______________
                          0
epoch                     0
split                   val
use_target_segs       False
seg_pretraining       False
seg-O-precision    0.941684
seg-O-recall       0.492415
seg-O-f1           0.641323
seg-B-precision    0.710876
seg-B-recall       0.767471
seg-B-f1           0.736306
seg-I-precision    0.802694
seg-I-recall       1.000918
seg-I-f1           0.889465
seg-precision      0.818418
seg-recall         0.753601
seg-f1             0.755698
precision          0.818418
recall             0.753601
f1                 0.755698
loss             138.313263
random_seed          710955


In [7]:
exp.test(
            model = LSTM_CRF,
            monitor_metric = "f1"
)

In [8]:
exp.prediction_mode(model = LSTM_CRF)

In [9]:
preds_sentence = exp.predict(list_samples = ecc_sentences, batch_size = 32)

In [11]:
# calculate the metrics
# TEMPORARY SOLUTION
target_df = pd.concat([s.df for s in ecc_sentences])
pred_df = pd.concat([s.df for s in preds_sentence])

scores = default_token_metric(
                        target_df = target_df, 
                        pred_df = pred_df, 
                        task_labels = preds_sentence[0]._task_labels
                        )
scores

{'seg-O-precision': 0.6,
 'seg-O-recall': 0.4444444444444444,
 'seg-O-f1': 0.5106382978723405,
 'seg-B-precision': 0.2857142857142857,
 'seg-B-recall': 0.3333333333333333,
 'seg-B-f1': 0.30769230769230765,
 'seg-I-precision': 0.8697478991596639,
 'seg-I-recall': 0.92,
 'seg-I-f1': 0.8941684665226783,
 'seg-precision': 0.5851540616246499,
 'seg-recall': 0.5659259259259258,
 'seg-f1': 0.5708330240291088,
 'precision': 0.5851540616246499,
 'recall': 0.5659259259259258,
 'f1': 0.5708330240291088}

Unnamed: 0,id,sentence_token_id,char_start,char_end,str,pos,dephead,deprel,paragraph_sentence_id,document_paragraph_id,nr_paragraphs_doc,document_id,paragraph_id,document_sentence_id,sentence_id,paragraph_token_id,document_token_id,root_idx,seg_id,seg
25,25,0,98,101,but,CC,41,cc,2,0,2,0,0,2,2,25,25,0,-1.0,0
26,26,1,102,106,just,RB,29,advmod,2,0,2,0,0,2,2,26,26,0,-1.0,0
27,27,2,107,112,maybe,RB,29,advmod,2,0,2,0,0,2,2,27,27,0,-1.0,0
28,28,3,113,117,some,DT,29,det,2,0,2,0,0,2,2,28,28,0,-1.0,0
29,29,4,118,127,questions,NNS,41,advcl,2,0,2,0,0,2,2,29,29,0,-1.0,0
30,30,5,128,130,on,IN,29,prep,2,0,2,0,0,2,2,30,30,0,-1.0,0
31,31,6,131,134,the,DT,33,det,2,0,2,0,0,2,2,31,31,0,-1.0,0
32,32,7,135,139,down,JJ,33,amod,2,0,2,0,0,2,2,32,32,0,-1.0,0
33,33,8,140,144,time,NN,30,pobj,2,0,2,0,0,2,2,33,33,0,-1.0,0
34,34,9,145,147,to,IN,29,prep,2,0,2,0,0,2,2,34,34,0,-1.0,0


Unnamed: 0,id,sentence_token_id,char_start,char_end,str,pos,dephead,deprel,paragraph_sentence_id,document_paragraph_id,...,document_id,paragraph_id,document_sentence_id,sentence_id,paragraph_token_id,document_token_id,root_idx,seg_id,seg,target_id
66,66,0,277,280,but,CC,69,cc,3,0,...,0,0,3,3,66,66,0,-1,0.0,-1
67,67,1,281,285,also,RB,69,advmod,3,0,...,0,0,3,3,67,67,0,-1,0.0,-1
68,68,2,286,287,i,PRP,69,nsubj,3,0,...,0,0,3,3,68,68,0,0,1.0,-1
69,69,3,288,294,missed,VBD,63,ROOT,3,0,...,0,0,3,3,69,69,0,0,2.0,-1
70,70,4,295,296,a,DT,71,det,3,0,...,0,0,3,3,70,70,0,0,2.0,-1
71,71,5,297,306,component,NN,69,dobj,3,0,...,0,0,3,3,71,71,0,0,2.0,-1
72,72,6,307,312,where,WRB,74,advmod,3,0,...,0,0,3,3,72,72,0,0,2.0,-1
73,73,7,313,316,you,PRP,74,nsubj,3,0,...,0,0,3,3,73,73,0,0,2.0,-1
74,74,8,317,320,say,VBP,71,relcl,3,0,...,0,0,3,3,74,74,0,0,2.0,-1
75,75,9,321,323,in,IN,74,prep,3,0,...,0,0,3,3,75,75,0,0,2.0,-1
