In [18]:
import re
import torch
import numpy as np
import pandas as pd
from transformers import BertTokenizer, BertModel, BertConfig, BertForTokenClassification, TextDataset
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, RandomSampler, TensorDataset, SequentialSampler
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from collections import Counter
from sklearn.metrics import classification_report,f1_score, accuracy_score

In [19]:
device = "cuda"

In [20]:
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=995526.0, style=ProgressStyle(descripti…




In [21]:
model = BertForTokenClassification.from_pretrained("bert-base-multilingual-cased")
model.cuda()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=625.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=714314041.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at 

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [22]:
grobid_input = pd.read_csv("./../gold/grobid_hum.tsv", sep="\t", index_col=0)

In [23]:
grobid_sample_id = list(set(grobid_input["sample"]))

In [24]:
def create_bert_input(grobid_input, grobid_sample_id, max_len, pad_token_id, pad_class):

    input_samples = []
    input_masks = []
    input_labels = []
    for sample_id in tqdm(grobid_sample_id):

        sample = grobid_input[grobid_input["sample"] == sample_id]
        input_sample = []
        input_mask = []
        input_label = []
        for index, row in sample.iterrows():

            label = row["class"]
            token = row["text"]
            encoded = tokenizer.encode_plus(token)
            input_sample.append(encoded["input_ids"][1:-1])
            input_mask.append(encoded["attention_mask"][1:-1])
            input_label.append([label]*len(encoded["input_ids"][1:-1]))

        input_sample=[102]+[x for y in input_sample for x in y][:max_len-1]+[103]
        input_mask=[1]+[x for y in input_mask for x in y][:max_len-1]+[1]
        input_label=[0]+[x for y in input_label for x in y][:max_len-1]+[0]

        while len(input_sample) != max_len:
            input_sample.append(pad_token_id)
            input_mask.append(0)
            input_label.append(pad_class)

        input_samples.append(input_sample)
        input_masks.append(input_mask)
        input_labels.append(input_label)


    return torch.tensor(input_samples), torch.tensor(input_masks), torch.tensor(input_labels)

In [25]:
X,mask,Y = create_bert_input(grobid_input.fillna(""), grobid_sample_id, 512, 0, 0)

HBox(children=(FloatProgress(value=0.0, max=6818.0), HTML(value='')))




In [17]:
tokenizer.decode(X[10][5:13])

'T. Sterling, J. Green'

In [26]:
train_dataset = TensorDataset(X[:6816],mask[:6816],Y[:6816])
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=8)

In [27]:
epochs = 10
total_steps = len(train_dataloader) * epochs
optimizer = AdamW(model.parameters(),lr = 2e-5, eps = 1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0, num_training_steps = total_steps)

In [28]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds.reshape([512*8,2]), axis=1)
    labels_flat = labels.flatten()
    return pred_flat, labels_flat

In [29]:
loss_values = []
for epoch_i in range(0, epochs):
 ## Iteration über Epochen

    total_loss = 0 
    model.train() # Trainiert nicht, sondern ändert das Modell .train()==Trainingsmodus
    print("Epoch: "+str(epoch_i))
    i = 0
    for step, batch in enumerate(train_dataloader):
     ## Iteration über batches

        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
          ## Batches auf GPU/CPU schieben

        model.zero_grad()
          ## Gradienten löschen für nächste Batch
        outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
          ## forward pass

        loss = outputs[0] 
        total_loss += loss.item()
          # pull out loss for printing

        loss.backward()
          ## Backward pass 
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
          ## Regularisierung

        optimizer.step() 
        scheduler.step()
          # Upadte optimizer + scheduler


        avg_train_loss = total_loss / (step+1) 
        logits = outputs[1]
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        p,t = flat_accuracy(logits, label_ids)

        if i % 50 == 0:
            print("Epoch:"+str(epoch_i)+" Batch "+str(step)+"/"+str(len(train_dataloader))+" Loss: "+str(i)+" "+str(avg_train_loss)[:5]+"    F1: "+str(f1_score(t, p, average="macro"))[:4])
        i+=1
        if i >= 850:
            break

Epoch: 0
Epoch:0 Batch 0/852 Loss: 0 0.697    F1: 0.37
Epoch:0 Batch 50/852 Loss: 50 0.090    F1: 0.72
Epoch:0 Batch 100/852 Loss: 100 0.062    F1: 0.61
Epoch:0 Batch 150/852 Loss: 150 0.050    F1: 0.63
Epoch:0 Batch 200/852 Loss: 200 0.040    F1: 0.65
Epoch:0 Batch 250/852 Loss: 250 0.034    F1: 0.60
Epoch:0 Batch 300/852 Loss: 300 0.029    F1: 0.61
Epoch:0 Batch 350/852 Loss: 350 0.027    F1: 0.65
Epoch:0 Batch 400/852 Loss: 400 0.025    F1: 0.58
Epoch:0 Batch 450/852 Loss: 450 0.024    F1: 0.60
Epoch:0 Batch 500/852 Loss: 500 0.022    F1: 0.70
Epoch:0 Batch 550/852 Loss: 550 0.021    F1: 0.59
Epoch:0 Batch 600/852 Loss: 600 0.020    F1: 0.65
Epoch:0 Batch 650/852 Loss: 650 0.019    F1: 0.60
Epoch:0 Batch 700/852 Loss: 700 0.019    F1: 0.65
Epoch:0 Batch 750/852 Loss: 750 0.018    F1: 0.62
Epoch:0 Batch 800/852 Loss: 800 0.018    F1: 0.60
Epoch: 1
Epoch:1 Batch 0/852 Loss: 0 0.000    F1: 0.60
Epoch:1 Batch 50/852 Loss: 50 0.004    F1: 0.68
Epoch:1 Batch 100/852 Loss: 100 0.004    F1:

