#### Set the general settings

In [288]:
RANDOMSTATE = 1234

In [289]:
MODEL = 'peregrine/t5_t5base/models/pytorch_model_semtag_base_split.bin'
ORIGINAL_MODEL = 't5-base'
CONFIG = 'configs/t5_base-config.json'
PREFIX = 'semtag'
CSV_LOC = 'peregrine/t5_t5base'

#### Imports

In [290]:
'''Partially based on https://github.com/MathewAlexander/T5_nlg'''

import sys
import glob
import os
import pandas as pd

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration, set_seed

from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

from pathlib import Path
from datetime import datetime

from tqdm import tqdm

#### Set random state

In [291]:
torch.manual_seed(RANDOMSTATE)
set_seed(RANDOMSTATE)

#### Setup functions for generation and scores

In [292]:
def gpu_checker():
    '''Checking for the GPU availability.'''
    if torch.cuda.is_available():
        dev = torch.device("cuda:0")
        print("Running on the GPU")
    else:
        dev = torch.device("cpu")
        print("Running on the CPU")

    return dev

In [293]:
def generate(text):
    '''Generates a sentence given a specific
    pytorch_model.bin and returns it.
    '''
    extra_tokens = ['~', 'ø']
    # model = T5ForConditionalGeneration.from_pretrained('pytorch_model.bin', return_dict=True, config='t5-small-config.json')
    dev = gpu_checker()
    model = T5ForConditionalGeneration.from_pretrained(MODEL, return_dict=True, config=CONFIG)
    model.to(dev)
    
    tokenizer = T5Tokenizer.from_pretrained(ORIGINAL_MODEL, model_max_length=512)

    new_tokens = tokenizer.additional_special_tokens + extra_tokens
    tokenizer = T5Tokenizer.from_pretrained(ORIGINAL_MODEL, additional_special_tokens=new_tokens, model_max_length=512)
    # print('\nSpecial tokens: ', tokenizer.get_added_vocab())

    input_ids = tokenizer.encode("{}: {}".format(PREFIX, text), return_tensors="pt").to(dev)  # Batch size 1
    outputs = model.generate(input_ids, num_beams=10, max_length=500).to(dev)
    gen_text = tokenizer.decode(outputs[0]).replace('<pad>','').replace('</s>','')

    return gen_text.lstrip().rstrip().replace(' ~ ', '~').replace(' ø ', 'ø')

In [294]:
def missing_item_checker(cur_output: str, cur_input_token_list: list, cur_input_tag_list: list):
    '''Checks for missing items using multiple conditions and adds them as MISSING to the tags.'''
    if ';;' in cur_output:
        cur_output.replace(';;', '";" ;')
    elif '; ;' in cur_output:
        cur_output.replace('; ;', '";" ;')

    output_list = cur_output.split(';')

    output_list_tag = []
    output_list_token = []
    cntr = -1

    for item in output_list:
        cntr += 1
        if ':' in item:
            tag, token = item.split(':', 1)
               
            if '";"' in token:
                token.replace('";"', ';')
            output_list_tag.append(tag.lstrip().rstrip())
            output_list_token.append(token.lstrip().rstrip())
        else:
            if '";"' in token:
                token.replace('";"', ';')
            if item.lstrip().rstrip() in cur_input_token_list:
                output_list_tag.append('MISSING')
                output_list_token.append(item.lstrip().rstrip())
            elif item.lstrip().rstrip() in cur_input_tag_list:
                output_list_tag.append(item.lstrip().rstrip())
                output_list_token.append(cur_input_token_list[cntr])
            else:
                output_list_tag.append('MISSING')
                output_list_token.append('MISSING')
                
    cntr = -1
    if len(cur_input_token_list) != len(output_list_tag):
        # print('The amount of items is not the same, starting missing item operation.')
        # for token in tqdm(cur_input_token_list, desc='Looking for missing items'):
        for token in cur_input_token_list:
            cntr += 1
            if token not in output_list_token:
                output_list_tag.insert(cntr, 'MISSING')

    return output_list_token, output_list_tag

