In [1]:
import os
import json
import sys
import torch
import random
import numpy as np
import spacy
from tqdm import tqdm
import pickle as pk
import json
import csv
import requests
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import Element
from xml.etree.ElementTree import SubElement
from xml.etree.ElementTree import ElementTree
from nltk.corpus import wordnet as wn
from nltk.corpus.reader import Synset
nlp = spacy.load('en_core_web_sm')

from spacy.lang.en import English
nlp_sen = English()
config = {"punct_chars": None}
nlp_sen.add_pipe("sentencizer", config=config)

### Pre setting of ESC

#### download ESC model published at the official github(https://github.com/SapienzaNLP/esc)  via link https://drive.google.com/file/d/100jxjLIdmSzrMXXOWgrPz93EG0JBnkfr/view and put the ckpt file in ESC_MODEL

In [2]:
import argparse
import time
from typing import NamedTuple, List, Optional, Tuple

os.environ['CUDA_PATH']='/usr/local/cuda'
os.environ['LD_LIBRARY_PATH']='$CUDA_PATH/lib64:$LD_LIBRARY_PATH'
ROOT_DIR = os.path.abspath('./')
sys.path.append(ROOT_DIR)


import torch
from torch.utils.data import DataLoader

from esc.utils.definitions_tokenizer import get_tokenizer
from esc.utils.wordnet import synset_from_offset
from esc.utils.wsd import WSDInstance
from esc.esc_dataset import WordNetDataset, OxfordDictionaryDataset
from esc.esc_pl_module import ESCModule
from esc.utils.commons import list_elems_in_dir
from esc.predict import InstancePredictionReport,ScoresReport,PredictionReport,probabilistic_prediction,precision_recall_f1_accuracy_score,predict,process_prediction_result

import ntpath
import logging

In [3]:
%load_ext autoreload
%autoreload 2

In [3]:
import ntpath
import logging

### parameter and model
args={}
args['ckpt'] = './ESC_Model/escher_semcor_best.ckpt'
args['dataset_paths'] = 'Test.data.xml'
args['prediction_types'] = 'probabilistic'
args['evaluate'] = False
args['device'] = 0
args['tokens_per_batch'] = 4000
args['output_errors'] = False
args['oxford_test'] = False
wsd_model = ESCModule.load_from_checkpoint(args['ckpt'])
wsd_model.freeze()

tokenizer = get_tokenizer(
    wsd_model.hparams.transformer_model, getattr(wsd_model.hparams, "use_special_tokens", False)
)

prediction_type = args['prediction_types']

dataset_path = args['dataset_paths']
dataset_path = dataset_path.replace(".data.xml", "").replace(".gold.key.txt", "")

### Presetting of Extration

Set **debug** to True when test *Labeler* on small version of data.  

Initial state of **input_file** is False, after first time running, all the xml file will be stored locally. Set **input_file** to True to use the stored xml file directly.  

**Log** file contains running information during the training.  

**Participant** file store the final result of participant extration.

In [4]:
from participant import *

In [6]:
debug = True
log_file_name = "./Log"
mkdir(log_file_name)
final_file_name = "./Final_Participant"
mkdir(final_file_name)

New folder created
New folder created


### CODAH

Get dataset

In [7]:
file_name = '../../Target_task/codah.tsv'
dev_dataset = get_codah_dataset(file_name)

if debug:
    dev_dataset = dev_dataset[:5]

100%|██████████| 2776/2776 [00:01<00:00, 2338.01it/s]


Extraction

In [7]:
xml_file_name = "./XML/CODAH"
mkdir(xml_file_name)
input_file=False
for data_index,sample_data in tqdm(enumerate(dev_dataset)):
        try:
            with open("./Log/CODAH_LOG.txt","a") as file:
                file.write("prompting begin at {data_index}\n".format(data_index=data_index))

            sen_idx = 0
            sen = sample_data['sentences'][0][0]
            sol_idx = -1
            file_name = os.path.join(xml_file_name, "{data_index}_{sen_idx}_{sol_idx}.xml".format(data_index=data_index,sen_idx=sen_idx,sol_idx=sol_idx))
            pre_entity = list(GetEntityList(nlp,sen,file_name=file_name,input_file=input_file))
            entity_list = [[],[],[],[]]
            ### end 1
            for sol_idx,sen_list in enumerate(sample_data['sentences']):
                for sen_idx,sen in enumerate(sen_list):
                    if sen_idx>0:
                        file_name =os.path.join(xml_file_name, "{data_index}_{sen_idx}_{sol_idx}.xml".format(data_index=data_index,sen_idx=sen_idx,sol_idx=sol_idx))
                        entity_list[sol_idx] +=  list(GetEntityList(nlp,sen,file_name=file_name,input_file=input_file)) 
            for index in range(len(entity_list)):
                entity_list[index] += pre_entity
            sample_data['entity'] = entity_list
        except:
            with open("./Log/CODAH_LOG.txt","a") as file:
                file.write("wrong at {data_index}\n".format(data_index=data_index))
                

