# Neural German Open IE (NeuralGerOIE)

### Setup

In [3]:
# !pip install pandas



In [4]:
!pip install simpletransformers

Collecting simpletransformers
  Downloading simpletransformers-0.63.6-py3-none-any.whl (249 kB)
[K     |████████████████████████████████| 249 kB 9.1 MB/s 
[?25hCollecting transformers>=4.6.0
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
[K     |████████████████████████████████| 4.0 MB 61.4 MB/s 
Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[K     |████████████████████████████████| 43 kB 2.0 MB/s 
Collecting wandb>=0.10.32
  Downloading wandb-0.12.14-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 11.9 MB/s 
[?25hCollecting streamlit
  Downloading streamlit-1.8.1-py2.py3-none-any.whl (10.1 MB)
[K     |████████████████████████████████| 10.1 MB 41.4 MB/s 
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 48.9 MB/s 
[?25hCollecting datasets
  Downloading datasets-2.0.0-py3-none-any.whl (325

In [1]:
import os
import re
import json
import logging
import numpy as np
import pandas as pd
from time import time

from google.colab import files
from google.colab import drive

from sklearn.model_selection import train_test_split
from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs

In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
%cd /content/drive/My\ Drive/Colab Notebooks

/content/drive/My Drive/Colab Notebooks


In [5]:
path = os.path.dirname(os.path.realpath('__file__'))
print(path)

/content/drive/My Drive/Colab Notebooks


### Load WikiData

In [6]:
df = pd.read_csv("wiki_data.tsv", sep="\t").rename(columns={"sentences":"input_text","labels":"target_text"})
print(df.shape)
df.sample(5)

(6453, 2)


Unnamed: 0,input_text,target_text
4632,Tahanea ist ein großes unbewohntes Atoll des T...,<arg1> Tahanea </arg1> <rel> Archipel </rel> <...
4857,Die Canadian Championship 2015 (offiziell: 201...,<arg1> Canadian Championship 2015 </arg1> <rel...
412,Das Erdbeben in Pakistan 2008 mit der Magnitud...,<arg1> Erdbeben in Pakistan 2008 </arg1> <rel>...
5837,"Das Mineral zeigt keine Spaltbarkeit, die Art ...",<arg1> Zinkgartrellit </arg1> <rel> Spaltbarke...
831,Die Haltestelle Wien Praterstern im 2. Wiener ...,<arg1> Haltestelle Wien Praterstern </arg1> <r...


### Data Prep

In [7]:
def change_args(label):
    label = label.replace("<arg1> ", "<sub> ")
    label = label.replace("<rel>" ," <rel>")
    label = label.replace("<arg2>", " <obj>")
    label = label.replace(" </rel> ", "")
    label = label.replace(" </arg1> " ,"")
    label = label.replace("</arg2>", "<end>")
    return label

def rem_end_middle(label):
    str1 = ' <end> <sub>'
    str2 = '<rel>'
    newstring = ''

    reg = "(?<=%s).*?(?=%s)" % (str1, str2)
    r = re.compile(reg, re.DOTALL)
    result = r.sub(newstring, label)
    result = result.replace("<end> ", "")
    result = result.replace("<sub><rel> ", "<rel> ")
    return result

def find_nth(haystack, needle, n):
    start = haystack.find(needle)
    while start >= 0 and n > 1:
        start = haystack.find(needle, start+len(needle))
        n -= 1
    return start

def rebuild_label(label, index):
    nth_rel = find_nth(label, "<rel>", 0)
    first = label[:nth_rel]
    middle = f"<rel{index}>"
    last = label[nth_rel+5:]
    rebuild = first + middle + last
    
    nth_obj = find_nth(rebuild, "<obj>", 0)
    first = rebuild[:nth_obj]
    middle = f"<obj{index}> "
    last = label[nth_obj+4:]
    rebuild = first + middle + last
    return rebuild

def relabel(label):
    simple_str = rem_end_middle(label)
    idx = 0
    while "<rel>" in simple_str:
        simple_str = rebuild_label(simple_str, idx)
        idx +=1
    return simple_str

In [8]:
df["target_text"] = df.target_text.apply(change_args)

In [9]:
df = df.groupby('input_text').agg({'target_text': ' '.join}).reset_index()
df["target_text"] = df.target_text.apply(relabel)
df.sample(5)

Unnamed: 0,input_text,target_text
5103,Vom Schreierkopf nach Süden senkt sich der Gra...,<sub> Schreierkopf <rel0> Scharte <obj0> Kreu...
2507,Die Pine Mountain Jump wurde seither oftmals r...,<sub> Pine Mountain Jump <rel0> renoviert <obj...
8,Auch mit den öffentlichen Verkehrsmitteln ist...,<sub> Obertürkheim <rel0> Bus <obj0> 61 <rel1...
4769,Selig sind die Dummen (engl. Yokel Chords) ist...,<sub> Selig sind die Dummen <rel0> Fernsehseri...
2861,Die heutige Gemeinde Zagori besteht aus 5 Geme...,<sub> Zagori <rel0> Gemeindebezirke <obj0> 5 ...


In [10]:
train_df, eval_df = train_test_split(df, test_size=0.1, random_state=42)
eval_df, test_df = train_test_split(eval_df, test_size=0.5, random_state=42)
print(train_df.shape)
print("---")
print(eval_df.shape)
print("---")
print(test_df.shape)

(4834, 2)
---
(269, 2)
---
(269, 2)


### Model Setup & Training

In [11]:
# Configure the model
model_args = Seq2SeqArgs()
model_args.num_train_epochs = 10
model_args.train_batch_size = 8
model_args.per_device_train_batch_size=8
model_args.per_device_eval_batch_size=8
model_args.evaluate_generated_text = True
model_args.evaluate_during_training = True
model_args.evaluate_during_training_steps = 10000
model_args.evaluate_during_training_verbose = True
model_args.overwrite_output_dir = True
model_args.save_best_model = True
model_args.save_steps = -1
model_args.evaluation_strategy='steps',
model_args.save_eval_checkpoints = False
model_args.save_model_every_epoch = False
model_args.no_cache = True
model_args.save_optimizer_and_scheduler = True
model_args.max_length = 300
model_args.adafactor_decay_rate = 0.7
model_args.process_count = 4
model_args.save_total_limit=3,
model_args.num_beams = 3
model_args.manual_seed = 42
model_args.load_best_model_at_end=True
model_args.n_gpu = 1 # 4

In [20]:
def count_matches(labels, preds):
    print(f"labels: {labels}")
    print(f"preds: {preds}")
    return sum([1 if label == pred else 0 for label, pred in zip(labels, preds)])

def count_in_matches(labels, preds):
    return sum([1 if pred in label else 0 for label, pred in zip(labels, preds)])

def in_matches(labels, preds):
    return [pred for label, pred in zip(labels, preds) if pred in label]

def out_matches(labels, preds):
    return [pred for label, pred in zip(labels, preds) if pred not in label]

In [13]:
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

model = Seq2SeqModel(
    encoder_decoder_type="bart", 
    encoder_decoder_name="Shahm/bart-german",
    args=model_args,
    use_cuda=True, 
)

Downloading:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/532M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/353 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/780k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

In [None]:
%%time
#Train the model
model.train_model(train_df, eval_data=eval_df, matches=count_matches)

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/4834 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model: Training started


Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Running Epoch 0 of 10:   0%|          | 0/605 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/269 [00:00<?, ?it/s]

Generating outputs:   0%|          | 0/34 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:{'eval_loss': 0.3926900208872907, 'matches': 31}


labels: ['<sub> Die Quereinsteigerinnen <rel0> per <obj0>  Christian Mrasek <end>', '<sub> The Transporter <rel0> Gang <obj0>  ja <end>', '<sub> San Junipero <rel0> Kanal <obj0>  Netflix <end>', '<sub> King William Island <rel0> Hauptort <obj0>  Gjoa Haven <end>', '<sub> Dar Bouazza <rel0> Region <obj0>  Casablanca-Settat <rel1> Hinterland <obj1>  Nouaceur <end>', '<sub> División de Honor de Rugby <rel0> Union <obj0>  Federación Española de Rugby <end>', '<sub> 9K37 Buk <rel0> Entwicklung <obj0>  1972 <end>', '<sub> Jalisco B <rel0> Rasse <obj0>  Selle Français <end>', '<sub> Einführungsgesetz zum Gerichtsverfassungsgesetz <rel0> Abkürzung <obj0>  EGGVG <rel1> Abkürzung <obj1>  GVGEG <end>', '<sub> Suribachi <rel0> Berg <obj0>  170 <rel1> Stelle <obj1>  Iwojima <end>', '<sub> Impulse Coaster <rel0> Hersteller <obj0>  Intamin <rel1> Kategorie <obj1>  Launched Coaster <rel2> Kategorie <obj2>  Shuttle Coaster <end>', '<sub> Schenker Storen <rel0> Wort für <obj0>  Schenker Storen AG <end>'

INFO:simpletransformers.seq2seq.seq2seq_model:Saving model into outputs/best_model


Running Epoch 1 of 10:   0%|          | 0/605 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/269 [00:00<?, ?it/s]

Generating outputs:   0%|          | 0/34 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:{'eval_loss': 0.32038135695106845, 'matches': 51}
INFO:simpletransformers.seq2seq.seq2seq_model:Saving model into outputs/best_model


labels: ['<sub> Die Quereinsteigerinnen <rel0> per <obj0>  Christian Mrasek <end>', '<sub> The Transporter <rel0> Gang <obj0>  ja <end>', '<sub> San Junipero <rel0> Kanal <obj0>  Netflix <end>', '<sub> King William Island <rel0> Hauptort <obj0>  Gjoa Haven <end>', '<sub> Dar Bouazza <rel0> Region <obj0>  Casablanca-Settat <rel1> Hinterland <obj1>  Nouaceur <end>', '<sub> División de Honor de Rugby <rel0> Union <obj0>  Federación Española de Rugby <end>', '<sub> 9K37 Buk <rel0> Entwicklung <obj0>  1972 <end>', '<sub> Jalisco B <rel0> Rasse <obj0>  Selle Français <end>', '<sub> Einführungsgesetz zum Gerichtsverfassungsgesetz <rel0> Abkürzung <obj0>  EGGVG <rel1> Abkürzung <obj1>  GVGEG <end>', '<sub> Suribachi <rel0> Berg <obj0>  170 <rel1> Stelle <obj1>  Iwojima <end>', '<sub> Impulse Coaster <rel0> Hersteller <obj0>  Intamin <rel1> Kategorie <obj1>  Launched Coaster <rel2> Kategorie <obj2>  Shuttle Coaster <end>', '<sub> Schenker Storen <rel0> Wort für <obj0>  Schenker Storen AG <end>'

Running Epoch 2 of 10:   0%|          | 0/605 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/269 [00:00<?, ?it/s]

Generating outputs:   0%|          | 0/34 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:{'eval_loss': 0.29620879143476486, 'matches': 49}
INFO:simpletransformers.seq2seq.seq2seq_model:Saving model into outputs/best_model


labels: ['<sub> Die Quereinsteigerinnen <rel0> per <obj0>  Christian Mrasek <end>', '<sub> The Transporter <rel0> Gang <obj0>  ja <end>', '<sub> San Junipero <rel0> Kanal <obj0>  Netflix <end>', '<sub> King William Island <rel0> Hauptort <obj0>  Gjoa Haven <end>', '<sub> Dar Bouazza <rel0> Region <obj0>  Casablanca-Settat <rel1> Hinterland <obj1>  Nouaceur <end>', '<sub> División de Honor de Rugby <rel0> Union <obj0>  Federación Española de Rugby <end>', '<sub> 9K37 Buk <rel0> Entwicklung <obj0>  1972 <end>', '<sub> Jalisco B <rel0> Rasse <obj0>  Selle Français <end>', '<sub> Einführungsgesetz zum Gerichtsverfassungsgesetz <rel0> Abkürzung <obj0>  EGGVG <rel1> Abkürzung <obj1>  GVGEG <end>', '<sub> Suribachi <rel0> Berg <obj0>  170 <rel1> Stelle <obj1>  Iwojima <end>', '<sub> Impulse Coaster <rel0> Hersteller <obj0>  Intamin <rel1> Kategorie <obj1>  Launched Coaster <rel2> Kategorie <obj2>  Shuttle Coaster <end>', '<sub> Schenker Storen <rel0> Wort für <obj0>  Schenker Storen AG <end>'

Running Epoch 3 of 10:   0%|          | 0/605 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/269 [00:00<?, ?it/s]

Generating outputs:   0%|          | 0/34 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:{'eval_loss': 0.31156429955187964, 'matches': 60}


labels: ['<sub> Die Quereinsteigerinnen <rel0> per <obj0>  Christian Mrasek <end>', '<sub> The Transporter <rel0> Gang <obj0>  ja <end>', '<sub> San Junipero <rel0> Kanal <obj0>  Netflix <end>', '<sub> King William Island <rel0> Hauptort <obj0>  Gjoa Haven <end>', '<sub> Dar Bouazza <rel0> Region <obj0>  Casablanca-Settat <rel1> Hinterland <obj1>  Nouaceur <end>', '<sub> División de Honor de Rugby <rel0> Union <obj0>  Federación Española de Rugby <end>', '<sub> 9K37 Buk <rel0> Entwicklung <obj0>  1972 <end>', '<sub> Jalisco B <rel0> Rasse <obj0>  Selle Français <end>', '<sub> Einführungsgesetz zum Gerichtsverfassungsgesetz <rel0> Abkürzung <obj0>  EGGVG <rel1> Abkürzung <obj1>  GVGEG <end>', '<sub> Suribachi <rel0> Berg <obj0>  170 <rel1> Stelle <obj1>  Iwojima <end>', '<sub> Impulse Coaster <rel0> Hersteller <obj0>  Intamin <rel1> Kategorie <obj1>  Launched Coaster <rel2> Kategorie <obj2>  Shuttle Coaster <end>', '<sub> Schenker Storen <rel0> Wort für <obj0>  Schenker Storen AG <end>'

Running Epoch 4 of 10:   0%|          | 0/605 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/269 [00:00<?, ?it/s]

Generating outputs:   0%|          | 0/34 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:{'eval_loss': 0.3165673552190556, 'matches': 66}


labels: ['<sub> Die Quereinsteigerinnen <rel0> per <obj0>  Christian Mrasek <end>', '<sub> The Transporter <rel0> Gang <obj0>  ja <end>', '<sub> San Junipero <rel0> Kanal <obj0>  Netflix <end>', '<sub> King William Island <rel0> Hauptort <obj0>  Gjoa Haven <end>', '<sub> Dar Bouazza <rel0> Region <obj0>  Casablanca-Settat <rel1> Hinterland <obj1>  Nouaceur <end>', '<sub> División de Honor de Rugby <rel0> Union <obj0>  Federación Española de Rugby <end>', '<sub> 9K37 Buk <rel0> Entwicklung <obj0>  1972 <end>', '<sub> Jalisco B <rel0> Rasse <obj0>  Selle Français <end>', '<sub> Einführungsgesetz zum Gerichtsverfassungsgesetz <rel0> Abkürzung <obj0>  EGGVG <rel1> Abkürzung <obj1>  GVGEG <end>', '<sub> Suribachi <rel0> Berg <obj0>  170 <rel1> Stelle <obj1>  Iwojima <end>', '<sub> Impulse Coaster <rel0> Hersteller <obj0>  Intamin <rel1> Kategorie <obj1>  Launched Coaster <rel2> Kategorie <obj2>  Shuttle Coaster <end>', '<sub> Schenker Storen <rel0> Wort für <obj0>  Schenker Storen AG <end>'

Running Epoch 5 of 10:   0%|          | 0/605 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/269 [00:00<?, ?it/s]

Generating outputs:   0%|          | 0/34 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:{'eval_loss': 0.3066762194913976, 'matches': 68}


labels: ['<sub> Die Quereinsteigerinnen <rel0> per <obj0>  Christian Mrasek <end>', '<sub> The Transporter <rel0> Gang <obj0>  ja <end>', '<sub> San Junipero <rel0> Kanal <obj0>  Netflix <end>', '<sub> King William Island <rel0> Hauptort <obj0>  Gjoa Haven <end>', '<sub> Dar Bouazza <rel0> Region <obj0>  Casablanca-Settat <rel1> Hinterland <obj1>  Nouaceur <end>', '<sub> División de Honor de Rugby <rel0> Union <obj0>  Federación Española de Rugby <end>', '<sub> 9K37 Buk <rel0> Entwicklung <obj0>  1972 <end>', '<sub> Jalisco B <rel0> Rasse <obj0>  Selle Français <end>', '<sub> Einführungsgesetz zum Gerichtsverfassungsgesetz <rel0> Abkürzung <obj0>  EGGVG <rel1> Abkürzung <obj1>  GVGEG <end>', '<sub> Suribachi <rel0> Berg <obj0>  170 <rel1> Stelle <obj1>  Iwojima <end>', '<sub> Impulse Coaster <rel0> Hersteller <obj0>  Intamin <rel1> Kategorie <obj1>  Launched Coaster <rel2> Kategorie <obj2>  Shuttle Coaster <end>', '<sub> Schenker Storen <rel0> Wort für <obj0>  Schenker Storen AG <end>'

Running Epoch 6 of 10:   0%|          | 0/605 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/269 [00:00<?, ?it/s]

Generating outputs:   0%|          | 0/34 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:{'eval_loss': 0.3164037825430141, 'matches': 60}


labels: ['<sub> Die Quereinsteigerinnen <rel0> per <obj0>  Christian Mrasek <end>', '<sub> The Transporter <rel0> Gang <obj0>  ja <end>', '<sub> San Junipero <rel0> Kanal <obj0>  Netflix <end>', '<sub> King William Island <rel0> Hauptort <obj0>  Gjoa Haven <end>', '<sub> Dar Bouazza <rel0> Region <obj0>  Casablanca-Settat <rel1> Hinterland <obj1>  Nouaceur <end>', '<sub> División de Honor de Rugby <rel0> Union <obj0>  Federación Española de Rugby <end>', '<sub> 9K37 Buk <rel0> Entwicklung <obj0>  1972 <end>', '<sub> Jalisco B <rel0> Rasse <obj0>  Selle Français <end>', '<sub> Einführungsgesetz zum Gerichtsverfassungsgesetz <rel0> Abkürzung <obj0>  EGGVG <rel1> Abkürzung <obj1>  GVGEG <end>', '<sub> Suribachi <rel0> Berg <obj0>  170 <rel1> Stelle <obj1>  Iwojima <end>', '<sub> Impulse Coaster <rel0> Hersteller <obj0>  Intamin <rel1> Kategorie <obj1>  Launched Coaster <rel2> Kategorie <obj2>  Shuttle Coaster <end>', '<sub> Schenker Storen <rel0> Wort für <obj0>  Schenker Storen AG <end>'

Running Epoch 7 of 10:   0%|          | 0/605 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/269 [00:00<?, ?it/s]

Generating outputs:   0%|          | 0/34 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:{'eval_loss': 0.33538585857433434, 'matches': 60}


labels: ['<sub> Die Quereinsteigerinnen <rel0> per <obj0>  Christian Mrasek <end>', '<sub> The Transporter <rel0> Gang <obj0>  ja <end>', '<sub> San Junipero <rel0> Kanal <obj0>  Netflix <end>', '<sub> King William Island <rel0> Hauptort <obj0>  Gjoa Haven <end>', '<sub> Dar Bouazza <rel0> Region <obj0>  Casablanca-Settat <rel1> Hinterland <obj1>  Nouaceur <end>', '<sub> División de Honor de Rugby <rel0> Union <obj0>  Federación Española de Rugby <end>', '<sub> 9K37 Buk <rel0> Entwicklung <obj0>  1972 <end>', '<sub> Jalisco B <rel0> Rasse <obj0>  Selle Français <end>', '<sub> Einführungsgesetz zum Gerichtsverfassungsgesetz <rel0> Abkürzung <obj0>  EGGVG <rel1> Abkürzung <obj1>  GVGEG <end>', '<sub> Suribachi <rel0> Berg <obj0>  170 <rel1> Stelle <obj1>  Iwojima <end>', '<sub> Impulse Coaster <rel0> Hersteller <obj0>  Intamin <rel1> Kategorie <obj1>  Launched Coaster <rel2> Kategorie <obj2>  Shuttle Coaster <end>', '<sub> Schenker Storen <rel0> Wort für <obj0>  Schenker Storen AG <end>'

Running Epoch 8 of 10:   0%|          | 0/605 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/269 [00:00<?, ?it/s]

Generating outputs:   0%|          | 0/34 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:{'eval_loss': 0.33745645052369905, 'matches': 64}


labels: ['<sub> Die Quereinsteigerinnen <rel0> per <obj0>  Christian Mrasek <end>', '<sub> The Transporter <rel0> Gang <obj0>  ja <end>', '<sub> San Junipero <rel0> Kanal <obj0>  Netflix <end>', '<sub> King William Island <rel0> Hauptort <obj0>  Gjoa Haven <end>', '<sub> Dar Bouazza <rel0> Region <obj0>  Casablanca-Settat <rel1> Hinterland <obj1>  Nouaceur <end>', '<sub> División de Honor de Rugby <rel0> Union <obj0>  Federación Española de Rugby <end>', '<sub> 9K37 Buk <rel0> Entwicklung <obj0>  1972 <end>', '<sub> Jalisco B <rel0> Rasse <obj0>  Selle Français <end>', '<sub> Einführungsgesetz zum Gerichtsverfassungsgesetz <rel0> Abkürzung <obj0>  EGGVG <rel1> Abkürzung <obj1>  GVGEG <end>', '<sub> Suribachi <rel0> Berg <obj0>  170 <rel1> Stelle <obj1>  Iwojima <end>', '<sub> Impulse Coaster <rel0> Hersteller <obj0>  Intamin <rel1> Kategorie <obj1>  Launched Coaster <rel2> Kategorie <obj2>  Shuttle Coaster <end>', '<sub> Schenker Storen <rel0> Wort für <obj0>  Schenker Storen AG <end>'

Running Epoch 9 of 10:   0%|          | 0/605 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/269 [00:00<?, ?it/s]

Generating outputs:   0%|          | 0/34 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:{'eval_loss': 0.34270044525756554, 'matches': 70}
INFO:simpletransformers.seq2seq.seq2seq_model:Saving model into outputs/


labels: ['<sub> Die Quereinsteigerinnen <rel0> per <obj0>  Christian Mrasek <end>', '<sub> The Transporter <rel0> Gang <obj0>  ja <end>', '<sub> San Junipero <rel0> Kanal <obj0>  Netflix <end>', '<sub> King William Island <rel0> Hauptort <obj0>  Gjoa Haven <end>', '<sub> Dar Bouazza <rel0> Region <obj0>  Casablanca-Settat <rel1> Hinterland <obj1>  Nouaceur <end>', '<sub> División de Honor de Rugby <rel0> Union <obj0>  Federación Española de Rugby <end>', '<sub> 9K37 Buk <rel0> Entwicklung <obj0>  1972 <end>', '<sub> Jalisco B <rel0> Rasse <obj0>  Selle Français <end>', '<sub> Einführungsgesetz zum Gerichtsverfassungsgesetz <rel0> Abkürzung <obj0>  EGGVG <rel1> Abkürzung <obj1>  GVGEG <end>', '<sub> Suribachi <rel0> Berg <obj0>  170 <rel1> Stelle <obj1>  Iwojima <end>', '<sub> Impulse Coaster <rel0> Hersteller <obj0>  Intamin <rel1> Kategorie <obj1>  Launched Coaster <rel2> Kategorie <obj2>  Shuttle Coaster <end>', '<sub> Schenker Storen <rel0> Wort für <obj0>  Schenker Storen AG <end>'

INFO:simpletransformers.seq2seq.seq2seq_model: Training of Shahm/bart-german model complete. Saved to outputs/.


CPU times: user 1h 48min 39s, sys: 35min 56s, total: 2h 24min 35s
Wall time: 2h 25min 29s


(6050,
 {'global_step': [605, 1210, 1815, 2420, 3025, 3630, 4235, 4840, 5445, 6050],
  'eval_loss': [0.3926900208872907,
   0.32038135695106845,
   0.29620879143476486,
   0.31156429955187964,
   0.3165673552190556,
   0.3066762194913976,
   0.3164037825430141,
   0.33538585857433434,
   0.33745645052369905,
   0.34270044525756554],
  'train_loss': [0.48322927951812744,
   0.1532779484987259,
   0.03874498978257179,
   0.03558674827218056,
   0.14402231574058533,
   0.2741043269634247,
   0.07369143515825272,
   0.04165903478860855,
   0.10961683094501495,
   0.06922011077404022],
  'matches': [31, 51, 49, 60, 66, 68, 60, 60, 64, 70]})

### Evaluate the model

In [None]:
result = model.eval_model(eval_df, matches=count_matches)
print(result)

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/269 [00:00<?, ?it/s]

Running Evaluation:   0%|          | 0/34 [00:00<?, ?it/s]

Generating outputs:   0%|          | 0/34 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model:{'eval_loss': 0.34270044525756554, 'matches': 70}


labels: ['<sub> Die Quereinsteigerinnen <rel0> per <obj0>  Christian Mrasek <end>', '<sub> The Transporter <rel0> Gang <obj0>  ja <end>', '<sub> San Junipero <rel0> Kanal <obj0>  Netflix <end>', '<sub> King William Island <rel0> Hauptort <obj0>  Gjoa Haven <end>', '<sub> Dar Bouazza <rel0> Region <obj0>  Casablanca-Settat <rel1> Hinterland <obj1>  Nouaceur <end>', '<sub> División de Honor de Rugby <rel0> Union <obj0>  Federación Española de Rugby <end>', '<sub> 9K37 Buk <rel0> Entwicklung <obj0>  1972 <end>', '<sub> Jalisco B <rel0> Rasse <obj0>  Selle Français <end>', '<sub> Einführungsgesetz zum Gerichtsverfassungsgesetz <rel0> Abkürzung <obj0>  EGGVG <rel1> Abkürzung <obj1>  GVGEG <end>', '<sub> Suribachi <rel0> Berg <obj0>  170 <rel1> Stelle <obj1>  Iwojima <end>', '<sub> Impulse Coaster <rel0> Hersteller <obj0>  Intamin <rel1> Kategorie <obj1>  Launched Coaster <rel2> Kategorie <obj2>  Shuttle Coaster <end>', '<sub> Schenker Storen <rel0> Wort für <obj0>  Schenker Storen AG <end>'

In [17]:
preds = model.predict(test_df.input_text.tolist())

Generating outputs:   0%|          | 0/34 [00:00<?, ?it/s]

In [18]:
count_in_matches(test_df.target_text.tolist(), preds) / test_df.shape[0]

0.3271375464684015

In [19]:
in_matches(test_df.target_text.tolist(), preds)

['<sub> Oxystannomikrolith <rel0> Spaltbarkeit <obj0>  keine Angaben <end>',
 '<sub> Musikkorps der Bundeswehr <rel0> Militär <obj0>  Bund <end>',
 '<sub> Kikuzuki <rel0> Indienststellung <obj0>  20. November 1926 <end>',
 '<sub> Walter H. Gale House <rel0> Architekt <obj0>  Frank Lloyd Wright <end>',
 '<sub> Li <rel0> Graph <obj0>  Li <end>',
 '<sub> Polasna <rel0> Status <obj0>  Siedlung städtischen Typs <end>',
 '<sub> The Weekly Standard <rel0> Staat <obj0>  Staaten <end>',
 '<sub> Sofijiwka-Park <rel0> angelegt <obj0>  1796 <end>',
 '<sub> The Intercept <rel0> online <obj0>  Februar 2014 <end>',
 '<sub> Condor Circuit <rel0> Ausgangspunkt <obj0>  Altos del Lircay <end>',
 '<sub> Tapti <rel0> Einzugsgebiet <obj0>  61575 <end>',
 '<sub> DRK <rel0> Funktion <obj0>  Humanitäres Völkerrecht <end>',
 '<sub> Gespanschaft Istrien <rel0> Städte <obj0>  10 <rel1> Gemeinden <obj1>  31 <end>',
 '<sub> You Don’t Know <rel0> Schrift <obj0>  Eminem <end>',
 '<sub> Space Pirate Captain Herlock: T

In [15]:
model_path = path + "/outputs/best_model"

In [14]:
model = Seq2SeqModel(
    encoder_decoder_type="bart",
    encoder_decoder_name="outputs/best_model",
    args=model_args,
    use_cuda=False
)

In [16]:
# Use the model for prediction
print(
    model.predict(
        [
            'Die Spinnerei und Weberei am Sparrenlech Kahn & Arnold war eine deutsche Baumwollspinnerei und -weberei mit Sitz in Augsburg.'
        ]
    )
)

Generating outputs:   0%|          | 0/1 [00:00<?, ?it/s]

['<sub> Spinnerei und Weberei am Sparrenlech <rel0> Sitz <obj0>  Augsburg <end>']


In [21]:
out_matches(test_df.target_text.tolist(), preds)

['<sub> Suzhou Zhongnan Center <rel0> Innenstadt <obj0>  Suzhous <end>',
 '<sub> Lind ob Velden <rel0> Zählsprengel <obj0>  Lind-Sonnental <end>',
 '<sub> Bischof Genn <rel0> Weihbischof <obj0>  Ludger Schepers <end>',
 '<sub> Barnstable County <rel0> Stätten <obj0>  2 <end>',
 '<sub> ISO 19011 <rel0> ISO <obj0>  9001 <end>',
 '<sub> Daniel Hediger <rel0> Skiclub <obj0>  Bex <end>',
 '<sub> Kleinkainraths <rel0> Grundfläche <obj0>  157,4 Hektar <end>',
 '<sub> RTG Retail Trade Group <rel0> Mandat <obj0>  LEH Sortimente <rel1> Beteiligung <obj1>  Netto ApS & Co. KG <end>',
 '<sub> Atlantis Adventure <rel0> Gefälle <obj0>  72° <end>',
 '<sub> Direktion City <rel0> Sender <obj0>  Freies Berlin <end>',
 '<sub> Potsdam <rel0> Landeshauptstadt <obj0>  Brandenburg <end>',
 '<sub> Werkenntwen <rel0> Nutzer <obj0>  9,6 Millionen <end>',
 '<sub> Obertürkheim <rel0> Bus <obj0>  62 <end>',
 '<sub> Team Zagreb <rel0> Liga <obj0>  Slohokej Liga <end>',
 '<sub> Anatolia Story <rel0> Charakter <obj0> 