In [295]:
def batch_generator(df: pd.DataFrame):
    '''Generates the t5 output for each row in a given dataframe.'''
    df_output_strings = []

    for idx, row in tqdm(df.iterrows(), total=df.shape[0], desc='Generating T5 output for each row'):

        input_token_list = row['token']
        input_tag_list = row['semtag']
        input_token_str = ' '.join(input_token_list)
        cur_generation = generate(input_token_str)
        df_output_strings.append(cur_generation)

    df['generated_strings'] = df_output_strings

    return df

In [296]:
def taglen(df: pd.DataFrame, df_col):
    '''Returns the length for the given column in each row of a dataframe.'''
    len_list = [len(row[df_col]) for idx, row in df.iterrows()]
    return sum(len_list), len_list

In [297]:
def test_importer():
    '''Imports the train and test dataframes from the csv data.'''
    df = pd.read_csv('../Data/sem-pmb_4_0_0-gold.csv')

    grouped_sentences = df.groupby('sent_file').agg({'token': list, 'lemma': list, 'from': list, 'to': list, 'semtag': list}).reset_index()

    output_sentences = []
    for index, data in grouped_sentences.iterrows():
        current_sentence = []
        for word in data[1]:
            current_sentence.append('{}: '.format(data[-1][data[1].index(word)]))

            current_sentence.append(word)
            if word != data[1][-1]:
                current_sentence.append('; ')

        output_sentences.append(current_sentence)

    grouped_sentences['output'] = output_sentences
       
    df_train, df_test = train_test_split(grouped_sentences, test_size=0.2, random_state=RANDOMSTATE)
    
    return df_train, df_test


In [298]:
df_train, df_test = test_importer()

In [None]:
df_test.head()

Problematic output example index 14: 'DEF:ø; PER: Yunus ; EPS: founded ; DEF: the ; ORG: Grameen~Bank ; DIS:ø; 30 ; UOM: years ; PST: ago ; NIL:.'

30 has no tag

In [None]:
df_test = batch_generator(df_test)

In [149]:
output_path = f'test_outputs/{ORIGINAL_MODEL}'
Path(output_path).mkdir(parents=True, exist_ok=True)

now = datetime.now()
timestring = now.strftime('%Y-%m-%d_%H-%M-%S')

output_name = f'{output_path}/{timestring}-test_output_{ORIGINAL_MODEL}'

df_test.to_json(f'{output_name}.json')

In [262]:
df_test = pd.read_json(f'{output_name}.json')

In [263]:
df_test.drop(columns=['from', 'to', 'lemma'], inplace=True)

In [264]:
df_test.head()

Unnamed: 0,sent_file,token,semtag,output,generated_strings
10136,pmb-4.0.0/data/en/gold/p94/d1295/en.drs.xml,"[On, ø, October, 2, ,, 1942, ,, he, was, at, t...","[REL, DIS, MOY, DOM, EQU, YOC, NIL, PRO, EPS, ...","[REL: , On, ; , DIS: , ø, ; , MOY: , October, ...","REL: On ; DEF:ø; MOY: October ; DOM: 2 ; NIL:,..."
2673,pmb-4.0.0/data/en/gold/p14/d1574/en.drs.xml,"[What, kind, of, ø, American, accent, does, ø,...","[QUE, CON, REL, DIS, GPO, CON, NOW, DEF, PER, ...","[QUE: , What, ; , CON: , kind, ; , REL: , of, ...",QUE: What ; CON: kind ; REL: of ; DIS:ø; GPO: ...
9570,pmb-4.0.0/data/en/gold/p88/d1269/en.drs.xml,"[The, two, truck~drivers, were, arrested, .]","[DEF, QUC, ROL, PST, EXS, NIL]","[DEF: , The, ; , QUC: , two, ; , ROL: , truck~...",DEF: The ; QUC: two ; ROL: truck~drivers ; PST...
5608,pmb-4.0.0/data/en/gold/p42/d0760/en.drs.xml,"[ø, Hooper, bought, a, house, in, ø, Portland, .]","[DEF, PER, EPS, DIS, CON, REL, DEF, GPE, NIL]","[DEF: , ø, ; , PER: , Hooper, ; , EPS: , bough...",DEF:ø; PER: Hooper ; EPS: bought ; DIS: a ; CO...
9455,pmb-4.0.0/data/en/gold/p86/d2746/en.drs.xml,"[We, went, fishing, in, the, lake, .]","[PRO, PST, EXG, REL, DEF, CON, NIL]","[PRO: , We, ; , PST: , went, ; , EXG: , fishin...",PRO: We ; PST: went ; EXG: fishing ; REL: in ;...