Prepocess and save

In [9]:
dev_dataset = codah_post_preprocess(dev_dataset)
np.save("Final_Participant/CODAH_Participant",dev_dataset)

5it [00:00, 47.44it/s]
5it [00:00, 33026.02it/s]
5it [00:00, 1246.67it/s]


### ROCStories

Get dataset

In [10]:
file_name = "../../Target_task/rocstories.csv"
dev_dataset = get_roc_dataset(file_name)

if debug:
    dev_dataset = dev_dataset[:5]

1872it [00:00, 120176.58it/s]


Extraction

In [11]:
xml_file_name = "./XML/ROC"
mkdir(xml_file_name)
input_file=False
for data_index,sample_data in tqdm(enumerate(dev_dataset)):
    try:

        with open("./Log/ROC_LOG.txt","a") as file:
            file.write("prompting begin at {data_index}\n".format(data_index=data_index))
        entity_1 = set()
        entity_2 = set()
        for sen_idx,sen in enumerate(sample_data['sentence']):
            sol_idx=0
            file_name = os.path.join(xml_file_name, "{data_index}_{sen_idx}_{sol_idx}.xml".format(data_index=data_index,sen_idx=sen_idx,sol_idx=sol_idx))
            pre_entity = set(GetEntityList(nlp,sen,file_name=file_name,input_file=input_file))
            entity_1 = entity_1 | pre_entity
            entity_2 = entity_2 | pre_entity
        ### end 1
        sen = sample_data['end1']
        sen_idx=0
        sol_idx=1
        file_name =os.path.join(xml_file_name, "{data_index}_{sen_idx}_{sol_idx}.xml".format(data_index=data_index,sen_idx=sen_idx,sol_idx=sol_idx))
        entity_1 = entity_1 | set(GetEntityList(nlp,sen,file_name=file_name,input_file=input_file))    
        ### end 2
        sen = sample_data['end2']
        sen_idx=0
        sol_idx=2
        file_name =os.path.join(xml_file_name, "{data_index}_{sen_idx}_{sol_idx}.xml".format(data_index=data_index,sen_idx=sen_idx,sol_idx=sol_idx))
        entity_2 = entity_2 | set(GetEntityList(nlp,sen,file_name=file_name,input_file=input_file))   
        sample_data['entity_1']=entity_1
        sample_data['entity_2']=entity_2
    except:
        with open("./Log/ROC_LOG.txt","a") as file:
            file.write("wrong at {data_index}\n".format(data_index=data_index))

New folder created


1it [00:18, 18.79s/it]

No synsets found in WordNet for instance d000.s000.t000. Skipping this instance
WSDInstance(annotated_token=AnnotatedToken(text='them', pos='NOUN', lemma='them'), labels=None, instance_id='d000.s000.t000')


5it [01:21, 16.32s/it]


Prepocess and save

In [12]:
dev_dataset  = post_preprocess(dev_dataset)
np.save("./Final_Participant/ROC_Participant",dev_dataset)

5it [00:00, 65.97it/s]
5it [00:00, 14614.30it/s]
5it [00:00, 753.34it/s]


### PIQA

Get dataset

In [13]:
input_name = '../../Target_task/piqa.jsonl'
label_name = '../../Target_task/piqa-labels.lst'
dev_dataset = get_piqa_dataset(input_name,label_name,nlp,nlp_sen)

if debug:
    dev_dataset = dev_dataset[:5]

100%|██████████| 1838/1838 [00:00<00:00, 3392.62it/s]


Extraction

In [8]:
xml_file_name = "./XML/PIQA"
mkdir(xml_file_name)

