In [None]:
import pandas as pd
import spacy

### Load data

In [None]:
lang = 'es'
Set = 'dev'
extractor = 'LM'

In [None]:
data= pd.read_csv(f'../Dataset/{Set}.csv')
data.lang.unique()

In [None]:
data = data[data['lang']==lang]

### Information Extraction Pipeline

In [None]:
# from transformers import AutoTokenizer, AutoModelForTokenClassification
# from transformers import pipeline
# tokenizer = AutoTokenizer.from_pretrained("Davlan/xlm-roberta-base-ner-hrl")
# model = AutoModelForTokenClassification.from_pretrained("Davlan/xlm-roberta-base-ner-hrl")
# nlp = pipeline("ner", model=model, tokenizer=tokenizer)


In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
tokenizer = AutoTokenizer.from_pretrained("jplu/tf-xlm-r-ner-40-lang")
model = AutoModelForTokenClassification.from_pretrained("jplu/tf-xlm-r-ner-40-lang",from_tf=True)
nlp = pipeline("ner", model=model, tokenizer=tokenizer)

In [None]:
#import the pipeline class


from InformationExtraction import InformationExtractionPipeline


# example spacy extractor function
NER = spacy.load("en_core_web_lg")
def tag_extraction_from_spacy(sen, model = NER):
    
    annotated = model(sen)
    extracted_names = [word.text for word in annotated.ents 
                       if word.label_=='PERSON' or word.label_=='ORG'or word.label_=='GPE']
    
    
    return extracted_names


def tag_extraction_from_LM(sen, model = nlp):
    
    ner_results = model(sen)
    extracted_names = []
    for idx in range(len(ner_results)):
        if ner_results[idx]['entity'][0] == 'B':
            start = ner_results[idx]['start']
            end = ner_results[idx]['end']
            j = idx+1
            while j < len(ner_results):
                if ner_results[j]['entity'][0] == 'B':
                    break
                elif ner_results[j]['entity'][0] == 'I':
                    end = ner_results[j]['end']
                j+=1
            idx = j
        
            extracted_names.append(sen[start:end].strip())
    
    
    return extracted_names


# example extractor function that uses training labels 
sent_to_tag = dict(zip(data['sent'],data['labels']))
def tag_extraction_from_tags(sent, sent_to_tag=sent_to_tag):

    tags = sent_to_tag[sent]
    sentsWithtags = [(s,t) for s,t in zip(sent.split(),tags.split())]
    entity_list = []
    for i,item in enumerate(sentsWithtags):
        if 'B-' in item[1]:
            j = i
            entity = []
            while j<len(sentsWithtags):
                if sentsWithtags[j][1] =='O':
                    break
                entity.append(sentsWithtags[j][0])
                j+=1
            i = j
         
            entity_list.append(" ".join(entity))
            
    

    return entity_list




In [None]:
#create pipline object:
#param: extractor: an entity extractor function that returns all the entities from a sentence
#param: max_sen: define the number of sentences to be added for each detected entity
#param: lang: define language. needed for wikipedia api
#param: saveJson: whether to save extracted informaton as json file. Saves time if needed to run the pipeline again
#param: loadJson: if you have saved a json file and want to use it
#param: jsonPath: define saved json file path


infoPipeline = InformationExtractionPipeline(extractor = tag_extraction_from_tags if extractor=='all' else tag_extraction_from_LM, 
                                        max_sen = 2, lang = lang, 
                                        loadJson = True, jsonPath=f'wiki-info-{lang}-{Set}.json',
                                        saveJson=True, saveJsonpath=f'wiki-info-{lang}-{Set}.json')

In [None]:
#call pipline and provide list of sentences as argument

augmented = infoPipeline(data['sent'].values.tolist())

### Info Percentage

In [None]:
data['augmented_sen'] = augmented
temp = data[data['sent']!=data['augmented_sen']]


In [None]:
info_percent = temp.shape[0]/data.shape[0]
print(f"Info Percentage: {info_percent*100:.2f}%")

### Save Augmented Data

In [None]:
data.to_csv(f'../Dataset/{Set}-wiki-{lang}-{extractor}.csv',index=False)