In [265]:
df_test.iloc[115].generated_strings

'DEF:ø; CTC: http://www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.www.'

In [266]:
output_list_token_list = []
output_list_tag_list = []

for idx, row in tqdm(df_test.iterrows(), total=df_test.shape[0]):
    output_list_token = []
    output_list_tag = []

    output_list_token, output_list_tag = missing_item_checker(row['generated_strings'], row['token'], row['semtag'])
    
    output_list_token_list.append(output_list_token)
    output_list_tag_list.append(output_list_tag)

df_test['generated_tokens'] = output_list_token_list
df_test['generated_tags'] = output_list_tag_list

100%|██████████| 2143/2143 [00:00<00:00, 24365.46it/s]


In [267]:
df_test.head()

Unnamed: 0,sent_file,token,semtag,output,generated_strings,generated_tokens,generated_tags
10136,pmb-4.0.0/data/en/gold/p94/d1295/en.drs.xml,"[On, ø, October, 2, ,, 1942, ,, he, was, at, t...","[REL, DIS, MOY, DOM, EQU, YOC, NIL, PRO, EPS, ...","[REL: , On, ; , DIS: , ø, ; , MOY: , October, ...","REL: On ; DEF:ø; MOY: October ; DOM: 2 ; NIL:,...","[On, ø, October, 2, ,, 1942, ,, he, was, at, t...","[REL, DEF, MOY, DOM, NIL, YOC, NIL, PRO, EPS, ..."
2673,pmb-4.0.0/data/en/gold/p14/d1574/en.drs.xml,"[What, kind, of, ø, American, accent, does, ø,...","[QUE, CON, REL, DIS, GPO, CON, NOW, DEF, PER, ...","[QUE: , What, ; , CON: , kind, ; , REL: , of, ...",QUE: What ; CON: kind ; REL: of ; DIS:ø; GPO: ...,"[What, kind, of, ø, American, accent, does, ø,...","[QUE, CON, REL, DIS, GPO, CON, NOW, DEF, PER, ..."
9570,pmb-4.0.0/data/en/gold/p88/d1269/en.drs.xml,"[The, two, truck~drivers, were, arrested, .]","[DEF, QUC, ROL, PST, EXS, NIL]","[DEF: , The, ; , QUC: , two, ; , ROL: , truck~...",DEF: The ; QUC: two ; ROL: truck~drivers ; PST...,"[The, two, truck~drivers, were, arrested, .]","[DEF, QUC, ROL, PST, EXS, NIL]"
5608,pmb-4.0.0/data/en/gold/p42/d0760/en.drs.xml,"[ø, Hooper, bought, a, house, in, ø, Portland, .]","[DEF, PER, EPS, DIS, CON, REL, DEF, GPE, NIL]","[DEF: , ø, ; , PER: , Hooper, ; , EPS: , bough...",DEF:ø; PER: Hooper ; EPS: bought ; DIS: a ; CO...,"[ø, Hooper, bought, a, house, in, ø, Portland, .]","[DEF, PER, EPS, DIS, CON, REL, DEF, GPE, NIL]"
9455,pmb-4.0.0/data/en/gold/p86/d2746/en.drs.xml,"[We, went, fishing, in, the, lake, .]","[PRO, PST, EXG, REL, DEF, CON, NIL]","[PRO: , We, ; , PST: , went, ; , EXG: , fishin...",PRO: We ; PST: went ; EXG: fishing ; REL: in ;...,"[We, went, fishing, in, the, lake, .]","[PRO, PST, EXG, REL, DEF, CON, NIL]"


In [268]:
df_pred_len, df_pred_len_list = taglen(df_test, 'generated_tags')
df_gold_len, df_gold_len_list = taglen(df_test, 'semtag')

print(f'prediction_total_tokens: {df_pred_len}')
print(f'gold_total_tokens: {df_gold_len}')

prediction_total_tokens: 15751
gold_total_tokens: 15470


In [269]:
print(len(df_pred_len_list), len(df_gold_len_list))