Epoch:9 Batch 550/852 Loss: 550 0.000    F1: 0.70
Epoch:9 Batch 600/852 Loss: 600 0.000    F1: 0.64
Epoch:9 Batch 650/852 Loss: 650 0.000    F1: 0.63
Epoch:9 Batch 700/852 Loss: 700 0.000    F1: 0.63
Epoch:9 Batch 750/852 Loss: 750 0.000    F1: 0.63
Epoch:9 Batch 800/852 Loss: 800 0.000    F1: 0.62


In [46]:
batch[0]

tensor([[ 102,  205, 5122,  ...,    0,    0,    0],
        [ 102, 1593,  165,  ...,    0,    0,    0]])

In [14]:
dvjs_input = pd.read_csv("./../gold/dvjs_data.tsv", sep="\t", index_col=0)

In [15]:
dvjs_sample_id = list(set(dvjs_input["sample"]))

In [16]:
X,mask,Y = create_bert_input(dvjs_input.fillna(""), dvjs_sample_id, 512, 0, 0)

HBox(children=(FloatProgress(value=0.0, max=338.0), HTML(value='')))




In [17]:
train_dataset = TensorDataset(X[:250],mask[:250],Y[:250])
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=8)

In [18]:
validation_data = TensorDataset(X[250:], mask[250:], Y[250:])
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=12)

In [19]:
epochs = 10
total_steps = len(train_dataloader) * epochs
optimizer = AdamW(model.parameters())
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0, num_training_steps = total_steps)

In [20]:
loss_values = []
for epoch_i in range(0, epochs):
 ## Iteration über Epochen

    total_loss = 0 
    model.train() # Trainiert nicht, sondern ändert das Modell .train()==Trainingsmodus
    print("Epoch: "+str(epoch_i))
    for step, batch in enumerate(train_dataloader):
     ## Iteration über batches

        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
          ## Batches auf GPU/CPU schieben

        model.zero_grad()
          ## Gradienten löschen für nächste Batch
        outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
          ## forward pass

        loss = outputs[0] 
        total_loss += loss.item()
          # pull out loss for printing

        loss.backward()
          ## Backward pass 
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
          ## Regularisierung

        optimizer.step() 
        scheduler.step()
          # Upadte optimizer + scheduler


        avg_train_loss = total_loss / len(train_dataloader) 
        logits = outputs[1]
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        p,t = flat_accuracy(logits, label_ids)

        if i % 10 == 0:
            print("Batch Loss: "+str(i)+" "+str(avg_train_loss)[:5]+"    F1: "+str(f1_score(t, p, average="macro"))[:4])
        i+=1

         ## Model in Validation Modus bringen

        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
         ## Validation Ergebnisse zurücksetzen, um sie neu berechnen zu können

        for batch in validation_dataloader:


            batch = tuple(t.to(device) for t in batch)


            b_input_ids, b_input_mask, b_labels = batch

            with torch.no_grad():
                outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
                logits = outputs[0] 
                logits = logits.detach().cpu().numpy()
                label_ids = b_labels.to('cpu').numpy()


                p,t = flat_accuracy(logits, label_ids)
                logits_run.append(p)
                label_run.append(t)
            
            print(np.stack(label_run).flatten())
            print(np.stack(logits_run).flatten())
            print(f1_score(np.stack(label_run).flatten(), np.stack(logits_run).flatten(), average="macro")) 

Epoch: 0
Batch Loss: 0.012816149741411209
Validation
validation acc: 0.0


  after removing the cwd from sys.path.


validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.02427833154797554
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.03488123882561922
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.04293237440288067
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.05084572732448578
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.05652045737951994
Validation
valid

validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.07488986942917109
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.08183270366862416
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.08782179979607463
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.09383618598803878
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.09916382469236851
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
valid

Batch Loss: 0.1252406199928373
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.13135204440914094
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.13765986752696335
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.14496717206202447
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.15119058336131275
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
valida

validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.1782187670469284
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.18371260911226273
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.19236509315669537
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.1991415680386126
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.2045319266617298
Validation
validation acc: 0.0
validation acc: 0.0
validati

validation acc: 0.0
Batch Loss: 0.016296360176056623
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.02161365421488881
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.028726874385029078
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.03590253507718444
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.04189553437754512
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
val

validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.0768241542391479
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.083691849373281
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.0909799188375473
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.10087434109300375
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.1053878334350884
Validation
validation

validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.13278647884726524
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.13603725307621062
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.13990466599352658
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.14738679793663323
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.15109838312491775
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
valid

validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.17975472239777446
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.18504420900717378
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.19080018252134323
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Batch Loss: 0.1978056225925684
Validation
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
validation acc: 0.0
Epoch: 9
Batch Loss: 0.007604381535202265
Validat

In [21]:
dvjs_input

Unnamed: 0,text,class,sample
0,Holger,1,0
1,Dainat,1,0
2,",",0,0
3,„,0,0
4,»,0,0
...,...,...,...
22530,",",0,341
22531,hier,0,341
22532,:,0,341
22533,256,0,341


In [22]:
grobid_input[:20]

Unnamed: 0,text,class,sample
0,dazu,0,0
1,auch,0,0
2,D.,1,0
3,Sexty,1,0
4,",",0,0
5,PoS,0,0
6,LATTICE,0,0
7,2014,0,0
8,(,0,0
9,2015,0,0