input_file=False
for data_index,sample_data in tqdm(enumerate(dev_dataset[:5])):
    try:
        with open("./Log/PIQA_LOG.txt","a") as file:
            file.write("prompting begin at {data_index}\n".format(data_index=data_index))
        entity_1 = set()
        entity_2 = set()
        for sen_idx,sen in enumerate(sample_data['goal_sol_1']):
            sol_idx=1
            file_name = os.path.join(xml_file_name, "{data_index}_{sen_idx}_{sol_idx}.xml".format(data_index=data_index,sen_idx=sen_idx,sol_idx=sol_idx))
            entity_1 = entity_1 | set(GetEntityList(nlp,sen,file_name=file_name,input_file=input_file))
        for sen_idx,sen in enumerate(sample_data['goal_sol_2']):
            sol_idx=2
            file_name =os.path.join(xml_file_name, "{data_index}_{sen_idx}_{sol_idx}.xml".format(data_index=data_index,sen_idx=sen_idx,sol_idx=sol_idx))
            entity_2 = entity_2 | set(GetEntityList(nlp,sen,file_name=file_name,input_file=input_file))    

        sample_data['entity_1']=entity_1
        sample_data['entity_2']=entity_2
    except:
        with open("./Log/PIQA_LOG.txt","a") as file:
            file.write("wrong at {data_index}\n".format(data_index=data_index))

Prepocess and save

In [15]:
dev_dataset  = post_preprocess(dev_dataset)
np.save("./Final_Participant/PIQA_Participant",dev_dataset)

5it [00:00, 60787.01it/s]
5it [00:00, 36345.79it/s]
5it [00:00, 538.42it/s]


### ANLI

Get dataset

In [16]:
input_name = '../../Target_task/anli.jsonl'
label_name = '../../Target_task/anli-labels.lst'
dev_dataset = get_anli_dataset(input_name,label_name,nlp,nlp_sen)

if debug:
    dev_dataset = dev_dataset[:5]

100%|██████████| 1532/1532 [00:00<00:00, 2976.75it/s]


Extraction

In [17]:
xml_file_name = "./XML/ANLI"
mkdir(xml_file_name)
input_file=False
for data_index,sample_data in tqdm(enumerate(dev_dataset)):
    try:
        with open("./Log/ANLI_LOG_sup.txt","a") as file:
            file.write("prompting begin at {data_index}\n".format(data_index=data_index))
        entity_1 = set()
        entity_2 = set()
        for sen_idx,sen in enumerate([sample_data['obs1'],sample_data['obs2']]):
            sol_idx = 0
            file_name = os.path.join(xml_file_name, "{data_index}_{sen_idx}_{sol_idx}.xml".format(data_index=data_index,sen_idx=sen_idx,sol_idx=sol_idx))
            pre_entity = set(GetEntityList(nlp,sen,file_name=file_name,input_file=input_file))
            entity_1 = entity_1 | pre_entity
            entity_2 = entity_2 | pre_entity

        # hpyo_1
        sen = sample_data['hyp1']
        sen_idx=0
        sol_idx=1   
        file_name =os.path.join(xml_file_name, "{data_index}_{sen_idx}_{sol_idx}.xml".format(data_index=data_index,sen_idx=sen_idx,sol_idx=sol_idx))
        entity_1 = entity_1 | set(GetEntityList(nlp,sen,file_name=file_name,input_file=input_file))  

        # hpyo_2
        sen = sample_data['hyp2']
        sen_idx=0
        sol_idx=2   
        file_name =os.path.join(xml_file_name, "{data_index}_{sen_idx}_{sol_idx}.xml".format(data_index=data_index,sen_idx=sen_idx,sol_idx=sol_idx))
        entity_2 = entity_2 | set(GetEntityList(nlp,sen,file_name=file_name,input_file=input_file))  

        sample_data['entity_1']=entity_1
        sample_data['entity_2']=entity_2
    except:
        with open("./Log/ANLI_LOG_sup.txt","a") as file:
            file.write("wrong at {data_index}\n".format(data_index=data_index))

New folder created


5it [00:43,  8.65s/it]


Prepocess and save

In [18]:
dev_dataset  = post_preprocess(dev_dataset)
np.save("./Final_Participant/ANLI_Participant",dev_dataset)

5it [00:00, 30750.03it/s]
5it [00:00, 47127.01it/s]
5it [00:00, 944.79it/s]