2143 2143


In [270]:
if df_pred_len != df_gold_len:
    dis = [i for i in range(len(df_gold_len_list)) if df_pred_len_list[i] != df_gold_len_list[i]]
    # dis = []
    # for i in range(len(df_gold_len_list)):
    #     if df_pred_len_list[i] != df_gold_len_list[i]:
    #         dis.append(i)
        
print('The following sentence indices differ: {}'.format(', '.join(map(str, dis))))

The following sentence indices differ: 46, 91, 105, 170, 235, 256, 349, 482, 483, 497, 516, 517, 572, 584, 715, 726, 760, 807, 814, 896, 900, 923, 937, 1036, 1063, 1122, 1145, 1232, 1270, 1302, 1454, 1579, 1611, 1624, 1675, 1682, 1699, 1703, 1806, 1861, 1982, 2057, 2090, 2107, 2125, 2136


In [271]:
def comparer(idx):
    print('ID: {}\nGold({}): {}\nPred({}): {}\n'.format(idx, len(df_test.iloc[idx].semtag), df_test.iloc[idx].semtag, len(df_test.iloc[idx].generated_tags), df_test.iloc[idx].generated_tags))

In [272]:
pred_tags = df_test.generated_tags.to_list()

In [273]:
for i in dis:
    del pred_tags[i][len(df_test.iloc[i].semtag):]

In [274]:
df_test['generated_tags'] = pred_tags

In [275]:
df_pred_len, df_pred_len_list = taglen(df_test, 'generated_tags')
df_gold_len, df_gold_len_list = taglen(df_test, 'semtag')

print(f'prediction_total_tokens: {df_pred_len}')
print(f'gold_total_tokens: {df_gold_len}')

prediction_total_tokens: 15464
gold_total_tokens: 15470


In [276]:
if df_pred_len != df_gold_len:
    dis = [i for i in range(len(df_gold_len_list)) if df_pred_len_list[i] != df_gold_len_list[i]]
    # dis = []
    # for i in range(len(df_gold_len_list)):
    #     if df_pred_len_list[i] != df_gold_len_list[i]:
    #         dis.append(i)
        
print('The following sentence indices differ: {}'.format(', '.join(map(str, dis))))

The following sentence indices differ: 349, 497, 937, 1302, 1699, 1861


In [277]:
for i in dis[::-1]:
    df_test.drop(df_test.iloc[i].name, inplace=True)

In [279]:
df_pred_len, df_pred_len_list = taglen(df_test, 'generated_tags')
df_gold_len, df_gold_len_list = taglen(df_test, 'semtag')

print(f'prediction_total_tokens: {df_pred_len}')
print(f'gold_total_tokens: {df_gold_len}')

prediction_total_tokens: 15388
gold_total_tokens: 15388


In [280]:
if df_pred_len != df_gold_len:
    dis = [i for i in range(len(df_gold_len_list)) if df_pred_len_list[i] != df_gold_len_list[i]]
    # dis = []
    # for i in range(len(df_gold_len_list)):
    #     if df_pred_len_list[i] != df_gold_len_list[i]:
    #         dis.append(i)
        
print('The following sentence indices differ: {}'.format(', '.join(map(str, dis))))

The following sentence indices differ: 349, 497, 937, 1302, 1699, 1861


In [287]:
print(classification_report(df_test['semtag'].sum(), df_test['generated_tags'].sum()))

              precision    recall  f1-score   support

         ALT       0.98      0.89      0.93        47
         AND       0.96      0.74      0.84       104
         APX       0.94      0.89      0.91        18
         ART       0.36      0.45      0.40        11
         BOT       1.00      1.00      1.00         2
         BUT       1.00      0.89      0.94        19
         CLO       1.00      0.90      0.95        39
         COL       0.89      0.93      0.91        27
         CON       0.97      0.98      0.98      1760
         COO       0.81      0.85      0.83        26
         CTC       1.00      1.00      1.00         4
         DEF       0.93      0.96      0.94      1597
         DEG       0.75      0.90      0.82        60
         DIS       0.90      0.87      0.89       980
         DOM       1.00      0.70      0.82        10
         DOW       1.00      1.00      1.00         5
         DST       0.96      0.96      0.96        51
         EFS       0.62    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
